Skip to content

Commit

Permalink
feat(engine): integrate state root task and comment it (paradigmxyz#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin authored Dec 17, 2024
1 parent e663f95 commit 48fee88
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 56 deletions.
18 changes: 8 additions & 10 deletions crates/engine/tree/benches/state_root_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use revm_primitives::{
Account as RevmAccount, AccountInfo, AccountStatus, Address, EvmState, EvmStorageSlot, HashMap,
B256, KECCAK_EMPTY, U256,
};
use std::sync::Arc;

#[derive(Debug, Clone)]
struct BenchParams {
Expand Down Expand Up @@ -137,16 +136,15 @@ fn bench_state_root(c: &mut Criterion) {
let state_updates = create_bench_state_updates(params);
setup_provider(&factory, &state_updates).expect("failed to setup provider");

let trie_input = Arc::new(TrieInput::from_state(Default::default()));

let config = StateRootConfig {
consistent_view: ConsistentDbView::new(factory, None),
input: trie_input,
};
let trie_input = TrieInput::from_state(Default::default());
let config = StateRootConfig::new_from_input(
ConsistentDbView::new(factory, None),
trie_input,
);
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let prefix_sets = Arc::new(config.input.prefix_sets.clone());
let nodes_sorted = config.nodes_sorted.clone();
let state_sorted = config.state_sorted.clone();
let prefix_sets = config.prefix_sets.clone();

(config, state_updates, provider, nodes_sorted, state_sorted, prefix_sets)
},
Expand Down
108 changes: 82 additions & 26 deletions crates/engine/tree/src/tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2224,13 +2224,47 @@ where

let exec_time = Instant::now();

// TODO: create StateRootTask with the receiving end of a channel and
// pass the sending end of the channel to the state hook.
let noop_state_hook = |_state: &EvmState| {};
let persistence_not_in_progress = !self.persistence_state.in_progress();

// TODO: uncomment to use StateRootTask

// let (state_root_handle, state_hook) = if persistence_not_in_progress {
// let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;
//
// let state_root_config = StateRootConfig::new_from_input(
// consistent_view.clone(),
// self.compute_trie_input(consistent_view, block.header().parent_hash())
// .map_err(ParallelStateRootError::into)?,
// );
//
// let provider_ro = consistent_view.provider_ro()?;
// let nodes_sorted = state_root_config.nodes_sorted.clone();
// let state_sorted = state_root_config.state_sorted.clone();
// let prefix_sets = state_root_config.prefix_sets.clone();
// let blinded_provider_factory = ProofBlindedProviderFactory::new(
// InMemoryTrieCursorFactory::new(
// DatabaseTrieCursorFactory::new(provider_ro.tx_ref()),
// &nodes_sorted,
// ),
// HashedPostStateCursorFactory::new(
// DatabaseHashedCursorFactory::new(provider_ro.tx_ref()),
// &state_sorted,
// ),
// prefix_sets,
// );
//
// let state_root_task = StateRootTask::new(state_root_config,
// blinded_provider_factory); let state_hook = state_root_task.state_hook();
// (Some(state_root_task.spawn(scope)), Box::new(state_hook) as Box<dyn OnStateHook>)
// } else {
// (None, Box::new(|_state: &EvmState| {}) as Box<dyn OnStateHook>)
// };
let state_hook = Box::new(|_state: &EvmState| {});

let output = self.metrics.executor.execute_metered(
executor,
(&block, U256::MAX).into(),
Box::new(noop_state_hook),
state_hook,
)?;

trace!(target: "engine::tree", elapsed=?exec_time.elapsed(), ?block_number, "Executed block");
Expand All @@ -2253,33 +2287,47 @@ where

trace!(target: "engine::tree", block=?sealed_block.num_hash(), "Calculating block state root");
let root_time = Instant::now();
let mut state_root_result = None;

// TODO: switch to calculate state root using `StateRootTask`.

// We attempt to compute state root in parallel if we are currently not persisting anything
// to database. This is safe, because the database state cannot change until we
// finish parallel computation. It is important that nothing is being persisted as
// we are computing in parallel, because we initialize a different database transaction
// per thread and it might end up with a different view of the database.
let persistence_in_progress = self.persistence_state.in_progress();
if !persistence_in_progress {
state_root_result = match self
.compute_state_root_parallel(block.header().parent_hash(), &hashed_state)
{
Ok((state_root, trie_output)) => Some((state_root, trie_output)),
let state_root_result = if persistence_not_in_progress {
// TODO: uncomment to use StateRootTask

// if let Some(state_root_handle) = state_root_handle {
// match state_root_handle.wait_for_result() {
// Ok((task_state_root, task_trie_updates)) => {
// info!(
// target: "engine::tree",
// block = ?sealed_block.num_hash(),
// ?task_state_root,
// "State root task finished"
// );
// }
// Err(error) => {
// info!(target: "engine::tree", ?error, "Failed to wait for state root task
// result"); }
// }
// }

match self.compute_state_root_parallel(block.header().parent_hash(), &hashed_state) {
Ok(result) => Some(result),
Err(ParallelStateRootError::Provider(ProviderError::ConsistentView(error))) => {
debug!(target: "engine", %error, "Parallel state root computation failed consistency check, falling back");
None
}
Err(error) => return Err(InsertBlockErrorKindTwo::Other(Box::new(error))),
};
}
}
} else {
None
};

let (state_root, trie_output) = if let Some(result) = state_root_result {
result
} else {
debug!(target: "engine::tree", block=?sealed_block.num_hash(), persistence_in_progress, "Failed to compute state root in parallel");
debug!(target: "engine::tree", block=?sealed_block.num_hash(), ?persistence_not_in_progress, "Failed to compute state root in parallel");
state_provider.state_root_with_updates(hashed_state.clone())?
};

Expand Down Expand Up @@ -2344,14 +2392,25 @@ where
parent_hash: B256,
hashed_state: &HashedPostState,
) -> Result<(B256, TrieUpdates), ParallelStateRootError> {
// TODO: when we switch to calculate state root using `StateRootTask` this
// method can be still useful to calculate the required `TrieInput` to
// create the task.
let consistent_view = ConsistentDbView::new_with_latest_tip(self.provider.clone())?;

let mut input = self.compute_trie_input(consistent_view.clone(), parent_hash)?;
// Extend with block we are validating root for.
input.append_ref(hashed_state);

ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
}

/// Computes the trie input at the provided parent hash.
fn compute_trie_input(
&self,
consistent_view: ConsistentDbView<P>,
parent_hash: B256,
) -> Result<TrieInput, ParallelStateRootError> {
let mut input = TrieInput::default();

if let Some((historical, blocks)) = self.state.tree_state.blocks_by_hash(parent_hash) {
debug!(target: "engine::tree", %parent_hash, %historical, "Calculating state root in parallel, parent found in memory");
debug!(target: "engine::tree", %parent_hash, %historical, "Parent found in memory");
// Retrieve revert state for historical block.
let revert_state = consistent_view.revert_state(historical)?;
input.append(revert_state);
Expand All @@ -2362,15 +2421,12 @@ where
}
} else {
// The block attaches to canonical persisted parent.
debug!(target: "engine::tree", %parent_hash, "Calculating state root in parallel, parent found in disk");
debug!(target: "engine::tree", %parent_hash, "Parent found on disk");
let revert_state = consistent_view.revert_state(parent_hash)?;
input.append(revert_state);
}

// Extend with block we are validating root for.
input.append_ref(hashed_state);

ParallelStateRoot::new(consistent_view, input).incremental_root_with_updates()
Ok(input)
}

