Broaded scope for the mutex when parsing ciphers

This commit is contained in:
Sosthene 2024-10-13 01:11:52 +02:00
parent d121e8f7b2
commit 43b4c5f4d2

View File

@ -850,8 +850,9 @@ fn send_data(prd: &Prd, shared_secret: &AnkSharedSecretHash) -> AnyhowResult<Api
fn decrypt_with_cached_messages( fn decrypt_with_cached_messages(
cipher: &[u8], cipher: &[u8],
messages: &mut MutexGuard<Vec<CachedMessage>>
) -> anyhow::Result<Option<(Vec<u8>, AnkSharedSecretHash)>> { ) -> anyhow::Result<Option<(Vec<u8>, AnkSharedSecretHash)>> {
let mut messages = lock_messages()?; debug!("cached messages: {:#?}", messages);
let nonce = Nonce::from_slice(&cipher[..12]); let nonce = Nonce::from_slice(&cipher[..12]);
@ -904,11 +905,7 @@ fn decrypt_with_cached_messages(
Ok(None) Ok(None)
} }
fn decrypt_with_known_processes(cipher: &[u8]) -> anyhow::Result<Option<(Vec<u8>, OutPoint)>> { fn decrypt_with_known_processes(cipher: &[u8], processes: MutexGuard<HashMap<OutPoint, Process>>) -> anyhow::Result<Option<(Vec<u8>, OutPoint)>> {
let processes = lock_processes()?;
debug!("Known processes: {:#?}", processes);
let nonce = Nonce::from_slice(&cipher[..12]); let nonce = Nonce::from_slice(&cipher[..12]);
for (outpoint, process) in processes.iter() { for (outpoint, process) in processes.iter() {
@ -1092,6 +1089,10 @@ fn handle_decrypted_message(
#[wasm_bindgen] #[wasm_bindgen]
pub fn parse_cipher(cipher_msg: String) -> ApiResult<ApiReturn> { pub fn parse_cipher(cipher_msg: String) -> ApiResult<ApiReturn> {
// We lock message cache and processes to prevent race conditions
let mut messages = lock_messages()?;
let processes = lock_processes()?;
// Check that the cipher is not empty or too long // Check that the cipher is not empty or too long
if cipher_msg.is_empty() || cipher_msg.len() > MAX_PRD_PAYLOAD_SIZE { if cipher_msg.is_empty() || cipher_msg.len() > MAX_PRD_PAYLOAD_SIZE {
return Err(ApiError::new( return Err(ApiError::new(
@ -1102,13 +1103,13 @@ pub fn parse_cipher(cipher_msg: String) -> ApiResult<ApiReturn> {
let cipher = Vec::from_hex(cipher_msg.trim_matches('"'))?; let cipher = Vec::from_hex(cipher_msg.trim_matches('"'))?;
// Try decrypting with cached messages first // Try decrypting with cached messages first
if let Ok(Some((plain, shared_secret))) = decrypt_with_cached_messages(&cipher) { if let Ok(Some((plain, shared_secret))) = decrypt_with_cached_messages(&cipher, &mut messages) {
return handle_decrypted_message(plain, Some(shared_secret), None) return handle_decrypted_message(plain, Some(shared_secret), None)
.map_err(|e| ApiError::new(format!("Failed to handle decrypted message: {}", e))); .map_err(|e| ApiError::new(format!("Failed to handle decrypted message: {}", e)));
} }
// If that fails, try decrypting with known processes // If that fails, try decrypting with known processes
if let Ok(Some((plain, root_commitment))) = decrypt_with_known_processes(&cipher) { if let Ok(Some((plain, root_commitment))) = decrypt_with_known_processes(&cipher, processes) {
return handle_decrypted_message(plain, None, Some(root_commitment)) return handle_decrypted_message(plain, None, Some(root_commitment))
.map_err(|e| ApiError::new(format!("Failed to handle decrypted message: {}", e))); .map_err(|e| ApiError::new(format!("Failed to handle decrypted message: {}", e)));
} }
@ -1118,8 +1119,7 @@ pub fn parse_cipher(cipher_msg: String) -> ApiResult<ApiReturn> {
return_msg.cipher = vec![cipher_msg]; return_msg.cipher = vec![cipher_msg];
return_msg.status = CachedMessageStatus::CipherWaitingTx; return_msg.status = CachedMessageStatus::CipherWaitingTx;
let mut messages_cache = lock_messages()?; messages.push(return_msg.clone());
messages_cache.push(return_msg.clone());
Ok(ApiReturn { Ok(ApiReturn {
updated_cached_msg: vec![return_msg], updated_cached_msg: vec![return_msg],