Skip to content

Commit

Permalink
better prank (foundry-rs#246)
Browse files Browse the repository at this point in the history
  • Loading branch information
brockelmore authored Dec 17, 2021
1 parent e36dea2 commit b60d973
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 78 deletions.
42 changes: 9 additions & 33 deletions evm-adapters/src/sputnik/cheatcodes/cheatcode_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,7 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> CheatcodeStackExecutor<'a, 'b, B, P>

/// Given a transaction's calldata, it tries to parse it as an [`HEVM cheatcode`](super::HEVM)
/// call and modify the state accordingly.
fn apply_cheatcode(
&mut self,
input: Vec<u8>,
transfer: Option<Transfer>,
target_gas: Option<u64>,
) -> Capture<(ExitReason, Vec<u8>), Infallible> {
fn apply_cheatcode(&mut self, input: Vec<u8>) -> Capture<(ExitReason, Vec<u8>), Infallible> {
let mut res = vec![];

// Get a mutable ref to the state so we can apply the cheats
Expand Down Expand Up @@ -382,31 +377,7 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> CheatcodeStackExecutor<'a, 'b, B, P>
}
HEVMCalls::Prank(inner) => {
let caller = inner.0;
let address = inner.1;
let input = inner.2;

let value =
if let Some(ref transfer) = transfer { transfer.value } else { U256::zero() };

// change origin
let context = Context { caller, address, apparent_value: value };
let ret = self.call(
address,
Some(Transfer { source: caller, target: address, value }),
input.to_vec(),
target_gas,
false,
context,
);
res = match ret {
Capture::Exit((successful, v)) => match successful {
ExitReason::Succeed(_) => {
ethers::abi::encode(&[Token::Bool(true), Token::Bytes(v.to_vec())])
}
_ => ethers::abi::encode(&[Token::Bool(false), Token::Bytes(v.to_vec())]),
},
_ => vec![],
};
self.state_mut().next_msg_sender = Some(caller);
}
HEVMCalls::ExpectRevert(inner) => {
if self.state().expected_revert.is_some() {
Expand Down Expand Up @@ -757,9 +728,14 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> Handler for CheatcodeStackExecutor<'a
// (e.g. with the StateManager)

let expected_revert = self.state_mut().expected_revert.take();
let caller = self.state_mut().next_msg_sender.take();
let mut new_context = context;
if let Some(caller) = caller {
new_context.caller = caller;
}

if code_address == *CHEATCODE_ADDRESS {
self.apply_cheatcode(input, transfer, target_gas)
self.apply_cheatcode(input)
} else if code_address == *CONSOLE_ADDRESS {
self.console_log(input)
} else {
Expand All @@ -771,7 +747,7 @@ impl<'a, 'b, B: Backend, P: PrecompileSet> Handler for CheatcodeStackExecutor<'a
is_static,
true,
true,
context,
new_context,
);

if let Some(expected_revert) = expected_revert {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub struct MemoryStackStateOwned<'config, B> {
pub backend: B,
pub substate: MemoryStackSubstate<'config>,
pub expected_revert: Option<Vec<u8>>,
pub next_msg_sender: Option<H160>,
}

impl<'config, B: Backend> MemoryStackStateOwned<'config, B> {
Expand All @@ -26,7 +27,12 @@ impl<'config, B: Backend> MemoryStackStateOwned<'config, B> {

impl<'config, B: Backend> MemoryStackStateOwned<'config, B> {
pub fn new(metadata: StackSubstateMetadata<'config>, backend: B) -> Self {
Self { backend, substate: MemoryStackSubstate::new(metadata), expected_revert: None }
Self {
backend,
substate: MemoryStackSubstate::new(metadata),
expected_revert: None,
next_msg_sender: None,
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion evm-adapters/src/sputnik/cheatcodes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ethers::contract::abigen!(
ffi(string[])(bytes)
addr(uint256)(address)
sign(uint256,bytes32)(uint8,bytes32,bytes32)
prank(address,address,bytes)(bool,bytes)
prank(address)
deal(address,uint256)
etch(address,bytes)
expectRevert(bytes)
Expand Down
48 changes: 5 additions & 43 deletions evm-adapters/testdata/CheatCodes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ interface Hevm {
function addr(uint256) external returns (address);
// Performs a foreign function call via terminal, (stringInputs) => (result)
function ffi(string[] calldata) external returns (bytes memory);
// Calls another contract with a specified `msg.sender`, (newSender, contract, input) => (success, returnData)
function prank(address, address, bytes calldata) external payable returns (bool, bytes memory);
// Sets the *next* call's msg.sender to be the input address
function prank(address) external;
// Sets an address' balance, (who, newBalance)
function deal(address, uint256) external;
// Sets an address' code, (who, newCode)
Expand Down Expand Up @@ -155,50 +155,12 @@ contract CheatCodes is DSTest {
function testPrank() public {
Prank prank = new Prank();
address new_sender = address(1337);
bytes4 sig = prank.checksOriginAndSender.selector;
string memory input = "And his name is JOHN CENA!";
bytes memory calld = abi.encodePacked(sig, abi.encode(input));
address origin = tx.origin;
address sender = msg.sender;
(bool success, bytes memory ret) = hevm.prank(new_sender, address(prank), calld);
assertTrue(success);
string memory expectedRetString = "SUPER SLAM!";
string memory actualRet = abi.decode(ret, (string));
assertEq(actualRet, expectedRetString);

// make sure we returned back to normal
assertEq(origin, tx.origin);
assertEq(sender, msg.sender);
}

function testPrankValue() public {
Prank prank = new Prank();
// setup the call
address new_sender = address(1337);
bytes4 sig = prank.checksOriginAndSender.selector;
hevm.prank(new_sender);
string memory input = "And his name is JOHN CENA!";
bytes memory calld = abi.encodePacked(sig, abi.encode(input));
address origin = tx.origin;
address sender = msg.sender;

// give the sender some monies
hevm.deal(new_sender, 1337);

// call the function passing in a value. the eth is pulled from the new sender
sig = hevm.prank.selector;
calld = abi.encodePacked(sig, abi.encode(new_sender, address(prank), calld));

// this is nested low level calls effectively
(bool high_level_success, bytes memory outerRet) = address(hevm).call{value: 1}(calld);
assertTrue(high_level_success);
(bool success, bytes memory ret) = abi.decode(outerRet, (bool,bytes));
assertTrue(success);
string memory retString = prank.checksOriginAndSender(input);
string memory expectedRetString = "SUPER SLAM!";
string memory actualRet = abi.decode(ret, (string));
assertEq(actualRet, expectedRetString);

// make sure we returned back to normal
assertEq(origin, tx.origin);
assertEq(retString, expectedRetString);
assertEq(sender, msg.sender);
}

Expand Down

0 comments on commit b60d973

Please sign in to comment.