Skip to content

Commit

Permalink
Parse branch targets the same as global symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
simonlindholm committed Jan 1, 2024
1 parent 16c2eab commit 0514701
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 39 deletions.
1 change: 1 addition & 0 deletions m2c/arch_ppc.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def match(self, matcher: AsmMatcher) -> Optional[Replacement]:
isinstance(instr, Instruction)
and instr.mnemonic == "b"
and isinstance(instr.args[0], AsmGlobalSymbol)
and not matcher.branch_target_exists(instr.args[0].symbol_name)
):
return Replacement(
[
Expand Down
14 changes: 5 additions & 9 deletions m2c/asm_instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,10 @@ class JumpTarget:
target: str

def __str__(self) -> str:
return f"{self.target}"
return self.target


Argument = Union[
Register, AsmGlobalSymbol, AsmAddressMode, Macro, AsmLiteral, BinOp, JumpTarget
]
Argument = Union[Register, AsmGlobalSymbol, AsmAddressMode, Macro, AsmLiteral, BinOp]


@dataclass(frozen=True)
Expand Down Expand Up @@ -234,10 +232,8 @@ def replace_bare_reg(


def get_jump_target(label: Argument) -> JumpTarget:
if isinstance(label, AsmGlobalSymbol):
return JumpTarget(label.symbol_name)
assert isinstance(label, JumpTarget), "invalid branch target"
return label
assert isinstance(label, AsmGlobalSymbol), "invalid branch target"
return JumpTarget(label.symbol_name)


# Main parser.
Expand Down Expand Up @@ -284,7 +280,7 @@ def expect(n: str) -> str:
if word in ["data", "sdata", "rodata", "rdata", "bss", "sbss", "text"]:
value = asm_section_global_symbol(word, 0)
else:
value = JumpTarget("." + word)
value = AsmGlobalSymbol("." + word)
elif tok == "%":
# A MIPS reloc macro, e.g. %hi(...) or %lo(...).
assert value is None
Expand Down
17 changes: 11 additions & 6 deletions m2c/asm_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,11 @@ def match_arg(self, a: Argument, e: Argument) -> bool:
if isinstance(e, Register):
return isinstance(a, Register) and self.match_reg(a, e)
if isinstance(e, AsmGlobalSymbol):
if e.symbol_name.isupper():
if e.symbol_name.startswith("."):
return isinstance(a, AsmGlobalSymbol) and self.match_label_use(
e.symbol_name, a.symbol_name
)
elif e.symbol_name.isupper():
return isinstance(a, AsmLiteral) and self.match_var(
self.symbolic_literals, e.symbol_name, a.value
)
Expand All @@ -159,10 +163,6 @@ def match_arg(self, a: Argument, e: Argument) -> bool:
and self.match_arg(a.lhs, e.lhs)
and self.match_reg(a.rhs, e.rhs)
)
if isinstance(e, JumpTarget):
return isinstance(a, JumpTarget) and self.match_label_use(
e.target, a.target
)
if isinstance(e, BinOp):
return isinstance(a, AsmLiteral) and a.value == self.eval_math(e)
assert False, f"bad pattern part: {e}"
Expand Down Expand Up @@ -198,6 +198,7 @@ def match_meta(self, ins: AsmInstruction) -> bool:
@dataclass
class AsmMatcher:
input: List[BodyPart]
labels: Set[str]
output: List[BodyPart] = field(default_factory=list)
index: int = 0

Expand Down Expand Up @@ -226,6 +227,9 @@ def derived_meta(self) -> InstructionMeta:
return part.meta.derived()
return InstructionMeta.missing()

def branch_target_exists(self, name: str) -> bool:
return name in self.labels

def apply(self, repl: Replacement, arch: ArchAsm) -> None:
# Track which registers are overwritten/clobbered in the replacement asm
repl_writes = []
Expand Down Expand Up @@ -262,7 +266,8 @@ def simplify_patterns(
"""Detect and simplify asm standard patterns emitted by known compilers. This is
especially useful for patterns that involve branches, which are hard to deal with
in the translate phase."""
matcher = AsmMatcher(body)
labels = {name for item in body if isinstance(item, Label) for name in item.names}
matcher = AsmMatcher(body, labels)
while matcher.index < len(matcher.input):
for pattern in patterns:
m = pattern.match(matcher)
Expand Down
42 changes: 25 additions & 17 deletions m2c/flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,12 @@ def normalize_ido_likely_branches(function: Function, arch: ArchFlowGraph) -> Fu
new_label = f"_m2c_{old_label}_before"
label_before_instr[before_target] = new_label
insert_label_before[before_target] = new_label
new_target = JumpTarget(label_before_instr[before_target])
new_target = label_before_instr[before_target]
mn_unlikely = item.mnemonic[:-1] or "b"
item = arch.parse(
mn_unlikely, item.args[:-1] + [new_target], item.meta.derived()
mn_unlikely,
item.args[:-1] + [AsmGlobalSymbol(new_target)],
item.meta.derived(),
)
next_item = arch.parse("nop", [], item.meta.derived())
new_body.append((orig_item, item))
Expand Down Expand Up @@ -408,7 +410,7 @@ def build_blocks(

body_iter: Iterator[Union[Instruction, Label]] = iter(function.body)
branch_likely_counts: Counter[str] = Counter()
cond_return_target: Optional[JumpTarget] = None
cond_return_target: Optional[str] = None

def process_mips(item: Union[Instruction, Label]) -> None:
if isinstance(item, Label):
Expand Down Expand Up @@ -462,28 +464,32 @@ def process_mips(item: Union[Instruction, Label]) -> None:
process_after.append(label)
process_after.append(next_item.clone())
else:
target = item.jump_target
assert isinstance(target, JumpTarget), "has delay slot and isn't a call"
assert isinstance(
item.jump_target, JumpTarget
), "has delay slot and isn't a call"
target = item.jump_target.target
temp_label = f"_m2c_{label.names[0]}_skip"
meta = item.meta.derived()
nop = arch.parse("nop", [], meta)
block_builder.add_instruction(
arch.parse(
item.mnemonic,
item.args[:-1] + [JumpTarget(temp_label)],
item.args[:-1] + [AsmGlobalSymbol(temp_label)],
item.meta,
)
)
block_builder.add_instruction(nop)
block_builder.new_block()
block_builder.add_instruction(
arch.parse("b", [JumpTarget(label.names[0])], item.meta.derived())
arch.parse(
"b", [AsmGlobalSymbol(label.names[0])], item.meta.derived()
)
)
block_builder.add_instruction(nop.clone())
block_builder.new_block()
block_builder.set_label(Label([temp_label]))
block_builder.add_instruction(
arch.parse("b", [target], item.meta.derived())
arch.parse("b", [AsmGlobalSymbol(target)], item.meta.derived())
)
block_builder.add_instruction(next_item.clone())
block_builder.new_block()
Expand All @@ -498,14 +504,14 @@ def process_mips(item: Union[Instruction, Label]) -> None:

if item.is_branch_likely:
assert isinstance(item.jump_target, JumpTarget)
target = item.jump_target
branch_likely_counts[target.target] += 1
index = branch_likely_counts[target.target]
target = item.jump_target.target
branch_likely_counts[target] += 1
index = branch_likely_counts[target]
mn_inverted = invert_mips_branch_mnemonic(item.mnemonic[:-1])
temp_label = f"_m2c_{target.target}_branchlikelyskip_{index}"
temp_label = f"_m2c_{target}_branchlikelyskip_{index}"
branch_not = arch.parse(
mn_inverted,
item.args[:-1] + [JumpTarget(temp_label)],
item.args[:-1] + [AsmGlobalSymbol(temp_label)],
item.meta.derived(),
)
nop = arch.parse("nop", [], item.meta.derived())
Expand All @@ -514,7 +520,7 @@ def process_mips(item: Union[Instruction, Label]) -> None:
block_builder.new_block()
block_builder.add_instruction(next_item)
block_builder.add_instruction(
arch.parse("b", [target], item.meta.derived())
arch.parse("b", [AsmGlobalSymbol(target)], item.meta.derived())
)
block_builder.add_instruction(nop.clone())
block_builder.new_block()
Expand Down Expand Up @@ -553,11 +559,13 @@ def process_no_delay_slots(item: Union[Instruction, Label]) -> None:

if item.is_conditional and item.is_return:
if cond_return_target is None:
cond_return_target = JumpTarget("_m2c_conditionalreturn_")
cond_return_target = "_m2c_conditionalreturn_"
# Strip the "lr" off of the instruction
assert item.mnemonic[-2:] == "lr"
branch_instr = arch.parse(
item.mnemonic[:-2], [cond_return_target], item.meta.derived()
item.mnemonic[:-2],
[AsmGlobalSymbol(cond_return_target)],
item.meta.derived(),
)
block_builder.add_instruction(branch_instr)
block_builder.new_block()
Expand Down Expand Up @@ -596,7 +604,7 @@ def process_no_delay_slots(item: Union[Instruction, Label]) -> None:

if cond_return_target is not None:
# Add an empty return block at the end of the function
block_builder.set_label(Label([cond_return_target.target]))
block_builder.set_label(Label([cond_return_target]))
for instr in arch.missing_return():
block_builder.add_instruction(instr)
block_builder.new_block()
Expand Down
13 changes: 6 additions & 7 deletions m2c/ir_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class IrMatch:
"""

symbolic_registers: Dict[str, Register] = field(default_factory=dict)
symbolic_labels: Dict[str, str] = field(default_factory=dict)
symbolic_args: Dict[str, Argument] = field(default_factory=dict)
ref_map: Dict[Reference, RefSet] = field(default_factory=dict)

Expand All @@ -131,6 +130,10 @@ def _is_symbolic_sym(arg: AsmGlobalSymbol) -> bool:
# Uppercase symbols are symbolic; everything else is literal
return arg.symbol_name.isupper()

@staticmethod
def _is_label_sym(arg: AsmGlobalSymbol) -> bool:
return arg.symbol_name.startswith(".")

def eval_math(self, pat: Argument) -> Argument:
# This function can only evaluate math in *patterns*, not candidate
# instructions. It does not need to support arbitrary math, only
Expand Down Expand Up @@ -168,13 +171,12 @@ def map_arg(self, key: Argument) -> Argument:
if isinstance(key, Register):
return self.map_reg(key)
if isinstance(key, AsmGlobalSymbol):
assert not self._is_label_sym(key), "not supported yet"
if self._is_symbolic_sym(key):
return self.symbolic_args[key.symbol_name]
return key
if isinstance(key, AsmAddressMode):
return AsmAddressMode(lhs=self.map_arg(key.lhs), rhs=self.map_reg(key.rhs))
if isinstance(key, JumpTarget):
return JumpTarget(self.symbolic_labels[key.target])
if isinstance(key, BinOp):
return self.eval_math(key)
assert False, f"bad pattern part: {key}"
Expand Down Expand Up @@ -232,6 +234,7 @@ def match_arg(self, pat: Argument, cand: Argument) -> bool:
return False
return self._match_var(self.symbolic_registers, pat.register_name, cand)
if isinstance(pat, AsmGlobalSymbol):
assert not self._is_label_sym(pat), "not supported yet"
if self._is_symbolic_sym(pat):
return self._match_var(self.symbolic_args, pat.symbol_name, cand)
return pat == cand
Expand All @@ -241,10 +244,6 @@ def match_arg(self, pat: Argument, cand: Argument) -> bool:
and self.match_arg(pat.lhs, cand.lhs)
and self.match_arg(pat.rhs, cand.rhs)
)
if isinstance(pat, JumpTarget):
return isinstance(cand, JumpTarget) and self._match_var(
self.symbolic_labels, pat.target, cand.target
)
if isinstance(pat, BinOp):
return self.eval_math(pat) == cand

Expand Down

0 comments on commit 0514701

Please sign in to comment.