/// Handles an error that occurred while inserting a block.
Expand Down Expand Up @@ -2648,7 +2704,7 @@ mod tests {
use reth_primitives::{Block, BlockExt, EthPrimitives};
use reth_provider::test_utils::MockEthProvider;
use reth_rpc_types_compat::engine::{block_to_payload_v1, payload::block_to_payload_v3};
use reth_trie::updates::TrieUpdates;
use reth_trie::{updates::TrieUpdates, HashedPostState};
use std::{
str::FromStr,
sync::mpsc::{channel, Sender},
Expand Down
72 changes: 52 additions & 20 deletions crates/engine/tree/src/tree/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@ use reth_provider::{
StateCommitmentProvider,
};
use reth_trie::{
proof::Proof, updates::TrieUpdates, HashedPostState, HashedStorage, MultiProof,
MultiProofTargets, Nibbles, TrieInput,
hashed_cursor::HashedPostStateCursorFactory,
prefix_set::TriePrefixSetsMut,
proof::Proof,
trie_cursor::InMemoryTrieCursorFactory,
updates::{TrieUpdates, TrieUpdatesSorted},
HashedPostState, HashedPostStateSorted, HashedStorage, MultiProof, MultiProofTargets, Nibbles,
TrieInput,
};
use reth_trie_db::DatabaseProof;
use reth_trie_db::{DatabaseHashedCursorFactory, DatabaseProof, DatabaseTrieCursorFactory};
use reth_trie_parallel::root::ParallelStateRootError;
use reth_trie_sparse::{
blinded::{BlindedProvider, BlindedProviderFactory},
Expand Down Expand Up @@ -72,12 +77,31 @@ impl StateRootHandle {
}

/// Common configuration for state root tasks
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct StateRootConfig<Factory> {
/// View over the state in the database.
pub consistent_view: ConsistentDbView<Factory>,
/// Latest trie input.
pub input: Arc<TrieInput>,
/// The sorted collection of cached in-memory intermediate trie nodes that
/// can be reused for computation.
pub nodes_sorted: Arc<TrieUpdatesSorted>,
/// The sorted in-memory overlay hashed state.
pub state_sorted: Arc<HashedPostStateSorted>,
/// The collection of prefix sets for the computation. Since the prefix sets _always_
/// invalidate the in-memory nodes, not all keys from `state_sorted` might be present here,
/// if we have cached nodes for them.
pub prefix_sets: Arc<TriePrefixSetsMut>,
}

impl<Factory> StateRootConfig<Factory> {
/// Creates a new state root config from the consistent view and the trie input.
pub fn new_from_input(consistent_view: ConsistentDbView<Factory>, input: TrieInput) -> Self {
Self {
consistent_view,
nodes_sorted: Arc::new(input.nodes.into_sorted()),
state_sorted: Arc::new(input.state.into_sorted()),
prefix_sets: Arc::new(input.prefix_sets),
}
}
}

/// Messages used internally by the state root task
Expand Down Expand Up @@ -321,8 +345,7 @@ where
/// Returns proof targets derived from the state update.
fn on_state_update(
scope: &rayon::Scope<'env>,
view: ConsistentDbView<Factory>,
input: Arc<TrieInput>,
config: StateRootConfig<Factory>,
update: EvmState,
fetched_proof_targets: &mut MultiProofTargets,
proof_sequence_number: u64,
Expand All @@ -335,7 +358,7 @@ where

// Dispatch proof gathering for this state update
scope.spawn(move |_| {
let provider = match view.provider_ro() {
let provider = match config.consistent_view.provider_ro() {
Ok(provider) => provider,
Err(error) => {
error!(target: "engine::root", ?error, "Could not get provider");
Expand All @@ -346,11 +369,18 @@ where
};

// TODO: replace with parallel proof
let result = Proof::overlay_multiproof(
provider.tx_ref(),
input.as_ref().clone(),
proof_targets.clone(),
);
let result = Proof::from_tx(provider.tx_ref())
.with_trie_cursor_factory(InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
&config.nodes_sorted,
))
.with_hashed_cursor_factory(HashedPostStateCursorFactory::new(
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&config.state_sorted,
))
.with_prefix_sets_mut(config.prefix_sets.as_ref().clone())
.with_branch_node_hash_masks(true)
.multiproof(proof_targets.clone());
match result {
Ok(proof) => {
let _ = state_root_message_sender.send(StateRootMessage::ProofCalculated(
Expand Down Expand Up @@ -472,8 +502,7 @@ where
);
Self::on_state_update(
scope,
self.config.consistent_view.clone(),
self.config.input.clone(),
self.config.clone(),
update,
&mut self.fetched_proof_targets,
self.proof_sequencer.next_sequence(),
Expand Down Expand Up @@ -859,13 +888,16 @@ mod tests {
}
}

let input = TrieInput::from_state(hashed_state);
let nodes_sorted = Arc::new(input.nodes.clone().into_sorted());
let state_sorted = Arc::new(input.state.clone().into_sorted());
let config = StateRootConfig {
consistent_view: ConsistentDbView::new(factory, None),
input: Arc::new(TrieInput::from_state(hashed_state)),
nodes_sorted: nodes_sorted.clone(),
state_sorted: state_sorted.clone(),
prefix_sets: Arc::new(input.prefix_sets),
};
let provider = config.consistent_view.provider_ro().unwrap();
let nodes_sorted = config.input.nodes.clone().into_sorted();
let state_sorted = config.input.state.clone().into_sorted();
let blinded_provider_factory = ProofBlindedProviderFactory::new(
InMemoryTrieCursorFactory::new(
DatabaseTrieCursorFactory::new(provider.tx_ref()),
Expand All @@ -875,7 +907,7 @@ mod tests {
DatabaseHashedCursorFactory::new(provider.tx_ref()),
&state_sorted,
),
Arc::new(config.input.prefix_sets.clone()),
config.prefix_sets.clone(),
);
let (root_from_task, _) = std::thread::scope(|std_scope| {
let task = StateRootTask::new(config, blinded_provider_factory);
Expand Down

0 comments on commit 48fee88

Please sign in to comment.