diff --git a/src/main.rs b/src/main.rs index b43ab25..8967b1c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,13 +7,13 @@ use std::{ net::SocketAddr, path::PathBuf, str::FromStr, - sync::{Arc, Mutex, MutexGuard, OnceLock}, - time::{Duration, Instant}, + sync::{Arc, Mutex, MutexGuard}, }; use bitcoincore_rpc::json::{self as bitcoin_json}; use futures_util::{future, pin_mut, stream::TryStreamExt, FutureExt, StreamExt}; use log::{debug, error, warn}; +use message::{broadcast_message, process_message, BroadcastType, MessageCache, MESSAGECACHE}; use scan::compute_partial_tweak_to_transaction; use sdk_common::sp_client::bitcoin::{ consensus::deserialize, @@ -26,15 +26,12 @@ use sdk_common::sp_client::{ }; use sdk_common::{ error::AnkError, - network::{AnkFlag, AnkNetworkMsg, FaucetMessage, NewTxMessage}, + network::{AnkFlag, NewTxMessage}, }; use sdk_common::sp_client::spclient::{derive_keys_from_seed, SpClient, SpendKey}; +use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc::{unbounded_channel, UnboundedSender}; -use tokio::{ - net::{TcpListener, TcpStream}, - time, -}; use tokio_stream::wrappers::UnboundedReceiverStream; use tokio_tungstenite::tungstenite::Message; @@ -45,10 +42,10 @@ mod config; mod daemon; mod electrumclient; mod faucet; +mod message; mod scan; use crate::config::Config; -use crate::faucet::handle_faucet_request; use crate::{daemon::Daemon, scan::scan_blocks}; type Tx = UnboundedSender; @@ -112,66 +109,6 @@ impl WalletFile { } } -static MESSAGECACHE: OnceLock = OnceLock::new(); - -const MESSAGECACHEDURATION: Duration = Duration::from_secs(10); -const MESSAGECACHEINTERVAL: Duration = Duration::from_secs(2); - -#[derive(Debug)] -struct MessageCache { - store: Mutex>, -} - -impl MessageCache { - fn new() -> Self { - Self { - store: Mutex::new(HashMap::new()), - } - } - - fn insert(&self, key: String) { - let mut store = self.store.lock().unwrap(); - store.insert(key.clone(), Instant::now()); - } - - fn contains(&self, key: &str) -> bool { - let store = self.store.lock().unwrap(); - store.contains_key(key) - } - - async fn clean_up() { - let cache = MESSAGECACHE.get().unwrap(); - - let mut interval = time::interval(MESSAGECACHEINTERVAL); - - loop { - interval.tick().await; - - let mut store = cache.store.lock().unwrap(); - - let now = Instant::now(); - let to_rm: Vec = store - .iter() - .filter_map(|(entry, entrytime)| { - if let Some(duration) = now.checked_duration_since(*entrytime) { - if duration > MESSAGECACHEDURATION { - Some(entry.clone()) - } else { - None - } - } else { - None - } - }) - .collect(); - - for key in to_rm { - store.remove(&key); - } - } - } -} - const FAUCET_AMT: Amount = Amount::from_sat(100_000); pub(crate) trait MutexExt { @@ -202,60 +139,6 @@ impl SilentPaymentWallet { } } -enum BroadcastType { - Sender(SocketAddr), - #[allow(dead_code)] - ExcludeSender(SocketAddr), - #[allow(dead_code)] - ToAll, -} - -fn broadcast_message( - peers: PeerMap, - flag: AnkFlag, - payload: String, - broadcast: BroadcastType, -) -> Result<()> { - let ank_msg = AnkNetworkMsg { - flag, - content: payload, - }; - let msg = Message::Text(serde_json::to_string(&ank_msg)?); - log::debug!("Broadcasting message: {}", msg); - match broadcast { - BroadcastType::Sender(addr) => { - peers - .lock() - .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? - .iter() - .find(|(peer_addr, _)| peer_addr == &&addr) - .ok_or(Error::msg("Failed to find the sender in the peer_map"))? - .1 - .send(msg)?; - } - BroadcastType::ExcludeSender(addr) => { - peers - .lock() - .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? - .iter() - .filter(|(peer_addr, _)| peer_addr != &&addr) - .for_each(|(_, peer_tx)| { - let _ = peer_tx.send(msg.clone()); - }); - } - BroadcastType::ToAll => { - peers - .lock() - .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? - .iter() - .for_each(|(_, peer_tx)| { - let _ = peer_tx.send(msg.clone()); - }); - } - } - Ok(()) -} - fn handle_new_tx_request(new_tx_msg: &mut NewTxMessage, shared_daemon: SharedDaemon) -> Result<()> { let tx = deserialize::(&Vec::from_hex(&new_tx_msg.transaction)?)?; let mempool_accept = shared_daemon.lock_anyhow()?.test_mempool_accept(&tx)?; @@ -301,119 +184,17 @@ async fn handle_connection( let (outgoing, incoming) = ws_stream.split(); let broadcast_incoming = incoming.try_for_each(|msg| { - let peers = peers.clone(); if let Ok(raw_msg) = msg.to_text() { debug!("Received msg: {}", raw_msg); - let cache = MESSAGECACHE.get().expect("Cache should be initialized"); - if cache.contains(raw_msg) { - debug!("Message already processed, dropping"); - return future::ok(()); - } else { - cache.insert(raw_msg.to_owned()); - } - let parsed = serde_json::from_str::(raw_msg); - match parsed { - Ok(ank_msg) => match ank_msg.flag { - AnkFlag::Faucet => { - debug!("Received a faucet message"); - if let Ok(mut content) = - serde_json::from_str::(&ank_msg.content) - { - match handle_faucet_request( - &content, - sp_wallet.clone(), - shared_daemon.clone(), - ) { - Ok(new_tx_msg) => { - debug!( - "Obtained new_tx_msg: {}", - serde_json::to_string(&new_tx_msg).unwrap() - ); - } - Err(e) => { - log::error!("Failed to send faucet tx: {}", e); - content.error = Some(e.into()); - let payload = serde_json::to_string(&content) - .expect("Message type shouldn't fail"); - if let Err(e) = broadcast_message( - peers.clone(), - AnkFlag::Faucet, - payload, - BroadcastType::Sender(addr), - ) { - log::error!("Failed to broadcast message: {}", e); - } - } - } - } else { - log::error!("Invalid content for faucet message"); - } - } - AnkFlag::NewTx => { - debug!("Received a new tx message"); - if let Ok(mut new_tx_msg) = - serde_json::from_str::(&ank_msg.content) - { - match handle_new_tx_request(&mut new_tx_msg, shared_daemon.clone()) { - Ok(new_tx_msg) => { - // Repeat the msg to all except sender - if let Err(e) = broadcast_message( - peers.clone(), - AnkFlag::NewTx, - serde_json::to_string(&new_tx_msg) - .expect("This should not fail"), - BroadcastType::ExcludeSender(addr), - ) { - log::error!("Failed to send message with error: {}", e); - } - } - Err(e) => { - log::error!("handle_new_tx_request returned error: {}", e); - new_tx_msg.error = Some(e.into()); - if let Err(e) = broadcast_message( - peers.clone(), - AnkFlag::NewTx, - serde_json::to_string(&new_tx_msg) - .expect("This shouldn't fail"), - BroadcastType::Sender(addr), - ) { - log::error!("Failed to broadcast message: {}", e); - } - } - } - } else { - log::error!("Invalid content for new_tx message"); - } - } - AnkFlag::Cipher => { - // For now we just send it to everyone - debug!("Received a cipher message"); - if let Err(e) = broadcast_message( - peers.clone(), - AnkFlag::Cipher, - ank_msg.content, - BroadcastType::ExcludeSender(addr), - ) { - log::error!("Failed to send message with error: {}", e); - } - } - AnkFlag::Unknown => { - debug!("Received an unknown message"); - if let Err(e) = broadcast_message( - peers.clone(), - AnkFlag::Unknown, - ank_msg.content, - BroadcastType::ExcludeSender(addr), - ) { - log::error!("Failed to send message with error: {}", e); - } - } - }, - Err(_) => log::error!("Failed to parse network message"), - } + process_message( + raw_msg, + peers.clone(), + sp_wallet.clone(), + shared_daemon.clone(), + addr, + ); } else { - // we don't care - log::debug!("Received non-text message {} from peer {}", msg, addr); + debug!("Received non-text message {} from peer {}", msg, addr); } future::ok(()) }); @@ -598,7 +379,11 @@ async fn main() -> Result<()> { } // Subscribe to Bitcoin Core - tokio::spawn(handle_zmq(peers.clone(), shared_daemon.clone(), config.zmq_url)); + tokio::spawn(handle_zmq( + peers.clone(), + shared_daemon.clone(), + config.zmq_url, + )); // Create the event loop and TCP listener we'll accept connections on. let try_socket = TcpListener::bind("127.0.0.1:9090").await; diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..ec58234 --- /dev/null +++ b/src/message.rs @@ -0,0 +1,271 @@ +use anyhow::{Error, Result}; +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex, OnceLock}, + time::{Duration, Instant}, +}; +use tokio::time; +use tokio_tungstenite::tungstenite::Message; + +use sdk_common::network::{AnkFlag, AnkNetworkMsg, FaucetMessage, NewTxMessage}; + +use crate::{ + daemon::Daemon, faucet::handle_faucet_request, handle_new_tx_request, PeerMap, + SilentPaymentWallet, +}; + +pub(crate) static MESSAGECACHE: OnceLock = OnceLock::new(); + +const MESSAGECACHEDURATION: Duration = Duration::from_secs(10); +const MESSAGECACHEINTERVAL: Duration = Duration::from_secs(2); + +#[derive(Debug)] +pub(crate) struct MessageCache { + store: Mutex>, +} + +impl MessageCache { + pub fn new() -> Self { + Self { + store: Mutex::new(HashMap::new()), + } + } + + fn insert(&self, key: String) { + let mut store = self.store.lock().unwrap(); + store.insert(key.clone(), Instant::now()); + } + + fn contains(&self, key: &str) -> bool { + let store = self.store.lock().unwrap(); + store.contains_key(key) + } + + pub async fn clean_up() { + let cache = MESSAGECACHE.get().unwrap(); + + let mut interval = time::interval(MESSAGECACHEINTERVAL); + + loop { + interval.tick().await; + + let mut store = cache.store.lock().unwrap(); + + let now = Instant::now(); + let to_rm: Vec = store + .iter() + .filter_map(|(entry, entrytime)| { + if let Some(duration) = now.checked_duration_since(*entrytime) { + if duration > MESSAGECACHEDURATION { + Some(entry.clone()) + } else { + None + } + } else { + None + } + }) + .collect(); + + for key in to_rm { + store.remove(&key); + } + } + } +} + +pub(crate) enum BroadcastType { + Sender(SocketAddr), + #[allow(dead_code)] + ExcludeSender(SocketAddr), + #[allow(dead_code)] + ToAll, +} + +pub(crate) fn broadcast_message( + peers: PeerMap, + flag: AnkFlag, + payload: String, + broadcast: BroadcastType, +) -> Result<()> { + let ank_msg = AnkNetworkMsg { + flag, + content: payload, + }; + let msg = Message::Text(serde_json::to_string(&ank_msg)?); + log::debug!("Broadcasting message: {}", msg); + match broadcast { + BroadcastType::Sender(addr) => { + peers + .lock() + .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? + .iter() + .find(|(peer_addr, _)| peer_addr == &&addr) + .ok_or(Error::msg("Failed to find the sender in the peer_map"))? + .1 + .send(msg)?; + } + BroadcastType::ExcludeSender(addr) => { + peers + .lock() + .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? + .iter() + .filter(|(peer_addr, _)| peer_addr != &&addr) + .for_each(|(_, peer_tx)| { + let _ = peer_tx.send(msg.clone()); + }); + } + BroadcastType::ToAll => { + peers + .lock() + .map_err(|e| Error::msg(format!("Failed to lock peers: {}", e.to_string())))? + .iter() + .for_each(|(_, peer_tx)| { + let _ = peer_tx.send(msg.clone()); + }); + } + } + Ok(()) +} + +fn process_faucet_message( + ank_msg: AnkNetworkMsg, + peers: PeerMap, + sp_wallet: Arc, + shared_daemon: Arc>, + addr: SocketAddr, +) { + log::debug!("Received a faucet message"); + if let Ok(mut content) = serde_json::from_str::(&ank_msg.content) { + match handle_faucet_request(&content, sp_wallet.clone(), shared_daemon.clone()) { + Ok(new_tx_msg) => { + log::debug!( + "Obtained new_tx_msg: {}", + serde_json::to_string(&new_tx_msg).unwrap() + ); + } + Err(e) => { + log::error!("Failed to send faucet tx: {}", e); + content.error = Some(e.into()); + let payload = serde_json::to_string(&content).expect("Message type shouldn't fail"); + if let Err(e) = broadcast_message( + peers.clone(), + AnkFlag::Faucet, + payload, + BroadcastType::Sender(addr), + ) { + log::error!("Failed to broadcast message: {}", e); + } + } + } + } else { + log::error!("Invalid content for faucet message"); + } +} + +fn process_new_tx_message( + ank_msg: AnkNetworkMsg, + peers: PeerMap, + shared_daemon: Arc>, + addr: SocketAddr, +) { + log::debug!("Received a new tx message"); + if let Ok(mut new_tx_msg) = serde_json::from_str::(&ank_msg.content) { + match handle_new_tx_request(&mut new_tx_msg, shared_daemon.clone()) { + Ok(new_tx_msg) => { + // Repeat the msg to all except sender + if let Err(e) = broadcast_message( + peers.clone(), + AnkFlag::NewTx, + serde_json::to_string(&new_tx_msg).expect("This should not fail"), + BroadcastType::ExcludeSender(addr), + ) { + log::error!("Failed to send message with error: {}", e); + } + } + Err(e) => { + log::error!("handle_new_tx_request returned error: {}", e); + new_tx_msg.error = Some(e.into()); + if let Err(e) = broadcast_message( + peers.clone(), + AnkFlag::NewTx, + serde_json::to_string(&new_tx_msg).expect("This shouldn't fail"), + BroadcastType::Sender(addr), + ) { + log::error!("Failed to broadcast message: {}", e); + } + } + } + } else { + log::error!("Invalid content for new_tx message"); + } +} + +fn process_cipher_message( + ank_msg: AnkNetworkMsg, + peers: PeerMap, + addr: SocketAddr, +) { + // For now we just send it to everyone + log::debug!("Received a cipher message"); + + if let Err(e) = broadcast_message( + peers.clone(), + AnkFlag::Cipher, + ank_msg.content, + BroadcastType::ExcludeSender(addr), + ) { + log::error!("Failed to send message with error: {}", e); + } +} + +fn process_unknown_message( + ank_msg: AnkNetworkMsg, + peers: PeerMap, + addr: SocketAddr, +) { + log::debug!("Received an unknown message"); + if let Err(e) = broadcast_message( + peers.clone(), + AnkFlag::Unknown, + ank_msg.content, + BroadcastType::ExcludeSender(addr), + ) { + log::error!("Failed to send message with error: {}", e); + } +} + +pub fn process_message( + raw_msg: &str, + peers: PeerMap, + sp_wallet: Arc, + shared_daemon: Arc>, + addr: SocketAddr, +) { + log::debug!("Received msg: {}", raw_msg); + let cache = MESSAGECACHE.get().expect("Cache should be initialized"); + if cache.contains(raw_msg) { + log::debug!("Message already processed, dropping"); + return; + } else { + cache.insert(raw_msg.to_owned()); + } + match serde_json::from_str::(raw_msg) { + Ok(ank_msg) => match ank_msg.flag { + AnkFlag::Faucet => { + process_faucet_message(ank_msg, peers, sp_wallet, shared_daemon, addr) + } + AnkFlag::NewTx => { + process_new_tx_message(ank_msg, peers, shared_daemon, addr) + } + AnkFlag::Cipher => { + process_cipher_message(ank_msg, peers, addr) + } + AnkFlag::Unknown => { + process_unknown_message(ank_msg, peers, addr) + } + }, + Err(_) => log::error!("Failed to parse network message"), + } +}