sdk_common/src/updates.rs
2025-11-27 16:54:59 +01:00

248 lines
7.2 KiB
Rust

use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::mem;
use spdk_core::{
bitcoin::{absolute::Height, BlockHash, OutPoint},
OwnedOutput, Updater,
};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Serialize, Deserialize)]
pub struct ScanProgress {
pub start: u32,
pub current: u32,
pub end: u32,
}
#[derive(Debug, Serialize, Deserialize)]
pub enum StateUpdate {
NoUpdate {
blkheight: Height,
},
Update {
blkheight: Height,
blkhash: BlockHash,
found_outputs: HashMap<OutPoint, OwnedOutput>,
found_inputs: HashSet<OutPoint>,
},
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
use std::sync::{mpsc::{self, Receiver, Sender}};
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
pub trait UpdateSink: Send + Sync {
fn send_scan_progress(&self, progress: ScanProgress) -> Result<()>;
fn send_state_update(&self, update: StateUpdate) -> Result<()>;
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
pub trait UpdateSink {
fn send_scan_progress(&self, progress: ScanProgress) -> Result<()>;
fn send_state_update(&self, update: StateUpdate) -> Result<()>;
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
pub struct NativeUpdateSink {
scan_tx: Sender<ScanProgress>,
state_tx: Sender<StateUpdate>,
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
impl NativeUpdateSink {
pub fn new() -> (Self, Receiver<ScanProgress>, Receiver<StateUpdate>) {
let (scan_tx, scan_rx) = mpsc::channel();
let (state_tx, state_rx) = mpsc::channel();
(Self { scan_tx, state_tx }, scan_rx, state_rx)
}
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
impl UpdateSink for NativeUpdateSink {
fn send_scan_progress(&self, progress: ScanProgress) -> Result<()> {
self.scan_tx.send(progress)?;
Ok(())
}
fn send_state_update(&self, update: StateUpdate) -> Result<()> {
self.state_tx.send(update)?;
Ok(())
}
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
use futures::channel::mpsc::{unbounded, UnboundedSender, UnboundedReceiver};
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
pub struct WasmUpdateSink {
scan_tx: UnboundedSender<ScanProgress>,
state_tx: UnboundedSender<StateUpdate>,
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
impl WasmUpdateSink {
pub fn new() -> (Rc<Self>, UnboundedReceiver<ScanProgress>, UnboundedReceiver<StateUpdate>) {
let (scan_tx, scan_rx) = unbounded();
let (state_tx, state_rx) = unbounded();
(Rc::new(Self { scan_tx, state_tx }), scan_rx, state_rx)
}
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
impl UpdateSink for WasmUpdateSink {
fn send_scan_progress(&self, progress: ScanProgress) -> Result<()> {
self.scan_tx.unbounded_send(progress)
.map_err(|e| anyhow::Error::msg(format!("Failed to send scan progress: {}", e)))?;
Ok(())
}
fn send_state_update(&self, update: StateUpdate) -> Result<()> {
self.state_tx.unbounded_send(update)
.map_err(|e| anyhow::Error::msg(format!("Failed to send state update: {}", e)))?;
Ok(())
}
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
use std::sync::{Arc, RwLock};
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
// Global sink instance
static UPDATE_SINK: RwLock<Option<Arc<dyn UpdateSink>>> = RwLock::new(None);
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
use std::cell::RefCell;
use std::rc::Rc;
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
thread_local! {
static UPDATE_SINK: RefCell<Option<Rc<dyn UpdateSink>>> = RefCell::new(None);
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
pub fn init_update_sink(sink: Arc<dyn UpdateSink>) {
let mut sink_guard = UPDATE_SINK.write().unwrap();
*sink_guard = Some(sink);
}
#[cfg(all(not(target_arch = "wasm32"), not(feature = "blindbit-wasm")))]
pub fn get_update_sink() -> Option<Arc<dyn UpdateSink>> {
UPDATE_SINK.read().unwrap().clone()
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
pub fn init_update_sink(sink: Rc<dyn UpdateSink>) {
UPDATE_SINK.with(|cell| {
*cell.borrow_mut() = Some(sink);
});
}
#[cfg(all(target_arch = "wasm32", feature = "blindbit-wasm"))]
pub fn get_update_sink() -> Option<Rc<dyn UpdateSink>> {
UPDATE_SINK.with(|cell| cell.borrow().clone())
}
#[derive(Debug)]
pub struct StateUpdater {
update: bool,
blkhash: Option<BlockHash>,
blkheight: Option<Height>,
found_outputs: HashMap<OutPoint, OwnedOutput>,
found_inputs: HashSet<OutPoint>,
}
impl StateUpdater {
pub fn new() -> Self {
Self {
update: false,
blkheight: None,
blkhash: None,
found_outputs: HashMap::new(),
found_inputs: HashSet::new(),
}
}
pub fn to_update(&mut self) -> Result<StateUpdate> {
let blkheight = self
.blkheight
.ok_or(anyhow::Error::msg("blkheight not filled"))?;
if self.update {
self.update = false;
let blkhash = self.blkhash.ok_or(anyhow::Error::msg("blkhash not set"))?;
self.blkheight = None;
self.blkhash = None;
// take results, and insert new empty values
let found_inputs = mem::take(&mut self.found_inputs);
let found_outputs = mem::take(&mut self.found_outputs);
Ok(StateUpdate::Update {
blkheight,
blkhash,
found_outputs,
found_inputs,
})
} else {
Ok(StateUpdate::NoUpdate { blkheight })
}
}
}
impl Updater for StateUpdater {
fn record_scan_progress(&mut self, start: Height, current: Height, end: Height) -> Result<()> {
self.blkheight = Some(current);
if let Some(sink) = get_update_sink() {
sink.send_scan_progress(ScanProgress {
start: start.to_consensus_u32(),
current: current.to_consensus_u32(),
end: end.to_consensus_u32(),
})?;
}
Ok(())
}
fn record_block_outputs(
&mut self,
height: Height,
blkhash: BlockHash,
found_outputs: HashMap<OutPoint, OwnedOutput>,
) -> Result<()> {
// may have already been written by record_block_inputs
self.update = true;
self.found_outputs = found_outputs;
self.blkhash = Some(blkhash);
self.blkheight = Some(height);
Ok(())
}
fn record_block_inputs(
&mut self,
blkheight: Height,
blkhash: BlockHash,
found_inputs: HashSet<OutPoint>,
) -> Result<()> {
self.update = true;
self.blkheight = Some(blkheight);
self.blkhash = Some(blkhash);
self.found_inputs = found_inputs;
Ok(())
}
fn save_to_persistent_storage(&mut self) -> Result<()> {
if let Some(sink) = get_update_sink() {
sink.send_state_update(self.to_update()?)?;
}
Ok(())
}
}