sdk_common/src/pcd.rs

332 lines
11 KiB
Rust

use std::{collections::HashSet, str::FromStr};
use anyhow::{Result, Error};
use aes_gcm::{aead::{Aead, Payload}, AeadCore, Aes256Gcm, KeyInit};
use log::debug;
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use sp_client::{bitcoin::{hashes::{sha256t_hash_newtype, Hash, HashEngine}, hex::{DisplayHex, FromHex}, XOnlyPublicKey}, silentpayments::utils::SilentPaymentAddress};
use tsify::Tsify;
use crate::{crypto::AAD, signature::{AnkValidationNoHash, AnkValidationYesHash, Proof}};
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Tsify)]
#[tsify(into_wasm_abi, from_wasm_abi)]
pub struct Member {
sp_addresses: Vec<String>
}
impl Member {
pub fn new(
sp_addresses: Vec<SilentPaymentAddress>,
) -> Result<Self> {
if sp_addresses.is_empty() {
return Err(Error::msg("empty address set"));
}
let mut seen = HashSet::new();
for s in sp_addresses.iter() {
if !seen.insert(s.clone()) {
return Err(Error::msg("Duplicate addresses found"));
}
}
let res: Vec<String> = sp_addresses.iter()
.map(|a| Into::<String>::into(*a))
.collect();
Ok(Self {
sp_addresses: res
})
}
pub fn get_addresses(&self) -> Vec<String> {
self.sp_addresses.clone()
}
pub fn key_is_part_of_member(&self, key: &XOnlyPublicKey) -> bool {
self.sp_addresses.iter().any(|a| {
let addr = SilentPaymentAddress::try_from(a.as_str()).unwrap();
addr.get_spend_key().x_only_public_key().0 == *key
})
}
}
sha256t_hash_newtype! {
pub struct AnkPcdTag = hash_str("4nk/Pcd");
#[hash_newtype(forward)]
pub struct AnkPcdHash(_);
}
impl AnkPcdHash {
pub fn from_value(value: &Value) -> Self {
let mut eng = AnkPcdHash::engine();
eng.input(value.to_string().as_bytes());
AnkPcdHash::from_engine(eng)
}
pub fn from_map(map: &Map<String, Value>) -> Self {
let value = Value::Object(map.clone());
let mut eng = AnkPcdHash::engine();
eng.input(value.to_string().as_bytes());
AnkPcdHash::from_engine(eng)
}
}
pub trait Pcd<'a>: Serialize + Deserialize<'a> {
fn tagged_hash(&self) -> AnkPcdHash {
AnkPcdHash::from_value(&self.to_value())
}
fn encrypt_fields(&self, fields2keys: &mut Map<String, Value>, fields2cipher: &mut Map<String, Value>) -> Result<()> {
let as_value = self.to_value();
let as_map = as_value.as_object().ok_or_else(|| Error::msg("Expected object"))?;
let mut rng = thread_rng();
for (field, value) in as_map {
let aes_key = Aes256Gcm::generate_key(&mut rng);
let nonce = Aes256Gcm::generate_nonce(&mut rng);
fields2keys.insert(field.to_owned(), Value::String(aes_key.to_lower_hex_string()));
let encrypt_eng = Aes256Gcm::new(&aes_key);
let value_string = value.to_string();
let payload = Payload {
msg: value_string.as_bytes(),
aad: AAD,
};
let cipher = encrypt_eng.encrypt(&nonce, payload)
.map_err(|e| Error::msg(format!("Encryption failed for field {}: {}", field, e)))?;
let mut res = Vec::with_capacity(nonce.len() + cipher.len());
res.extend_from_slice(&nonce);
res.extend_from_slice(&cipher);
fields2cipher.insert(field.to_owned(), Value::String(res.to_lower_hex_string()));
}
Ok(())
}
fn decrypt_fields(&self, fields2keys: &Map<String, Value>, fields2plain: &mut Map<String, Value>) -> Result<()> {
let value = self.to_value();
let map = value.as_object().unwrap();
for (field, encrypted_value) in map.iter() {
if let Some(aes_key) = fields2keys.get(field) {
let key_buf = Vec::from_hex(&aes_key.to_string().trim_matches('\"'))?;
let decrypt_eng = Aes256Gcm::new(key_buf.as_slice().into());
let raw_cipher = Vec::from_hex(&encrypted_value.as_str().ok_or_else(|| Error::msg("Expected string"))?.trim_matches('\"'))?;
if raw_cipher.len() < 28 {
return Err(Error::msg(format!("Invalid ciphertext length for field {}", field)));
}
let payload = Payload {
msg: &raw_cipher[12..],
aad: AAD,
};
let plain = decrypt_eng.decrypt(raw_cipher[..12].into(), payload)
.map_err(|_| Error::msg(format!("Failed to decrypt field {}", field)))?;
let decrypted_value: String = String::from_utf8(plain)?;
fields2plain.insert(field.to_owned(), Value::String(decrypted_value));
} else {
fields2plain.insert(field.to_owned(), Value::Null);
}
}
Ok(())
}
fn to_value(&self) -> Value {
Value::from_str(&serde_json::to_string(&self).unwrap()).unwrap()
}
}
impl Pcd<'_> for Value {}
#[derive(Debug, Clone, Serialize, Deserialize, Tsify)]
#[tsify(into_wasm_abi, from_wasm_abi)]
pub struct ValidationRule {
quorum: f32, // Must be >= 0.0, <= 1.0, 0.0 means reading right
pub fields: Vec<String>, // Which fields are concerned by this rule
min_sig_member: f32, // Must be >= 0.0, <= 1.0, does each member need to sign with all it's devices?
}
impl ValidationRule {
pub fn new(quorum: f32, fields: Vec<String>, min_sig_member: f32) -> Result<Self> {
if quorum < 0.0 || quorum > 1.0 {
return Err(Error::msg("quorum must be 0.0 < quorum <= 1.0"));
}
if min_sig_member < 0.0 || min_sig_member > 1.0 {
return Err(Error::msg("min_signatures_member must be 0.0 < min_signatures_member <= 1.0"));
}
if fields.is_empty() {
return Err(Error::msg("Fields can't be empty"));
}
let res = Self {
quorum,
fields,
min_sig_member,
};
Ok(res)
}
pub fn is_satisfied(&self, field: &str, new_state_hash: AnkPcdHash, proofs: &[&Proof], members: &[Member]) -> bool {
// Check if this rule applies to the field
if !self.fields.contains(&field.to_string()) {
return false;
}
let required_members = (members.len() as f32 * self.quorum).ceil() as usize;
let validating_members = members.iter()
.filter(|member| {
let member_proofs: Vec<&Proof> = proofs.iter()
.filter(|p| member.key_is_part_of_member(&p.get_key()))
.cloned()
.collect();
self.satisfy_min_sig_member(member, new_state_hash, &member_proofs).is_ok()
})
.count();
validating_members >= required_members
}
pub fn satisfy_min_sig_member(&self, member: &Member, new_state_hash: AnkPcdHash, proofs: &[&Proof]) -> Result<()> {
let required_sigs = (member.get_addresses().len() as f32 * self.min_sig_member).ceil() as usize;
if required_sigs > proofs.len() {
// We can't have more proofs than registered devices for one member
return Err(Error::msg("More proofs than devices for member"));
} else if proofs.len() < required_sigs {
// Even if all proof are valid yes, we don't reach the quota
return Err(Error::msg("Not enough provided proofs to reach quota"));
}
let mut yes_votes: Vec<Proof> = Vec::new();
let mut no_votes: Vec<Proof> = Vec::new();
// Compute both yes and no commitment
let yes = AnkValidationYesHash::from_commitment(new_state_hash).to_byte_array();
let no = AnkValidationNoHash::from_commitment(new_state_hash).to_byte_array();
// Validate proofs here
for proof in proofs {
if !proof.verify().is_ok() {
return Err(Error::msg("Invalid proof"));
}
let signed_message = proof.get_message();
if signed_message == yes {
yes_votes.push(**proof);
} else if signed_message == no {
no_votes.push(**proof);
} else {
return Err(Error::msg("We don't know what this proof signs for"));
}
}
if yes_votes.len() >= required_sigs {
Ok(())
} else {
Err(Error::msg("Not enough yes votes"))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Tsify)]
#[tsify(into_wasm_abi, from_wasm_abi)]
pub struct RoleDefinition {
pub members: Vec<Member>,
pub validation_rules: Vec<ValidationRule>,
}
impl RoleDefinition {
pub fn is_satisfied(&self, new_state: &Value, previous_state: &Value, proofs: &[&Proof]) -> bool {
// compute the modified fields
let modified_fields: Vec<String> = new_state.as_object().unwrap()
.iter()
.filter_map(|(key, value)| {
let previous_value = previous_state.as_object().unwrap().get(key);
if previous_value.is_none() || value != previous_value.unwrap() {
Some(key.clone())
} else {
None
}
})
.collect();
let new_state_hash = AnkPcdHash::from_value(new_state);
// check that for each field we can satisfy at least one rule
modified_fields.iter().all(|field| {
self.validation_rules.iter().any(|rule| rule.is_satisfied(field, new_state_hash, proofs, &self.members))
})
}
pub fn get_applicable_rules(&self, field: &str) -> Vec<&ValidationRule> {
self.validation_rules.iter()
.filter(|rule| rule.fields.contains(&field.to_string()))
.collect()
}
}
pub fn compare_maps(map1: &Map<String, Value>, map2: &Map<String, Value>) -> bool {
// First, check if both maps have the same keys
if map1.keys().collect::<Vec<&String>>() != map2.keys().collect::<Vec<&String>>() {
return false;
}
// Then, check if the corresponding values have the same type
for key in map1.keys() {
let value1 = map1.get(key).unwrap();
let value2 = map2.get(key).unwrap();
if !compare_values(value1, value2) {
return false;
}
}
true
}
fn compare_values(value1: &Value, value2: &Value) -> bool {
if value1.is_null() && value2.is_null() {
return true;
} else if value1.is_boolean() && value2.is_boolean() {
return true;
} else if value1.is_number() && value2.is_number() {
return true;
} else if value1.is_string() && value2.is_string() {
return true;
} else if value1.is_array() && value2.is_array() {
return compare_arrays(value1.as_array().unwrap(), value2.as_array().unwrap());
} else if value1.is_object() && value2.is_object() {
// Recursive comparison for nested objects
return compare_maps(value1.as_object().unwrap(), value2.as_object().unwrap());
} else {
return false;
}
}
fn compare_arrays(array1: &Vec<Value>, array2: &Vec<Value>) -> bool {
// Compare the type of each element in the arrays
for (elem1, elem2) in array1.iter().zip(array2.iter()) {
if !compare_values(elem1, elem2) {
return false;
}
}
true
}