Skip to content

Commit

Permalink
Merge pull request ethereum#12976 from dflupu/mulmod-opti
Browse files Browse the repository at this point in the history
Add simplification rules for `mod(mul(X, Y), A)` & `mod(add(X, Y), A)`
  • Loading branch information
wechman authored Jun 22, 2022
2 parents 75300c3 + 8498fdf commit 02fdcb3
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Language Features:
Compiler Features:
* TypeChecker: Support using library constants in initializers of other constants.
* Yul IR Code Generation: Improved copy routines for arrays with packed storage layout.
* Yul Optimizer: Add rule to convert `mod(mul(X, Y), A)` into `mulmod(X, Y, A)`, if `A` is a power of two.
* Yul Optimizer: Add rule to convert `mod(add(X, Y), A)` into `addmod(X, Y, A)`, if `A` is a power of two.


Bugfixes:
Expand Down
27 changes: 25 additions & 2 deletions libevmasm/RuleList.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,41 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleListPart4_5(

template <class Pattern>
std::vector<SimplificationRule<Pattern>> simplificationRuleListPart5(
bool _forYulOptimizer,
Pattern A,
Pattern B,
Pattern,
Pattern X,
Pattern
Pattern Y
)
{
using Word = typename Pattern::Word;
using Builtins = typename Pattern::Builtins;

std::vector<SimplificationRule<Pattern>> rules;

// The libevmasm optimizer does not support rules resulting in opcodes with more than two arguments.
if (_forYulOptimizer)
{
// Replace MOD(MUL(X, Y), A) with MULMOD(X, Y, A) iff A=2**N
rules.push_back({
Builtins::MOD(Builtins::MUL(X, Y), A),
[=]() -> Pattern { return Builtins::MULMOD(X, Y, A); },
[=] {
return A.d() > 0 && ((A.d() & (A.d() - 1)) == 0);
}
});

// Replace MOD(ADD(X, Y), A) with ADDMOD(X, Y, A) iff A=2**N
rules.push_back({
Builtins::MOD(Builtins::ADD(X, Y), A),
[=]() -> Pattern { return Builtins::ADDMOD(X, Y, A); },
[=] {
return A.d() > 0 && ((A.d() & (A.d() - 1)) == 0);
}
});
}

// Replace MOD X, <power-of-two> with AND X, <power-of-two> - 1
for (size_t i = 0; i < Pattern::WordSize; ++i)
{
Expand Down Expand Up @@ -798,7 +821,7 @@ std::vector<SimplificationRule<Pattern>> simplificationRuleList(
rules += simplificationRuleListPart3(A, B, C, W, X);
rules += simplificationRuleListPart4(A, B, C, W, X);
rules += simplificationRuleListPart4_5(A, B, C, W, X);
rules += simplificationRuleListPart5(A, B, C, W, X);
rules += simplificationRuleListPart5(_evmVersion.has_value(), A, B, C, W, X);
rules += simplificationRuleListPart6(A, B, C, W, X);
rules += simplificationRuleListPart7(A, B, C, W, X);
rules += simplificationRuleListPart8(A, B, C, W, X);
Expand Down
31 changes: 31 additions & 0 deletions test/formal/mod_add_to_addmod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from opcodes import MOD, ADD, ADDMOD
from rule import Rule
from z3 import BitVec

"""
Rule:
MOD(ADD(X, Y), A) -> ADDMOD(X, Y, A)
given
A > 0
A & (A - 1) == 0
"""

rule = Rule()

n_bits = 32

# Input vars
X = BitVec('X', n_bits)
Y = BitVec('Y', n_bits)
A = BitVec('A', n_bits)

# Non optimized result
nonopt = MOD(ADD(X, Y), A)

# Optimized result
opt = ADDMOD(X, Y, A)

rule.require(A > 0)
rule.require(((A & (A - 1)) == 0))

rule.check(nonopt, opt)
31 changes: 31 additions & 0 deletions test/formal/mod_mul_to_mulmod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from opcodes import MOD, MUL, MULMOD
from rule import Rule
from z3 import BitVec

"""
Rule:
MOD(MUL(X, Y), A) -> MULMOD(X, Y, A)
given
A > 0
A & (A - 1) == 0
"""

rule = Rule()

n_bits = 8

# Input vars
X = BitVec('X', n_bits)
Y = BitVec('Y', n_bits)
A = BitVec('A', n_bits)

# Non optimized result
nonopt = MOD(MUL(X, Y), A)

# Optimized result
opt = MULMOD(X, Y, A)

rule.require(A > 0)
rule.require(((A & (A - 1)) == 0))

rule.check(nonopt, opt)
8 changes: 7 additions & 1 deletion test/formal/opcodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from z3 import BitVecVal, BV2Int, If, LShR, UDiv, ULT, UGT, URem
from z3 import BitVecVal, BV2Int, If, LShR, UDiv, ULT, UGT, URem, ZeroExt, Extract

def ADD(x, y):
return x + y
Expand All @@ -18,6 +18,12 @@ def SDIV(x, y):
def MOD(x, y):
return If(y == 0, 0, URem(x, y))

def MULMOD(x, y, m):
return If(m == 0, 0, Extract(x.size() - 1, 0, URem(ZeroExt(x.size(), x) * ZeroExt(x.size(), y), ZeroExt(m.size(), m))))

def ADDMOD(x, y, m):
return If(m == 0, 0, Extract(x.size() - 1, 0, URem(ZeroExt(1, x) + ZeroExt(1, y), ZeroExt(1, m))))

def SMOD(x, y):
return If(y == 0, 0, x % y)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
mstore(0, mod(add(mload(0), mload(1)), 32))
}
// ----
// step: expressionSimplifier
//
// {
// {
// let _3 := mload(1)
// let _4 := 0
// mstore(_4, addmod(mload(_4), _3, 32))
// }
// }
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
mstore(0, mod(mul(mload(0), mload(1)), 32))
}
// ----
// step: expressionSimplifier
//
// {
// {
// let _3 := mload(1)
// let _4 := 0
// mstore(_4, mulmod(mload(_4), _3, 32))
// }
// }

0 comments on commit 02fdcb3

Please sign in to comment.