Skip to content

Commit

Permalink
refactor(evm): split trace encoding in several functions (foundry-rs#…
Browse files Browse the repository at this point in the history
…1312)

* refactor: move node to separate module

* refactor: split decoding in several functions

* fix: no need for sorted maps here
  • Loading branch information
mattsse authored Apr 14, 2022
1 parent 64cbc98 commit 8c6f624
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 158 deletions.
169 changes: 42 additions & 127 deletions evm/src/trace/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ use super::{
identifier::TraceIdentifier, CallTraceArena, RawOrDecodedCall, RawOrDecodedLog,
RawOrDecodedReturnData,
};
use crate::abi::{CHEATCODE_ADDRESS, CONSOLE_ABI, HEVM_ABI};
use crate::{
abi::{CHEATCODE_ADDRESS, CONSOLE_ABI, HEVM_ABI},
trace::{node::CallTraceNode, utils},
};
use ethers::{
abi::{Abi, Address, Event, Function, Param, ParamType, Token},
types::H256,
};
use foundry_utils::format_token;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};

/// The call trace decoder.
///
Expand All @@ -20,13 +22,13 @@ use std::collections::BTreeMap;
#[derive(Default, Debug)]
pub struct CallTraceDecoder {
/// Information for decoding precompile calls.
pub precompiles: BTreeMap<Address, Function>,
pub precompiles: HashMap<Address, Function>,
/// Addresses identified to be a specific contract.
///
/// The values are in the form `"<artifact>:<contract>"`.
pub contracts: BTreeMap<Address, String>,
pub contracts: HashMap<Address, String>,
/// Address labels
pub labels: BTreeMap<Address, String>,
pub labels: HashMap<Address, String>,
/// A mapping of addresses to their known functions
pub functions: BTreeMap<[u8; 4], Vec<Function>>,
/// All known events
Expand Down Expand Up @@ -114,7 +116,7 @@ impl CallTraceDecoder {
),
]
.into(),
contracts: BTreeMap::new(),
contracts: Default::default(),
labels: [(CHEATCODE_ADDRESS, "VM".to_string())].into(),
functions: HEVM_ABI
.functions()
Expand Down Expand Up @@ -189,97 +191,22 @@ impl CallTraceDecoder {
pub fn decode(&self, traces: &mut CallTraceArena) {
for node in traces.arena.iter_mut() {
// Set contract name
if let Some(contract) = self.contracts.get(&node.trace.address) {
node.trace.contract = Some(contract.clone());
if let Some(contract) = self.contracts.get(&node.trace.address).cloned() {
node.trace.contract = Some(contract);
}

// Set label
if let Some(label) = self.labels.get(&node.trace.address) {
node.trace.label = Some(label.clone());
if let Some(label) = self.labels.get(&node.trace.address).cloned() {
node.trace.label = Some(label);
}

// Decode call
if let RawOrDecodedCall::Raw(bytes) = &node.trace.data {
if let Some(precompile_fn) = self.precompiles.get(&node.trace.address) {
node.trace.label = Some("PRECOMPILE".to_string());
node.trace.data = RawOrDecodedCall::Decoded(
precompile_fn.name.clone(),
precompile_fn.decode_input(&bytes[..]).map_or_else(
|_| vec![hex::encode(&bytes)],
|tokens| tokens.iter().map(|token| self.apply_label(token)).collect(),
),
);

if let RawOrDecodedReturnData::Raw(bytes) = &node.trace.output {
node.trace.output = RawOrDecodedReturnData::Decoded(
precompile_fn.decode_output(&bytes[..]).map_or_else(
|_| hex::encode(&bytes),
|tokens| {
tokens
.iter()
.map(|token| self.apply_label(token))
.collect::<Vec<_>>()
.join(", ")
},
),
);
}
} else if bytes.len() >= 4 {
if let Some(precompile_fn) = self.precompiles.get(&node.trace.address) {
node.decode_precompile(precompile_fn, &self.labels);
} else if let RawOrDecodedCall::Raw(ref bytes) = node.trace.data {
if bytes.len() >= 4 {
if let Some(funcs) = self.functions.get(&bytes[0..4]) {
// This is safe because (1) we would not have an entry for the given
// selector if no functions with that selector were added and (2) the same
// selector implies the function has the same name and inputs.
let func = &funcs[0];

// Decode inputs
let inputs = if !bytes[4..].is_empty() {
if node.trace.address == CHEATCODE_ADDRESS {
// Try to decode cheatcode inputs in a more custom way
self.decode_cheatcode_inputs(func, bytes).unwrap_or_else(|| {
func.decode_input(&bytes[4..])
.expect("bad function input decode")
.iter()
.map(|token| self.apply_label(token))
.collect()
})
} else {
match func.decode_input(&bytes[4..]) {
Ok(v) => {
v.iter().map(|token| self.apply_label(token)).collect()
}
Err(_) => Vec::new(),
}
}
} else {
Vec::new()
};
node.trace.data = RawOrDecodedCall::Decoded(func.name.clone(), inputs);

if let RawOrDecodedReturnData::Raw(bytes) = &node.trace.output {
if !bytes.is_empty() {
if node.trace.success {
if let Some(tokens) = funcs
.iter()
.find_map(|func| func.decode_output(&bytes[..]).ok())
{
node.trace.output = RawOrDecodedReturnData::Decoded(
tokens
.iter()
.map(|token| self.apply_label(token))
.collect::<Vec<_>>()
.join(", "),
);
}
} else if let Ok(decoded_error) =
foundry_utils::decode_revert(&bytes[..], Some(&self.errors))
{
node.trace.output = RawOrDecodedReturnData::Decoded(format!(
r#""{}""#,
decoded_error
));
}
}
}
node.decode_function(funcs, &self.labels, &self.errors);
}
} else {
node.trace.data = RawOrDecodedCall::Decoded("fallback".to_string(), Vec::new());
Expand All @@ -300,50 +227,38 @@ impl CallTraceDecoder {
}

// Decode events
node.logs.iter_mut().for_each(|log| {
if let RawOrDecodedLog::Raw(raw_log) = log {
if let Some(events) =
self.events.get(&(raw_log.topics[0], raw_log.topics.len() - 1))
{
for event in events {
if let Ok(decoded) = event.parse_log(raw_log.clone()) {
*log = RawOrDecodedLog::Decoded(
event.name.clone(),
decoded
.params
.into_iter()
.map(|param| (param.name, self.apply_label(&param.value)))
.collect(),
);
break
}
}
}
}
});
self.decode_events(node);
}
}

fn apply_label(&self, token: &Token) -> String {
match token {
Token::Address(addr) => {
if let Some(label) = self.labels.get(addr) {
format!("{}: [{:?}]", label, addr)
} else {
format_token(token)
fn decode_events(&self, node: &mut CallTraceNode) {
node.logs.iter_mut().for_each(|log| {
self.decode_event(log);
});
}

fn decode_event(&self, log: &mut RawOrDecodedLog) {
if let RawOrDecodedLog::Raw(raw_log) = log {
if let Some(events) = self.events.get(&(raw_log.topics[0], raw_log.topics.len() - 1)) {
for event in events {
if let Ok(decoded) = event.parse_log(raw_log.clone()) {
*log = RawOrDecodedLog::Decoded(
event.name.clone(),
decoded
.params
.into_iter()
.map(|param| (param.name, self.apply_label(&param.value)))
.collect(),
);
break
}
}
}
_ => format_token(token),
}
}

fn decode_cheatcode_inputs(&self, func: &Function, data: &[u8]) -> Option<Vec<String>> {
match func.name.as_str() {
"expectRevert" => foundry_utils::decode_revert(data, Some(&self.errors))
.ok()
.map(|decoded| vec![decoded]),
_ => None,
}
fn apply_label(&self, token: &Token) -> String {
utils::label(token, &self.labels)
}
}

Expand Down
22 changes: 4 additions & 18 deletions evm/src/trace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
pub mod identifier;

mod decoder;
mod node;
mod utils;

pub use decoder::CallTraceDecoder;

use crate::{abi::CHEATCODE_ADDRESS, CallKind};
Expand All @@ -12,6 +15,7 @@ use ethers::{
abi::{Address, RawLog},
types::U256,
};
use node::CallTraceNode;
use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
Expand Down Expand Up @@ -198,24 +202,6 @@ pub enum LogCallOrder {
Call(usize),
}

#[derive(Default, Debug, Clone, Serialize, Deserialize)]
/// A node in the arena
pub struct CallTraceNode {
/// Parent node index in the arena
pub parent: Option<usize>,
/// Children node indexes in the arena
pub children: Vec<usize>,
/// This node's index in the arena
pub idx: usize,
/// The call trace
pub trace: CallTrace,
/// Logs
#[serde(skip)]
pub logs: Vec<RawOrDecodedLog>,
/// Ordering of child calls and logs
pub ordering: Vec<LogCallOrder>,
}

// TODO: Maybe unify with output
/// Raw or decoded calldata.
#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down
126 changes: 126 additions & 0 deletions evm/src/trace/node.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use crate::{
executor::CHEATCODE_ADDRESS,
trace::{
utils, CallTrace, LogCallOrder, RawOrDecodedCall, RawOrDecodedLog, RawOrDecodedReturnData,
},
};
use ethers::{
abi::{Abi, Function},
types::Address,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// A node in the arena
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct CallTraceNode {
/// Parent node index in the arena
pub parent: Option<usize>,
/// Children node indexes in the arena
pub children: Vec<usize>,
/// This node's index in the arena
pub idx: usize,
/// The call trace
pub trace: CallTrace,
/// Logs
#[serde(skip)]
pub logs: Vec<RawOrDecodedLog>,
/// Ordering of child calls and logs
pub ordering: Vec<LogCallOrder>,
}

impl CallTraceNode {
/// Decode a regular function
pub fn decode_function(
&mut self,
funcs: &[Function],
labels: &HashMap<Address, String>,
errors: &Abi,
) {
debug_assert!(!funcs.is_empty(), "requires at least 1 func");
// This is safe because (1) we would not have an entry for the given
// selector if no functions with that selector were added and (2) the
// same selector implies the function has
// the same name and inputs.
let func = &funcs[0];

if let RawOrDecodedCall::Raw(ref bytes) = self.trace.data {
let inputs = if !bytes[4..].is_empty() {
if self.trace.address == CHEATCODE_ADDRESS {
// Try to decode cheatcode inputs in a more custom way
utils::decode_cheatcode_inputs(func, bytes, errors).unwrap_or_else(|| {
func.decode_input(&bytes[4..])
.expect("bad function input decode")
.iter()
.map(|token| utils::label(token, labels))
.collect()
})
} else {
match func.decode_input(&bytes[4..]) {
Ok(v) => v.iter().map(|token| utils::label(token, labels)).collect(),
Err(_) => Vec::new(),
}
}
} else {
Vec::new()
};
self.trace.data = RawOrDecodedCall::Decoded(func.name.clone(), inputs);

if let RawOrDecodedReturnData::Raw(bytes) = &self.trace.output {
if !bytes.is_empty() {
if self.trace.success {
if let Some(tokens) =
funcs.iter().find_map(|func| func.decode_output(&bytes[..]).ok())
{
self.trace.output = RawOrDecodedReturnData::Decoded(
tokens
.iter()
.map(|token| utils::label(token, labels))
.collect::<Vec<_>>()
.join(", "),
);
}
} else if let Ok(decoded_error) =
foundry_utils::decode_revert(&bytes[..], Some(errors))
{
self.trace.output =
RawOrDecodedReturnData::Decoded(format!(r#""{}""#, decoded_error));
}
}
}
}
}

/// Decode the node's tracing data for the given precompile function
pub fn decode_precompile(
&mut self,
precompile_fn: &Function,
labels: &HashMap<Address, String>,
) {
if let RawOrDecodedCall::Raw(ref bytes) = self.trace.data {
self.trace.label = Some("PRECOMPILE".to_string());
self.trace.data = RawOrDecodedCall::Decoded(
precompile_fn.name.clone(),
precompile_fn.decode_input(&bytes[..]).map_or_else(
|_| vec![hex::encode(&bytes)],
|tokens| tokens.iter().map(|token| utils::label(token, labels)).collect(),
),
);

if let RawOrDecodedReturnData::Raw(ref bytes) = self.trace.output {
self.trace.output = RawOrDecodedReturnData::Decoded(
precompile_fn.decode_output(&bytes[..]).map_or_else(
|_| hex::encode(&bytes),
|tokens| {
tokens
.iter()
.map(|token| utils::label(token, labels))
.collect::<Vec<_>>()
.join(", ")
},
),
);
}
}
}
}
Loading

0 comments on commit 8c6f624

Please sign in to comment.