diff --git a/core/Cargo.toml b/core/Cargo.toml index 8509db15e..1311e5a67 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -70,3 +70,4 @@ ethers-core = { workspace = true } sp1 = ["dep:sp1-driver", "sp1-driver/enable"] risc0 = ["dep:risc0-driver", "risc0-driver/enable"] sgx = ["dep:sgx-prover", "sgx-prover/enable"] +test-utils = [] diff --git a/core/src/provider/mod.rs b/core/src/provider/mod.rs index 3d7f30ce6..7b5356da2 100644 --- a/core/src/provider/mod.rs +++ b/core/src/provider/mod.rs @@ -11,6 +11,7 @@ use crate::{ }; pub mod db; +pub mod persistent_map; pub mod rpc; #[allow(async_fn_in_trait)] diff --git a/core/src/provider/persistent_map.rs b/core/src/provider/persistent_map.rs new file mode 100644 index 000000000..f6af306ee --- /dev/null +++ b/core/src/provider/persistent_map.rs @@ -0,0 +1,156 @@ +use alloy_rpc_types::EIP1186AccountProofResponse; +use alloy_rpc_types::EIP1186StorageProof; +use reth_primitives::{Address, U256}; +use reth_revm::primitives::AccountInfo; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, fs::File, hash::Hash, io::Write, path::PathBuf}; + +#[derive(Hash, Eq, PartialEq)] +pub struct StorageSlotKey { + address: Address, + slot: U256, +} + +impl Serialize for StorageSlotKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + format!("{}:{}", self.address, self.slot).serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for StorageSlotKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + let mut parts = s.split(':'); + let address = parts + .next() + .ok_or_else(|| serde::de::Error::custom("Missing address"))? + .parse() + .map_err(serde::de::Error::custom)?; + let slot = parts + .next() + .ok_or_else(|| serde::de::Error::custom("Missing slot"))? + .parse() + .map_err(serde::de::Error::custom)?; + Ok(Self { address, slot }) + } +} + +impl From<(Address, U256)> for StorageSlotKey { + fn from((address, slot): (Address, U256)) -> Self { + Self { address, slot } + } +} + +impl From for (Address, U256) { + fn from(key: StorageSlotKey) -> Self { + (key.address, key.slot) + } +} + +pub struct PersistentBlockData { + base: PathBuf, +} + +impl PersistentBlockData { + pub fn new(base: impl Into) -> Self { + Self { base: base.into() } + } + + pub fn accounts(&self, block_number: u64) -> PersistentMap { + PersistentMap::new(format!( + "{}/{}-accounts.json", + self.base.display(), + block_number + )) + } + + pub fn storage_values(&self, block_number: u64) -> PersistentMap { + PersistentMap::new(format!( + "{}/{}-storage_values.json", + self.base.display(), + block_number + )) + } + + pub fn account_proofs( + &self, + block_number: u64, + ) -> PersistentMap { + PersistentMap::new(format!( + "{}/{}-account_proofs.json", + self.base.display(), + block_number + )) + } + + pub fn account_storage_proofs( + &self, + block_number: u64, + ) -> PersistentMap { + PersistentMap::new(format!( + "{}/{}-account_storage_proofs.json", + self.base.display(), + block_number + )) + } +} + +/// A simple cache that implements most of the methods of `HashMap` and ensures that the data is persisted to a file. +pub struct PersistentMap< + K: Hash + Eq + Serialize + for<'de> Deserialize<'de>, + V: Serialize + for<'de> Deserialize<'de>, +> { + file_path: PathBuf, + map: HashMap, +} + +impl< + K: Hash + Eq + Serialize + for<'de> Deserialize<'de>, + V: Serialize + for<'de> Deserialize<'de>, + > PersistentMap +{ + pub fn new(file_path: impl Into) -> Self { + let file_path = file_path.into(); + if let Some(parent) = file_path.parent() { + std::fs::create_dir_all(parent).unwrap_or_else(|e| { + tracing::warn!("Failed to create directory {}: {}", parent.display(), e); + }); + } + + // Load the storage from the file + let map = File::open(&file_path) + .and_then(|file| { + serde_json::from_reader::<_, HashMap>(file) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e)) + }) + .unwrap_or(HashMap::new()); + + Self { file_path, map } + } + + pub fn contains_key(&self, key: &K) -> bool { + self.map.contains_key(key) + } + + pub fn get(&self, key: &K) -> Option<&V> { + self.map.get(key) + } + + pub fn insert(&mut self, key: K, value: V) -> Option { + self.map.insert(key, value) + } + + pub fn save(&self) -> std::io::Result<()> { + let mut file = File::create(&self.file_path)?; + serde_json::to_writer(&mut file, &self.map) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + file.flush()?; + Ok(()) + } +} diff --git a/core/src/provider/rpc.rs b/core/src/provider/rpc.rs index 495262430..cf12d506d 100644 --- a/core/src/provider/rpc.rs +++ b/core/src/provider/rpc.rs @@ -19,13 +19,28 @@ pub struct RpcBlockDataProvider { pub provider: ReqwestProvider, pub client: RpcClient>, block_number: u64, + + #[cfg(any(test, feature = "test-utils"))] + pub persistent_block_data: + std::sync::Arc>, } impl RpcBlockDataProvider { pub fn new(url: &str, block_number: u64) -> RaikoResult { let url = reqwest::Url::parse(url).map_err(|_| RaikoError::RPC("Invalid RPC URL".to_owned()))?; + Ok(Self { + #[cfg(any(test, feature = "test-utils"))] + persistent_block_data: ::std::sync::Arc::new(tokio::sync::Mutex::new( + super::persistent_map::PersistentBlockData::new(format!( + "../testdata/{}", + url.to_string() + .trim_start_matches("https://") + .trim_end_matches("/"), + )), + )), + provider: ProviderBuilder::new().on_provider(RootProvider::new_http(url.clone())), client: ClientBuilder::default().http(url), block_number, @@ -84,10 +99,23 @@ impl BlockDataProvider for RpcBlockDataProvider { } async fn get_accounts(&self, accounts: &[Address]) -> RaikoResult> { - let mut all_accounts = Vec::with_capacity(accounts.len()); + #[cfg(any(test, feature = "test-utils"))] + let all_accounts = &mut self + .persistent_block_data + .lock() + .await + .accounts(self.block_number); + + #[cfg(not(any(test, feature = "test-utils")))] + let mut all_accounts = HashMap::with_capacity(accounts.len()); + + let to_fetch_accounts: Vec<_> = accounts + .iter() + .filter(|address| !all_accounts.contains_key(*address)) + .collect(); let max_batch_size = 250; - for accounts in accounts.chunks(max_batch_size) { + for accounts in to_fetch_accounts.chunks(max_batch_size) { let mut batch = self.client.new_batch(); let mut nonce_requests = Vec::with_capacity(max_batch_size); @@ -134,10 +162,10 @@ impl BlockDataProvider for RpcBlockDataProvider { .await .map_err(|e| RaikoError::RPC(format!("Error sending batch request {e}")))?; - let mut accounts = vec![]; // Collect the data from the batch - for ((nonce_request, balance_request), code_request) in nonce_requests - .into_iter() + for (((address, nonce_request), balance_request), code_request) in accounts + .iter() + .zip(nonce_requests.into_iter()) .zip(balance_requests.into_iter()) .zip(code_requests.into_iter()) { @@ -160,21 +188,34 @@ impl BlockDataProvider for RpcBlockDataProvider { let bytecode = Bytecode::new_raw(code); let account_info = AccountInfo::new(balance, nonce, bytecode.hash_slow(), bytecode); - - accounts.push(account_info); + all_accounts.insert(**address, account_info); } - - all_accounts.append(&mut accounts); } - Ok(all_accounts) + Ok(accounts + .iter() + .map(|address| all_accounts.get(address).expect("checked above").clone()) + .collect()) } async fn get_storage_values(&self, accounts: &[(Address, U256)]) -> RaikoResult> { - let mut all_values = Vec::with_capacity(accounts.len()); + #[cfg(any(test, feature = "test-utils"))] + let all_values = &mut self + .persistent_block_data + .lock() + .await + .storage_values(self.block_number); + + #[cfg(not(any(test, feature = "test-utils")))] + let mut all_values: HashMap<(Address, U256), U256> = HashMap::with_capacity(accounts.len()); + + let to_fetch_slots: Vec<_> = accounts + .iter() + .filter(|(address, slot)| !all_values.contains_key(&(*address, *slot).into())) + .collect(); let max_batch_size = 1000; - for accounts in accounts.chunks(max_batch_size) { + for accounts in to_fetch_slots.chunks(max_batch_size) { let mut batch = self.client.new_batch(); let mut requests = Vec::with_capacity(max_batch_size); @@ -199,20 +240,20 @@ impl BlockDataProvider for RpcBlockDataProvider { .await .map_err(|e| RaikoError::RPC(format!("Error sending batch request {e}")))?; - let mut values = Vec::with_capacity(max_batch_size); // Collect the data from the batch - for request in requests { - values.push( - request.await.map_err(|e| { - RaikoError::RPC(format!("Error collecting request data: {e}")) - })?, - ); + for ((address, slot), request) in accounts.iter().zip(requests.into_iter()) { + let value = request + .await + .map_err(|e| RaikoError::RPC(format!("Error collecting request data: {e}")))?; + all_values.insert((*address, *slot).into(), value); } - - all_values.append(&mut values); } - Ok(all_values) + Ok(accounts + .iter() + .map(|(address, slot)| all_values.get(&(*address, *slot).into()).unwrap()) + .cloned() + .collect()) } async fn get_merkle_proofs( @@ -222,13 +263,34 @@ impl BlockDataProvider for RpcBlockDataProvider { offset: usize, num_storage_proofs: usize, ) -> RaikoResult { - let mut storage_proofs: MerkleProof = HashMap::new(); + #[cfg(any(test, feature = "test-utils"))] + let account_proofs = &mut self + .persistent_block_data + .lock() + .await + .account_proofs(block_number); + + #[cfg(any(test, feature = "test-utils"))] + let account_storage_proofs = &mut self + .persistent_block_data + .lock() + .await + .account_storage_proofs(block_number); + + #[cfg(not(any(test, feature = "test-utils")))] + let account_proofs = &mut HashMap::with_capacity(accounts.len()); + #[cfg(not(any(test, feature = "test-utils")))] + let account_storage_proofs = &mut HashMap::< + (Address, U256), + alloy_rpc_types::EIP1186StorageProof, + >::with_capacity(accounts.len()); + let mut idx = offset; - let mut accounts = accounts.clone(); + let mut accounts_mut = accounts.clone(); - let batch_limit = 1000; - while !accounts.is_empty() { + let batch_limit = 1; + while !accounts_mut.is_empty() { #[cfg(debug_assertions)] raiko_lib::inplace_print(&format!( "fetching storage proof {idx}/{num_storage_proofs}..." @@ -243,10 +305,10 @@ impl BlockDataProvider for RpcBlockDataProvider { let mut requests = Vec::new(); let mut batch_size = 0; - while !accounts.is_empty() && batch_size < batch_limit { + while !accounts_mut.is_empty() && batch_size < batch_limit { let mut address_to_remove = None; - if let Some((address, keys)) = accounts.iter_mut().next() { + if let Some((address, keys)) = accounts_mut.iter_mut().next() { // Calculate how many keys we can still process let num_keys_to_process = if batch_size + keys.len() < batch_limit { keys.len() @@ -260,40 +322,51 @@ impl BlockDataProvider for RpcBlockDataProvider { } // Extract the keys to process - let keys_to_process = keys - .drain(0..num_keys_to_process) - .map(StorageKey::from) + let keys_to_process = keys.drain(0..num_keys_to_process).collect::>(); + let to_fetch_keys = keys_to_process + .iter() + .filter(|key| { + !account_storage_proofs.contains_key(&(*address, **key).into()) + }) + .cloned() .collect::>(); - - // Add the request - requests.push(Box::pin( - batch - .add_call::<_, EIP1186AccountProofResponse>( - "eth_getProof", - &( - *address, - keys_to_process.clone(), - BlockId::from(block_number), - ), - ) - .map_err(|_| { - RaikoError::RPC( - "Failed adding eth_getProof call to batch".to_owned(), + if !to_fetch_keys.is_empty() || !account_proofs.contains_key(address) { + requests.push(Box::pin( + batch + .add_call::<_, EIP1186AccountProofResponse>( + "eth_getProof", + &( + *address, + to_fetch_keys + .iter() + .map(|key| StorageKey::from(*key)) + .collect::>(), + BlockId::from(block_number), + ), ) - })?, - )); + .map_err(|_| { + RaikoError::RPC( + "Failed adding eth_getProof call to batch".to_owned(), + ) + })?, + )); + } // Keep track of how many keys were processed // Add an additional 1 for the account proof itself - batch_size += 1 + keys_to_process.len(); + batch_size += 1 + to_fetch_keys.len(); } // Remove the address if all keys were processed for this account if let Some(address) = address_to_remove { - accounts.remove(&address); + accounts_mut.remove(&address); } } + if requests.is_empty() { + continue; + } + // Send the batch batch .send() @@ -302,19 +375,40 @@ impl BlockDataProvider for RpcBlockDataProvider { // Collect the data from the batch for request in requests { - let mut proof = request + let proof = request .await .map_err(|e| RaikoError::RPC(format!("Error collecting request data: {e}")))?; idx += proof.storage_proof.len(); - if let Some(map_proof) = storage_proofs.get_mut(&proof.address) { - map_proof.storage_proof.append(&mut proof.storage_proof); - } else { - storage_proofs.insert(proof.address, proof); + + if !account_proofs.contains_key(&proof.address) { + let mut account_only_proof = proof.clone(); + account_only_proof.storage_proof = vec![]; + account_proofs.insert(proof.address, account_only_proof); + } + + for slot_proof in proof.storage_proof { + account_storage_proofs.insert( + (proof.address.clone(), slot_proof.key.0.into()).into(), + slot_proof, + ); } } } clear_line(); - Ok(storage_proofs) + Ok(accounts + .into_iter() + .map(|(address, keys)| { + let mut account_proof = account_proofs.get(&address).unwrap().clone(); + account_proof.storage_proof = vec![]; + for key in keys { + let storage_proof = account_storage_proofs + .get(&(address.clone(), key).into()) + .expect("checked above"); + account_proof.storage_proof.push(storage_proof.clone()); + } + (address, account_proof) + }) + .collect()) } }