From 120733e48098bcf8f17bc7515dddf1eac6e8a12f Mon Sep 17 00:00:00 2001 From: chillpill91 Date: Thu, 25 Apr 2024 14:05:32 +0100 Subject: [PATCH] update execution.rs file --- host/src/execution.rs | 59 ++++++++++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/host/src/execution.rs b/host/src/execution.rs index f48efbe32..3d2ace4d0 100644 --- a/host/src/execution.rs +++ b/host/src/execution.rs @@ -1,4 +1,4 @@ -use std::str::FromStr; +use std::{convert::TryFrom, str::FromStr}; use raiko_lib::{ builder::{BlockBuilderStrategy, TaikoStrategy}, @@ -16,26 +16,30 @@ use tracing::{info, warn}; use super::error::Result; use crate::{error::HostError, memory, preflight::preflight, request::ProofRequest}; +/// Execute the proof generation process pub async fn execute( - config: &serde_json::Value, + config: &Config, cached_input: Option, ) -> Result<(GuestInput, Proof)> { let total_proving_time = Measurement::start("", false); // Generate the input - let input = if let Some(cached_input) = cached_input { - println!("Using cached input"); - cached_input - } else { - memory::reset_stats(); - let measurement = Measurement::start("Generating input...", false); - let input = prepare_input(config).await?; - measurement.stop_with("=> Input generated"); - memory::print_stats("Input generation peak memory used: "); - input + let input = match cached_input { + Some(cached_input) => { + println!("Using cached input"); + cached_input + } + None => { + memory::reset_stats(); + let measurement = Measurement::start("Generating input...", false); + let input = prepare_input(config).await?; + measurement.stop_with("=> Input generated"); + memory::print_stats("Input generation peak memory used: "); + input + } }; - // 2. Test run the block + // Test run the block memory::reset_stats(); match TaikoStrategy::build_from(&input) { Ok((header, _mpt_node)) => { @@ -59,14 +63,12 @@ pub async fn execute( let res = D::run(input.clone(), output, config) .await .map(|proof| (input, proof)) - .map_err(|e| HostError::GuestError(e.to_string()))?; - + .map_err(|e| HostError::GuestError(e.to_string())); measurement.stop_with("=> Proof generated"); memory::print_stats("Prover peak memory used: "); total_proving_time.stop_with("====> Complete proof generated"); - - Ok(res) + res } Err(e) => { warn!("Proving bad block construction!"); @@ -75,9 +77,9 @@ pub async fn execute( } } -/// prepare input data for provers -pub async fn prepare_input(config: &serde_json::Value) -> Result { - let req = ProofRequest::deserialize(config).unwrap(); +/// Prepare input data for provers +async fn prepare_input(config: &Config) -> Result { + let req = ProofRequest::try_from(config)?; let block_number = req.block_number; let rpc = req.rpc.clone(); let l1_rpc = req.l1_rpc.clone(); @@ -111,7 +113,7 @@ impl Prover for NativeDriver { async fn run( _input: GuestInput, output: GuestOutput, - _request: &serde_json::Value, + _request: &Config, ) -> ProverResult { to_proof(Ok(NativeResponse { output })) } @@ -121,8 +123,23 @@ impl Prover for NativeDriver { } } +use std::convert::TryFrom; + +impl TryFrom<&serde_json::Value> for Config { + type Error = HostError; + + fn try_from(value: &serde_json::Value) -> Result { + let config: Config = serde_json::from_value(value.clone()) + .map_err(|e| HostError::DeserializeError(format!("Failed to deserialize config: {}", e)))?; + Ok(config) + } +} + + #[cfg(test)] mod tests { + use super::*; + #[tokio::test] async fn test_async_block() { let result = async { Result::<(), &'static str>::Err("error") };