From 0ba5bc0e2aebbf959ab9fd60cd941fb1b0b9d870 Mon Sep 17 00:00:00 2001 From: Sosthene00 <674694@protonmail.ch> Date: Tue, 26 Mar 2024 17:18:34 +0100 Subject: [PATCH] aesgcm refactoring + unit tests --- crates/sp_client/src/aesgcm.rs | 303 ++++++++++++++++++++++++++++----- 1 file changed, 265 insertions(+), 38 deletions(-) diff --git a/crates/sp_client/src/aesgcm.rs b/crates/sp_client/src/aesgcm.rs index e0e2b13..0539f7e 100644 --- a/crates/sp_client/src/aesgcm.rs +++ b/crates/sp_client/src/aesgcm.rs @@ -29,40 +29,57 @@ use rand::{thread_rng, RngCore}; const HALFKEYSIZE: usize = SECRET_KEY_SIZE / 2; -pub type HalfKey = [u8; HALFKEYSIZE]; +pub struct HalfKey([u8; HALFKEYSIZE]); -pub enum EncryptionTarget { - Login(HalfKey), +impl TryFrom> for HalfKey { + type Error = anyhow::Error; + fn try_from(value: Vec) -> std::prelude::v1::Result { + if value.len() == HALFKEYSIZE { + let mut buf = [0u8; HALFKEYSIZE]; + buf.copy_from_slice(&value); + Ok(HalfKey(buf)) + } else { + Err(Error::msg("Invalid length for HalfKey")) + } + } } -pub enum DecryptionTarget { +impl HalfKey { + pub fn as_slice(&self) -> &[u8] { + &self.0 + } + + pub fn to_inner(&self) -> Vec { + self.0.to_vec() + } +} + +pub enum Purpose { Login, } -pub enum PlainText { - Login(HalfKey), -} - pub type CipherText = Vec; +pub type EncryptedKey = Vec; + pub struct Aes256Decryption { - target: DecryptionTarget, + pub purpose: Purpose, + cipher_text: CipherText, aes_key: [u8; 32], nonce: [u8; 12], - cipher_text: CipherText, } impl Aes256Decryption { pub fn new( - target: DecryptionTarget, + purpose: Purpose, + cipher_text: CipherText, encrypted_aes_key: Vec, shared_secret: SharedSecret, - cipher_text: CipherText, ) -> Result { if encrypted_aes_key.len() <= 12 { return Err(Error::msg("encrypted_aes_key is shorter than nonce length")); - } - // take the first 12 bytes form encrypted_aes_key as nonce + } // Actually we could probably test that if the remnant is not a multiple of 32, something's wrong + // take the first 12 bytes form encrypted_aes_key as nonce let (decrypt_key_nonce, encrypted_key) = encrypted_aes_key.split_at(12); // decrypt key with shared_secret obtained from transaction let decrypt_key_cipher = Aes256Gcm::new_from_slice(shared_secret.as_ref()) @@ -76,26 +93,29 @@ impl Aes256Decryption { let mut aes_key = [0u8; 32]; aes_key.copy_from_slice(&aes_key_plain); if cipher_text.len() <= 12 { - return Err(Error::msg("cipher_text is shorter than nonce lenght")); + return Err(Error::msg("cipher_text is shorter than nonce length")); } let (message_nonce, message_cipher) = cipher_text.split_at(12); let mut nonce = [0u8; 12]; nonce.copy_from_slice(message_nonce); Ok(Self { - target, + purpose, + cipher_text: message_cipher.to_vec(), aes_key, nonce, - cipher_text: message_cipher.to_vec(), }) } - pub fn decrypt_with_key(&self) -> Result { - match self.target { - DecryptionTarget::Login => self.decrypt_login() + pub fn decrypt_with_key(&self) -> Result<Vec<u8>> { + match self.purpose { + Purpose::Login => { + let half_key = self.decrypt_login()?; + Ok(half_key.to_inner()) + } } } - fn decrypt_login(&self) -> Result<PlainText> { + fn decrypt_login(&self) -> Result<HalfKey> { let cipher = Aes256Gcm::new(&self.aes_key.into()); let plain = cipher .decrypt(&self.nonce.into(), &*self.cipher_text) @@ -105,44 +125,70 @@ impl Aes256Decryption { } let mut key_half = [0u8; SECRET_KEY_SIZE / 2]; key_half.copy_from_slice(&plain); - Ok(PlainText::Login(key_half)) + Ok(HalfKey(key_half)) } } pub struct Aes256Encryption { - pub target: EncryptionTarget, + pub purpose: Purpose, + plaintext: Vec<u8>, aes_key: [u8; 32], nonce: [u8; 12], shared_secrets: HashMap<Txid, HashMap<SilentPaymentAddress, SharedSecret>>, } impl Aes256Encryption { - pub fn new(target: EncryptionTarget) -> Result<Self> { + pub fn new(purpose: Purpose, plaintext: Vec<u8>) -> Result<Self> { let mut rng = thread_rng(); let aes_key: [u8; 32] = Aes256Gcm::generate_key(&mut rng).into(); let nonce: [u8; 12] = Aes256Gcm::generate_nonce(&mut rng).into(); - Ok(Self { - target, - aes_key, - nonce, - shared_secrets: HashMap::new(), - }) + Self::import_key(purpose, plaintext, aes_key, nonce) } pub fn set_shared_secret( &mut self, shared_secrets: HashMap<Txid, HashMap<SilentPaymentAddress, SharedSecret>>, - ) -> Result<()> { - unimplemented!(); + ) { + self.shared_secrets = shared_secrets; + } + + pub fn encrypt_keys_with_shared_secrets( + &self, + ) -> Result<HashMap<SilentPaymentAddress, EncryptedKey>> { + let mut res = HashMap::new(); + let mut rng = thread_rng(); + + for (_, sp_address2shared_secret) in self.shared_secrets.iter() { + for (sp_address, shared_secret) in sp_address2shared_secret { + let cipher = Aes256Gcm::new_from_slice(shared_secret.as_ref()) + .map_err(|e| Error::msg(format!("{}", e)))?; + let nonce = Aes256Gcm::generate_nonce(&mut rng); + let encrypted_key = cipher + .encrypt(&nonce, self.aes_key.as_slice()) + .map_err(|e| Error::msg(format!("{}", e)))?; + + let mut ciphertext = Vec::<u8>::with_capacity(nonce.len() + encrypted_key.len()); + ciphertext.extend(nonce); + ciphertext.extend(encrypted_key); + + res.insert(sp_address.to_owned(), ciphertext); + } + } + Ok(res) } pub fn import_key( - target: EncryptionTarget, + purpose: Purpose, + plaintext: Vec<u8>, aes_key: [u8; 32], nonce: [u8; 12], ) -> Result<Self> { + if plaintext.len() == 0 { + return Err(Error::msg("Can't create encryption for an empty message")); + } Ok(Self { - target, + purpose, + plaintext, aes_key, nonce, shared_secrets: HashMap::new(), @@ -150,15 +196,16 @@ impl Aes256Encryption { } pub fn encrypt_with_aes_key(&self) -> Result<CipherText> { - match self.target { - EncryptionTarget::Login(half_key) => self.encrypt_login(half_key), + match self.purpose { + Purpose::Login => self.encrypt_login(), } } - fn encrypt_login(&self, plaintext: HalfKey) -> Result<CipherText> { + fn encrypt_login(&self) -> Result<CipherText> { + let half_key: HalfKey = self.plaintext.clone().try_into()?; let cipher = Aes256Gcm::new(&self.aes_key.into()); let cipher_text = cipher - .encrypt(&self.nonce.into(), plaintext.as_slice()) + .encrypt(&self.nonce.into(), half_key.as_slice()) .map_err(|e| Error::msg(format!("{}", e)))?; let mut res = Vec::with_capacity(self.nonce.len() + cipher_text.len()); res.extend_from_slice(&self.nonce); @@ -166,3 +213,183 @@ impl Aes256Encryption { Ok(res) } } + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use super::*; + + const ALICE_SP_ADDRESS: &str = "tsp1qqw3lqr6xravz9nf8ntazgwwl0fqv47kfjdxsnxs6eutavqfwyv5q6qk97mmyf6dtkdyzqlu2zv6h9j2ggclk7vn705q5u2phglpq7yw3dg5rwpdz"; + const BOB_SP_ADDRESS: &str = "tsp1qq2hlsgrj0gz8kcfkf9flqw5llz0u2vr04telqndku9mcqm6dl4fhvq60t8r78srrf56w9yr7w9e9dusc2wjqc30up6fjwnh9mw3e3veqegdmtf08"; + const TRANSACTION: &str = "4e6d03dec558e1b6624f813bf2da7cd8d8fb1c2296684c08cf38724dcfd8d10b"; + const ALICE_SHARED_SECRET: &str = "ccf02d364c2641ca129a3fdf49de57b705896e233f7ba6d738991993ea7e2106"; + const BOB_SHARED_SECRET: &str = "15ef3e377fb842e81de52dbaaea8ba30aeb051a81043ee19264afd27353da521"; + + #[test] + fn new_aes_empty_plaintext() { + let plaintext = Vec::new(); + let aes_enc = Aes256Encryption::new(Purpose::Login, plaintext); + + assert!(aes_enc.is_err()); + } + + #[test] + fn aes_encrypt_login_invalid_length() { + let plaintext = "example"; + let aes_enc_short = Aes256Encryption::new(Purpose::Login, plaintext.as_bytes().to_vec()); + + assert!(aes_enc_short.is_ok()); + + let cipher = aes_enc_short.unwrap().encrypt_with_aes_key(); + + assert!(cipher.is_err()); + + let plaintext = [1u8; 64]; + let aes_enc_long = Aes256Encryption::new(Purpose::Login, plaintext.to_vec()); + + assert!(aes_enc_long.is_ok()); + + let cipher = aes_enc_long.unwrap().encrypt_with_aes_key(); + + assert!(cipher.is_err()); + } + + #[test] + fn aes_encrypt_login() { + let plaintext = [1u8; HALFKEYSIZE]; + let aes_enc = Aes256Encryption::new(Purpose::Login, plaintext.to_vec()); + + assert!(aes_enc.is_ok()); + + let cipher = aes_enc.unwrap().encrypt_with_aes_key(); + + assert!(cipher.is_ok()); + } + + #[test] + fn aes_encrypt_key() { + let plaintext = [1u8; HALFKEYSIZE]; + let mut aes_enc = Aes256Encryption::new(Purpose::Login, plaintext.to_vec()).unwrap(); + + let mut shared_secrets: HashMap<Txid, _> = HashMap::new(); + let mut sp_address2shared_secrets: HashMap<SilentPaymentAddress, SharedSecret> = + HashMap::new(); + sp_address2shared_secrets.insert( + ALICE_SP_ADDRESS.try_into().unwrap(), + SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + ); + shared_secrets.insert( + Txid::from_str(TRANSACTION).unwrap(), + sp_address2shared_secrets, + ); + + aes_enc.set_shared_secret(shared_secrets); + + let sp_address2encrypted_keys = aes_enc.encrypt_keys_with_shared_secrets(); + + assert!(sp_address2encrypted_keys.is_ok()); + + let encrypted_key = sp_address2encrypted_keys + .unwrap() + .get(&ALICE_SP_ADDRESS.try_into().unwrap()) + .cloned(); + + let ciphertext = aes_enc.encrypt_with_aes_key(); + + assert!(ciphertext.is_ok()); + + let aes_dec = Aes256Decryption::new( + Purpose::Login, + ciphertext.unwrap(), + encrypted_key.unwrap(), + SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + ); + + assert!(aes_dec.is_ok()); + + let retrieved_plain = aes_dec.unwrap().decrypt_with_key(); + + assert!(retrieved_plain.is_ok()); + + assert!(retrieved_plain.unwrap() == plaintext); + } + + #[test] + fn aes_encrypt_key_many() { + let plaintext = [1u8; HALFKEYSIZE]; + let mut aes_enc = Aes256Encryption::new(Purpose::Login, plaintext.to_vec()).unwrap(); + + let mut shared_secrets: HashMap<Txid, _> = HashMap::new(); + let mut sp_address2shared_secrets: HashMap<SilentPaymentAddress, SharedSecret> = + HashMap::new(); + sp_address2shared_secrets.insert( + ALICE_SP_ADDRESS.try_into().unwrap(), + SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + ); + sp_address2shared_secrets.insert( + BOB_SP_ADDRESS.try_into().unwrap(), + SharedSecret::from_str(BOB_SHARED_SECRET).unwrap(), + ); + shared_secrets.insert( + Txid::from_str(TRANSACTION).unwrap(), + sp_address2shared_secrets, + ); + + aes_enc.set_shared_secret(shared_secrets); + + let mut sp_address2encrypted_keys = aes_enc.encrypt_keys_with_shared_secrets(); + + assert!(sp_address2encrypted_keys.is_ok()); + + // Alice + let encrypted_key = sp_address2encrypted_keys.as_mut() + .unwrap() + .get(&ALICE_SP_ADDRESS.try_into().unwrap()) + .cloned(); + + let ciphertext = aes_enc.encrypt_with_aes_key(); + + assert!(ciphertext.is_ok()); + + let aes_dec = Aes256Decryption::new( + Purpose::Login, + ciphertext.unwrap(), + encrypted_key.unwrap(), + SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + ); + + assert!(aes_dec.is_ok()); + + let retrieved_plain = aes_dec.unwrap().decrypt_with_key(); + + assert!(retrieved_plain.is_ok()); + + assert!(retrieved_plain.unwrap() == plaintext); + + // Bob + let encrypted_key = sp_address2encrypted_keys + .unwrap() + .get(&BOB_SP_ADDRESS.try_into().unwrap()) + .cloned(); + + let ciphertext = aes_enc.encrypt_with_aes_key(); + + assert!(ciphertext.is_ok()); + + let aes_dec = Aes256Decryption::new( + Purpose::Login, + ciphertext.unwrap(), + encrypted_key.unwrap(), + SharedSecret::from_str(BOB_SHARED_SECRET).unwrap(), + ); + + assert!(aes_dec.is_ok()); + + let retrieved_plain = aes_dec.unwrap().decrypt_with_key(); + + assert!(retrieved_plain.is_ok()); + + assert!(retrieved_plain.unwrap() == plaintext); + } +}