Skip to content

Commit

Permalink
feat: allow msg.data in raw_call without slice (vyperlang#2902)
Browse files Browse the repository at this point in the history
a common use case for raw_call is to forward all calldata to an
implementation or logic contract of some kind. however, currently vyper 
only allows msg.data inside of slice() or len(); this does not allow one
to forward calldata of any size. this enables that use case without
requiring the length of msg.data to be known. it does this by simply
copying calldata to the location pointed to by `msize`.

this commit also fixes an annotation in abi_decode, and slightly cleans
up the logic in the raw_call implementation.
  • Loading branch information
charles-cooper authored Jun 9, 2022
1 parent 4243cbd commit b8dea0c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 24 deletions.
29 changes: 29 additions & 0 deletions tests/parser/functions/test_raw_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,35 @@ def foo(_addr: address) -> int128:
assert caller.foo(target.address) == 42


def test_forward_calldata(get_contract, w3, keccak):
target_source = """
@external
def foo() -> uint256:
return 123
"""

caller_source = """
target: address
@external
def set_target(target: address):
self.target = target
@external
def __default__():
assert 123 == _abi_decode(raw_call(self.target, msg.data, max_outsize=32), uint256)
"""

target = get_contract(target_source)

caller = get_contract(caller_source)
caller.set_target(target.address, transact={})

# manually construct msg.data for `caller` contract
sig = keccak("foo()".encode()).hex()[:10]
w3.eth.send_transaction({"to": caller.address, "data": sig})


def test_static_call_fails_nonpayable(get_contract, assert_tx_failed):

target_source = """
Expand Down
48 changes: 25 additions & 23 deletions vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,8 +1169,25 @@ def build_IR(self, expr, args, kwargs, context):
expr,
)

eval_input_buf = ensure_in_memory(data, context)
input_buf = eval_seq(eval_input_buf)
if data.value == "~calldata":
call_ir = ["with", "mem_ofst", "msize"]
args_ofst = ["seq", ["calldatacopy", "mem_ofst", 0, "calldatasize"], "mem_ofst"]
args_len = "calldatasize"
else:
# some gymnastics to propagate constants (if eval_input_buf
# returns a static memory location)
eval_input_buf = ensure_in_memory(data, context)

input_buf = eval_seq(eval_input_buf)

if input_buf is None:
call_ir = ["with", "arg_buf", eval_input_buf]
input_buf = IRnode.from_list("arg_buf")
else:
call_ir = ["seq", eval_input_buf]

args_ofst = add_ofst(input_buf, 32)
args_len = ["mload", input_buf]

output_node = IRnode.from_list(
context.new_internal_variable(ByteArrayType(outsize)),
Expand All @@ -1180,16 +1197,10 @@ def build_IR(self, expr, args, kwargs, context):

bool_ty = BaseType("bool")

if input_buf is None:
call_ir = ["with", "arg_buf", eval_input_buf]
input_buf = IRnode.from_list("arg_buf")
else:
call_ir = ["seq", eval_input_buf]

# build IR for call or delegatecall
common_call_args = [
add_ofst(input_buf, 32),
["mload", input_buf], # buf len
args_ofst,
args_len,
# if there is no return value, the return offset can be 0
add_ofst(output_node, 32) if outsize else 0,
outsize,
Expand All @@ -1201,25 +1212,16 @@ def build_IR(self, expr, args, kwargs, context):
call_op = ["staticcall", gas, to, *common_call_args]
else:
call_op = ["call", gas, to, value, *common_call_args]

call_ir += [call_op]

# build sequence IR
if outsize:
# return minimum of outsize and returndatasize
size = [
"with",
"_l",
outsize,
["with", "_r", "returndatasize", ["if", ["gt", "_l", "_r"], "_r", "_l"]],
]
size = ["select", ["lt", outsize, "returndatasize"], outsize, "returndatasize"]

# store output size and return output location
store_output_size = [
"with",
"output_pos",
output_node,
["seq", ["mstore", "output_pos", size], "output_pos"],
]
store_output_size = ["seq", ["mstore", output_node, size], output_node]

bytes_ty = ByteArrayType(outsize)

Expand Down Expand Up @@ -2272,7 +2274,7 @@ def build_IR(self, expr, args, kwargs, context):
typ=output_typ,
location=data.location,
encoding=Encoding.ABI,
annotation="abi_decode {output_type}",
annotation=f"abi_decode({output_typ})",
)
)

Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/validation/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _validate_address_code_attribute(node: vy_ast.Attribute) -> None:
def _validate_msg_data_attribute(node: vy_ast.Attribute) -> None:
if isinstance(node.value, vy_ast.Name) and node.value.id == "msg" and node.attr == "data":
parent = node.get_ancestor()
if not isinstance(parent, vy_ast.Call) or parent.get("func.id") not in ("slice", "len"):
allowed_builtins = ("slice", "len", "raw_call")
if not isinstance(parent, vy_ast.Call) or parent.get("func.id") not in allowed_builtins:
raise StructureException(
"msg.data is only allowed inside of the slice or len functions", node
)
Expand Down

0 comments on commit b8dea0c

Please sign in to comment.