Skip to content

Commit

Permalink
Remove dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
clarus committed Sep 18, 2024
1 parent 66281db commit 8047c53
Showing 1 changed file with 53 additions and 40 deletions.
93 changes: 53 additions & 40 deletions coq/scripts/shallow_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,41 @@ def variables_names_to_coq(as_pattern: bool, variable_names) -> str:
return variable_name_to_coq(variable_names[0].get('name'))
else:
quote = "'" if as_pattern else ""
return quote + f"({', '.join(variable_name_to_coq(variable_name.get('name')) for variable_name in variable_names)})"
return \
quote + "(" + \
', '.join(
variable_name_to_coq(variable_name.get('name'))
for variable_name in variable_names
) + \
")"

def node_in_block_to_coq(level: int, node):
def node_in_block_to_coq(node):
node_type = node.get('nodeType')

if node_type in ['YulVariableDeclaration', 'YulAssignment']:
return node_to_coq(level, node)
return node_to_coq(node)

elif node_type in ['YulIf', 'YulSwitch']:
return \
"do~ [[\n" + \
indent(node_to_coq(level + 1, node)) + "\n" + \
indent(node_to_coq(node)) + "\n" + \
"]] in"

elif node_type in ['YulBlock', 'YulForLoop']:
return \
"do~\n" + \
indent(node_to_coq(level + 1, node)) + "\n" + \
indent(node_to_coq(node)) + "\n" + \
"in"

return \
"do~ [[ " + node_to_coq(level, node) + " ]] in"
"do~ [[ " + node_to_coq(node) + " ]] in"

def block_to_coq(level: int, node, result: str) -> str:
def block_to_coq(node, result: str) -> str:
node_type = node.get('nodeType')

if node_type == 'YulBlock':
statements = [
node_in_block_to_coq(level, stmt)
node_in_block_to_coq(stmt)
for stmt in node.get('statements', [])
if stmt.get('nodeType') != 'YulFunctionDefinition'
] + [result]
Expand Down Expand Up @@ -98,24 +104,24 @@ def is_function_pure(function_name: str) -> bool:
# return function_name in pure_functions
return False

def node_to_coq(level: int, node) -> str:
def node_to_coq(node) -> str:
if isinstance(node, dict):
node_type = node.get('nodeType')

if node_type == 'YulBlock':
return block_to_coq(level, node, "M.pure tt")
return block_to_coq(node, "M.pure tt")

elif node_type == 'YulFunctionDefinition':
return "(* Function definition should be handled at top level *)"

elif node_type == 'YulVariableDeclaration':
variables = variables_names_to_coq(True, node.get('variables', []))
value = node_to_coq(level + 1, node.get('value'))
value = node_to_coq(node.get('value'))
return f"let~ {variables} := [[ {value} ]] in"

elif node_type == 'YulAssignment':
variable = variables_names_to_coq(True, node.get('variableNames'))
value = node_to_coq(level + 1, node.get('value'))
value = node_to_coq(node.get('value'))
return f"let~ {variable} := [[ {value} ]] in"

elif node_type == 'YulFunctionCall':
Expand All @@ -126,7 +132,7 @@ def node_to_coq(level: int, node) -> str:
args: list[str] = [
paren(
arg.get('nodeType') not in ['YulLiteral', 'YulIdentifier'],
node_to_coq(level + 1, arg),
node_to_coq(arg),
)
for arg in node.get('arguments', [])
]
Expand All @@ -152,21 +158,21 @@ def node_to_coq(level: int, node) -> str:
return node.get('value', 'Unknown literal')

elif node_type == 'YulExpressionStatement':
return node_to_coq(level + 1, node.get('expression'))
return node_to_coq(node.get('expression'))

elif node_type == 'YulIf':
condition = node_to_coq(level, node.get('condition'))
true_body = node_to_coq(level + 1, node.get('body'))
condition = node_to_coq(node.get('condition'))
true_body = node_to_coq(node.get('body'))
return \
f"M.if_unit (| {condition},\n" + \
indent(true_body) + "\n" + \
"|)"

elif node_type == 'YulSwitch':
expression = node_to_coq(level, node.get('expression'))
expression = node_to_coq(node.get('expression'))
cases = [
f"if δ =? {node_to_coq(level, case.get('value'))} then\n" + \
indent(node_to_coq(level + 1, case.get('body')))
f"if δ =? {node_to_coq(case.get('value'))} then\n" + \
indent(node_to_coq(case.get('body')))
for case in node.get('cases', [])
]
return \
Expand All @@ -186,10 +192,10 @@ def node_to_coq(level: int, node) -> str:
return "M.continue"

