Skip to content

Commit

Permalink
perf(trie): use local ThreadPool in Parallel::multiproof (paradigmxyz…
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez authored Dec 19, 2024
1 parent 5639552 commit d1b3dee
Showing 1 changed file with 75 additions and 14 deletions.
89 changes: 75 additions & 14 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ use reth_trie::{
};
use reth_trie_common::proof::ProofRetainer;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseTrieCursorFactory};
use std::sync::Arc;
use tracing::{debug, error};
use std::{sync::Arc, time::Instant};
use tracing::{debug, error, trace};

#[cfg(feature = "metrics")]
use crate::metrics::ParallelStateRootMetrics;

/// TODO:
#[derive(Debug)]
pub struct ParallelProof<Factory> {
pub struct ParallelProof<'env, Factory> {
/// Consistent view of the database.
view: ConsistentDbView<Factory>,
/// The sorted collection of cached in-memory intermediate trie nodes that
Expand All @@ -46,25 +46,29 @@ pub struct ParallelProof<Factory> {
pub prefix_sets: Arc<TriePrefixSetsMut>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Thread pool for local tasks
thread_pool: &'env rayon::ThreadPool,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
}

impl<Factory> ParallelProof<Factory> {
impl<'env, Factory> ParallelProof<'env, Factory> {
/// Create new state proof generator.
pub fn new(
view: ConsistentDbView<Factory>,
nodes_sorted: Arc<TrieUpdatesSorted>,
state_sorted: Arc<HashedPostStateSorted>,
prefix_sets: Arc<TriePrefixSetsMut>,
thread_pool: &'env rayon::ThreadPool,
) -> Self {
Self {
view,
nodes_sorted,
state_sorted,
prefix_sets,
collect_branch_node_hash_masks: false,
thread_pool,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
Expand All @@ -77,7 +81,7 @@ impl<Factory> ParallelProof<Factory> {
}
}

impl<Factory> ParallelProof<Factory>
impl<Factory> ParallelProof<'_, Factory>
where
Factory: DatabaseProviderFactory<Provider: BlockReader>
+ StateCommitmentProvider
Expand Down Expand Up @@ -112,26 +116,50 @@ where
prefix_sets.account_prefix_set.iter().map(|nibbles| B256::from_slice(&nibbles.pack())),
prefix_sets.storage_prefix_sets.clone(),
);
let storage_root_targets_len = storage_root_targets.len();

debug!(
target: "trie::parallel_state_root",
total_targets = storage_root_targets_len,
"Starting parallel proof generation"
);

// Pre-calculate storage roots for accounts which were changed.
tracker.set_precomputed_storage_roots(storage_root_targets.len() as u64);
debug!(target: "trie::parallel_state_root", len = storage_root_targets.len(), "pre-generating storage proofs");

let mut storage_proofs =
B256HashMap::with_capacity_and_hasher(storage_root_targets.len(), Default::default());

for (hashed_address, prefix_set) in
storage_root_targets.into_iter().sorted_unstable_by_key(|(address, _)| *address)
{
let view = self.view.clone();
let target_slots = targets.get(&hashed_address).cloned().unwrap_or_default();

let trie_nodes_sorted = self.nodes_sorted.clone();
let hashed_state_sorted = self.state_sorted.clone();
let collect_masks = self.collect_branch_node_hash_masks;

let (tx, rx) = std::sync::mpsc::sync_channel(1);

rayon::spawn_fifo(move || {
self.thread_pool.spawn_fifo(move || {
debug!(
target: "trie::parallel",
?hashed_address,
"Starting proof calculation"
);

let task_start = Instant::now();
let result = (|| -> Result<_, ParallelStateRootError> {
let provider_start = Instant::now();
let provider_ro = view.provider_ro()?;
trace!(
target: "trie::parallel",
?hashed_address,
provider_time = ?provider_start.elapsed(),
"Got provider"
);

let cursor_start = Instant::now();
let trie_cursor_factory = InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
&trie_nodes_sorted,
Expand All @@ -140,19 +168,42 @@ where
DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
&hashed_state_sorted,
);
trace!(
target: "trie::parallel",
?hashed_address,
cursor_time = ?cursor_start.elapsed(),
"Created cursors"
);

StorageProof::new_hashed(
let proof_start = Instant::now();
let proof_result = StorageProof::new_hashed(
trie_cursor_factory,
hashed_cursor_factory,
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.with_branch_node_hash_masks(collect_masks)
.storage_multiproof(target_slots)
.map_err(|e| ParallelStateRootError::Other(e.to_string()))
.map_err(|e| ParallelStateRootError::Other(e.to_string()));

trace!(
target: "trie::parallel",
?hashed_address,
proof_time = ?proof_start.elapsed(),
"Completed proof calculation"
);

proof_result
})();
if let Err(err) = tx.send(result) {
error!(target: "trie::parallel", ?hashed_address, err_content = ?err.0, "Failed to send proof result");

if let Err(e) = tx.send(result) {
error!(
target: "trie::parallel",
?hashed_address,
error = ?e,
task_time = ?task_start.elapsed(),
"Failed to send proof result"
);
}
});
storage_proofs.insert(hashed_address, rx);
Expand Down Expand Up @@ -338,12 +389,22 @@ mod tests {
let trie_cursor_factory = DatabaseTrieCursorFactory::new(provider_rw.tx_ref());
let hashed_cursor_factory = DatabaseHashedCursorFactory::new(provider_rw.tx_ref());

let num_threads =
std::thread::available_parallelism().map_or(1, |num| (num.get() / 2).max(1));

let state_root_task_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("proof-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool");

assert_eq!(
ParallelProof::new(
consistent_view,
Default::default(),
Default::default(),
Default::default()
Default::default(),
&state_root_task_pool
)
.multiproof(targets.clone())
.unwrap(),
Expand Down

0 comments on commit d1b3dee

Please sign in to comment.