diff --git a/crates/sp_client/src/aesgcm.rs b/crates/sp_client/src/aesgcm.rs index 0539f7e..b3d55c7 100644 --- a/crates/sp_client/src/aesgcm.rs +++ b/crates/sp_client/src/aesgcm.rs @@ -73,25 +73,32 @@ impl Aes256Decryption { pub fn new( purpose: Purpose, cipher_text: CipherText, - encrypted_aes_key: Vec, - shared_secret: SharedSecret, + encrypted_aes_key: Vec, // If shared_secret is none this is actually the aes_key + shared_secret: Option, // We don't need that for certain purpose, like Login ) -> Result { - if encrypted_aes_key.len() <= 12 { - return Err(Error::msg("encrypted_aes_key is shorter than nonce length")); - } // Actually we could probably test that if the remnant is not a multiple of 32, something's wrong - // take the first 12 bytes form encrypted_aes_key as nonce - let (decrypt_key_nonce, encrypted_key) = encrypted_aes_key.split_at(12); - // decrypt key with shared_secret obtained from transaction - let decrypt_key_cipher = Aes256Gcm::new_from_slice(shared_secret.as_ref()) - .map_err(|e| Error::msg(format!("{}", e)))?; - let aes_key_plain = decrypt_key_cipher - .decrypt(decrypt_key_nonce.into(), encrypted_key) - .map_err(|e| Error::msg(format!("{}", e)))?; - if aes_key_plain.len() != 32 { - return Err(Error::msg("Invalid length for decrypted key")); - } let mut aes_key = [0u8; 32]; - aes_key.copy_from_slice(&aes_key_plain); + if let Some(shared_secret) = shared_secret { + if encrypted_aes_key.len() <= 12 { + return Err(Error::msg("encrypted_aes_key is shorter than nonce length")); + } // Actually we could probably test that if the remnant is not a multiple of 32, something's wrong + // take the first 12 bytes form encrypted_aes_key as nonce + let (decrypt_key_nonce, encrypted_key) = encrypted_aes_key.split_at(12); + // decrypt key with shared_secret obtained from transaction + let decrypt_key_cipher = Aes256Gcm::new_from_slice(shared_secret.as_ref()) + .map_err(|e| Error::msg(format!("{}", e)))?; + let aes_key_plain = decrypt_key_cipher + .decrypt(decrypt_key_nonce.into(), encrypted_key) + .map_err(|e| Error::msg(format!("{}", e)))?; + if aes_key_plain.len() != 32 { + return Err(Error::msg("Invalid length for decrypted key")); + } + aes_key.copy_from_slice(&aes_key_plain); + } else { + if encrypted_aes_key.len() != 32 { + return Err(Error::msg("Invalid length for decrypted key")); + } + aes_key.copy_from_slice(&encrypted_aes_key); + } if cipher_text.len() <= 12 { return Err(Error::msg("cipher_text is shorter than nonce length")); } @@ -258,13 +265,22 @@ mod tests { #[test] fn aes_encrypt_login() { let plaintext = [1u8; HALFKEYSIZE]; - let aes_enc = Aes256Encryption::new(Purpose::Login, plaintext.to_vec()); + let aes_key = Aes256Gcm::generate_key(&mut thread_rng()); + let nonce = Aes256Gcm::generate_nonce(&mut thread_rng()); + let aes_enc = Aes256Encryption::import_key(Purpose::Login, plaintext.to_vec(), aes_key.into(), nonce.into()); assert!(aes_enc.is_ok()); let cipher = aes_enc.unwrap().encrypt_with_aes_key(); assert!(cipher.is_ok()); + + let mut plain_key = [0u8;32]; + plain_key.copy_from_slice(&aes_key.to_vec()); + + let aes_dec = Aes256Decryption::new(Purpose::Login, cipher.unwrap(), plain_key.to_vec(), None); + + assert!(aes_dec.is_ok()); } #[test] @@ -303,7 +319,7 @@ mod tests { Purpose::Login, ciphertext.unwrap(), encrypted_key.unwrap(), - SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + Some(SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap()), ); assert!(aes_dec.is_ok()); @@ -356,7 +372,7 @@ mod tests { Purpose::Login, ciphertext.unwrap(), encrypted_key.unwrap(), - SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap(), + Some(SharedSecret::from_str(ALICE_SHARED_SECRET).unwrap()), ); assert!(aes_dec.is_ok()); @@ -381,7 +397,7 @@ mod tests { Purpose::Login, ciphertext.unwrap(), encrypted_key.unwrap(), - SharedSecret::from_str(BOB_SHARED_SECRET).unwrap(), + Some(SharedSecret::from_str(BOB_SHARED_SECRET).unwrap()), ); assert!(aes_dec.is_ok());