elif node_type == 'YulForLoop':
pre = node_in_block_to_coq(level, node.get('pre'))
condition = node_to_coq(level + 1, node.get('condition'))
post = node_to_coq(level + 1, node.get('post'))
body = node_to_coq(level + 1, node.get('body'))
pre = node_in_block_to_coq(node.get('pre'))
condition = node_to_coq(node.get('condition'))
post = node_to_coq(node.get('post'))
body = node_to_coq(node.get('body'))

return \
"(* for loop *)\n" + \
Expand Down Expand Up @@ -226,20 +232,22 @@ def function_result_type(arity: int) -> str:

return "(" + " * ".join(["U256.t"] * arity) + ")"

def function_definition_to_coq(level: int, node) -> str:
def function_definition_to_coq(node) -> str:
name = variable_name_to_coq(node.get('name'))
params = ''.join([
" (" + variable_name_to_coq(p['name']) + " : U256.t)"
for p in node.get('parameters', [])
])
result = function_result_value(node.get('returnVariables', []))
body = block_to_coq(level + 1, node.get('body'), result)
body = block_to_coq(node.get('body'), result)
return \
f"Definition {name}{params} : M.t {function_result_type(len(node.get('returnVariables', [])))} :=\n" + \
indent(body + ".")
f"Definition {name}{params} : M.t " + \
function_result_type(len(node.get('returnVariables', []))) + " :=\n" + \
indent(body) + "."

# Get the names of the functions called in a function.
# We take care of sorting the names in alphabetical order so that the output is deterministic.
# We take care of sorting the names in alphabetical order so that the output is
# deterministic.
def get_function_dependencies(function_node) -> list[str]:
dependencies = set()

Expand Down Expand Up @@ -296,16 +304,19 @@ def dfs(node, visited, stack, path):

def order_functions(ordered_names: list[str], function_nodes: list) -> list:
# Create a dictionary for quick lookup of index in ordered_names
name_order: dict[str, int] = {name: index for index, name in enumerate(ordered_names)}
name_order: dict[str, int] = {
name: index
for index, name in enumerate(ordered_names)
}

# Define a key function that returns the index of the function name in ordered_names
def key_func(node):
return name_order.get(node.get('name'), len(ordered_names)) # Put unknown names at the end
return name_order.get(node.get('name'), len(ordered_names))

# Sort the function_nodes using the key function
return sorted(function_nodes, key=key_func)

def top_level_to_coq(level: int, node) -> str:
def top_level_to_coq(node) -> str:
node_type = node.get('nodeType')

if node_type == 'YulBlock':
Expand All @@ -316,35 +327,37 @@ def top_level_to_coq(level: int, node) -> str:
dependencies = get_function_dependencies(statement)
functions_dependencies[function_name] = dependencies
ordered_function_names = topological_sort(functions_dependencies)
ordered_functions = \
order_functions(ordered_function_names, node.get('statements', []))
functions = [
function_definition_to_coq(level, stmt)
for stmt in order_functions(ordered_function_names, node.get('statements', []))
if stmt.get('nodeType') == 'YulFunctionDefinition'
function_definition_to_coq(function)
for function in ordered_functions
if function.get('nodeType') == 'YulFunctionDefinition'
]
body = \
"Definition body : M.t unit :=\n" + \
indent(node_to_coq(level + 1, node)) + "."
indent(node_to_coq(node)) + "."
return ("\n\n").join(functions + [body])

return f"(* Unsupported top-level node type: {node_type} *)"

def object_to_coq(level: int, node) -> str:
def object_to_coq(node) -> str:
node_type = node.get('nodeType')

if node_type == 'YulObject':
return \
"Module " + node['name'] + ".\n" + \
indent(object_to_coq(level + 1, node['code'])) + "\n" + \
indent(object_to_coq(node['code'])) + "\n" + \
"".join(
"\n" +
indent(object_to_coq(level + 1, child)) + "\n"
indent(object_to_coq(child)) + "\n"
for child in node.get('subObjects', [])
if child.get('nodeType') != 'YulData'
) + \
"End " + node['name'] + "."

elif node_type == 'YulCode':
return top_level_to_coq(level, node['block'])
return top_level_to_coq(node['block'])

elif node_type == 'YulData':
return "(* Data object not expected *)"
Expand All @@ -356,7 +369,7 @@ def main():
with open(sys.argv[1], 'r') as file:
data = json.load(file)

coq_code = object_to_coq(0, data)
coq_code = object_to_coq(data)

print("(* Generated by " + Path(__file__).name + " *)")
print("Require Import CoqOfSolidity.CoqOfSolidity.")
Expand Down

0 comments on commit 8047c53

Please sign in to comment.