diff --git a/src/api.rs b/src/api.rs index 90fac8c..de9c0a0 100644 --- a/src/api.rs +++ b/src/api.rs @@ -104,6 +104,7 @@ pub struct UserDiff { pub notify_user: bool, pub need_validation: bool, pub validation_status: DiffStatus, + pub storages: Vec, } #[derive(Debug, PartialEq, Tsify, Serialize, Deserialize, Default)] @@ -705,17 +706,34 @@ fn create_diffs(device: &MutexGuard, process: &Process, new_state: &Proc let our_id = device.get_pairing_commitment(); - let fields_to_validate = if let Some(our_id) = our_id { - new_state.get_fields_to_validate_for_member(&our_id)? + let fields_to_validate: HashMap> = if let Some(our_id) = our_id { + let mut relevant_fields = HashMap::new(); + for (name, role) in new_state.roles.iter() { + if !role.members.contains(&our_id) { + // This role doesn't concern requester + continue; + } + let fields: Vec = role + .validation_rules + .iter() + .flat_map(|rule| rule.fields.clone()) + .collect(); + relevant_fields.extend(fields.into_iter().map(|field| (field, role.storages.clone()))); + } + relevant_fields } else { // Device is unpaired, we just take all the fields in the `pairing` role if let Some(pairing_role) = new_state.roles.get("pairing") { - pairing_role.validation_rules.iter().flat_map(|r| r.fields.clone()).collect() + let mut relevant_fields = HashMap::new(); + let fields: Vec = pairing_role.validation_rules.iter().flat_map(|r| r.fields.clone()).collect(); + relevant_fields.extend(fields.into_iter().map(|field| (field, pairing_role.storages.clone()))); + relevant_fields } else { return Err(AnyhowError::msg("Missing pairing role")) } }; + let new_state_id = &new_state.state_id; let new_public_data = &new_state.public_data; @@ -723,7 +741,7 @@ fn create_diffs(device: &MutexGuard, process: &Process, new_state: &Proc let process_id = process.get_process_id()?.to_string(); let mut diffs = vec![]; for (field, hash) in new_state_commitments.iter() { - let need_validation = fields_to_validate.contains(field); + let need_validation = fields_to_validate.contains_key(field); diffs.push(UserDiff { process_id: process_id.clone(), state_id: new_state_id.to_lower_hex_string(), @@ -734,6 +752,7 @@ fn create_diffs(device: &MutexGuard, process: &Process, new_state: &Proc need_validation, validation_status: DiffStatus::None, roles: new_state.roles.clone(), + storages: fields_to_validate.get(field).unwrap_or(&vec![]).clone(), }); } @@ -944,7 +963,7 @@ fn handle_prd( let sp_wallet = local_device.get_sp_client(); let local_address = sp_wallet.get_receiving_address().to_string(); - let mut relevant_fields: HashSet = HashSet::new(); + let mut relevant_fields: HashMap> = HashMap::new(); let shared_secrets = lock_shared_secrets()?; for (name, role) in state.roles.iter() { if !role.members.contains(&requester) { @@ -956,7 +975,7 @@ fn handle_prd( .iter() .flat_map(|rule| rule.fields.clone()) .collect(); - relevant_fields.extend(fields); + relevant_fields.extend(fields.into_iter().map(|field| (field, role.storages.clone()))); } let sender = local_device.get_pairing_commitment().ok_or(AnyhowError::msg("Device not paired"))?; @@ -970,7 +989,7 @@ fn handle_prd( state.pcd_commitment.clone(), ); - full_prd.filter_keys(&relevant_fields); + full_prd.filter_keys(&relevant_fields.iter().map(|(field, _)| field.clone()).collect()); let prd_msg = full_prd.to_network_msg(sp_wallet)?; let addresses = members_list.0.get(&sender).ok_or(AnyhowError::msg("Unknown requester"))?.get_addresses(); @@ -991,7 +1010,7 @@ fn handle_prd( let pcd_commitment = &state.pcd_commitment; for (field, hash) in pcd_commitment.iter() { // We only need field that are visible by requester - if !relevant_fields.contains(field.as_str()) { + if !relevant_fields.contains_key(field.as_str()) { continue; } let diff = UserDiff { @@ -999,6 +1018,7 @@ fn handle_prd( state_id: state_id.to_lower_hex_string(), value_commitment: hash.to_lower_hex_string(), field: field.to_owned(), + storages: relevant_fields.get(field.as_str()).unwrap().clone(), ..Default::default() }; diffs.push(diff); @@ -1038,6 +1058,7 @@ fn handle_decrypted_message( } } +// TODO make separate functions for the decryption itself and the parsing to avoid copy and serialization cost #[wasm_bindgen] pub fn parse_cipher(cipher_msg: String, members_list: OutPointMemberMap, processes: OutPointProcessMap) -> ApiResult { // Check that the cipher is not empty or too long