Skip to content

Commit

Permalink
wasm: replace comparison special cases, add and use builtins (open-po…
Browse files Browse the repository at this point in the history
…licy-agent#3271)

The planner has further uses for ir.EqualStmt and ir.NotEqualStmt,
so they're not removed.

Signed-off-by: Stephan Renatus <[email protected]>
  • Loading branch information
srenatus authored Mar 15, 2021
1 parent c80fb04 commit 8f7185d
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 114 deletions.
12 changes: 12 additions & 0 deletions internal/compiler/wasm/opa/callgraph.csv
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ parse_ip,memchr
opa_cidr_intersects,opa_value_type
opa_cidr_intersects,parse_cidr
opa_cidr_intersects,opa_boolean
opa_cmp_eq,opa_value_compare
opa_cmp_eq,opa_boolean
opa_cmp_neq,opa_value_compare
opa_cmp_neq,opa_boolean
opa_cmp_gt,opa_value_compare
opa_cmp_gt,opa_boolean
opa_cmp_gte,opa_value_compare
opa_cmp_gte,opa_boolean
opa_cmp_lt,opa_value_compare
opa_cmp_lt,opa_boolean
opa_cmp_lte,opa_value_compare
opa_cmp_lte,opa_boolean
opa_eval_ctx_new,opa_malloc
__force_import_opa_builtins,opa_builtin0
__force_import_opa_builtins,opa_builtin1
Expand Down
34 changes: 6 additions & 28 deletions internal/compiler/wasm/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ var builtinsFunctions = map[string]string{
ast.NetCIDRContains.Name: "opa_cidr_contains",
ast.NetCIDROverlap.Name: "opa_cidr_contains",
ast.NetCIDRIntersects.Name: "opa_cidr_intersects",
ast.Equal.Name: "opa_cmp_eq",
ast.GreaterThan.Name: "opa_cmp_gt",
ast.GreaterThanEq.Name: "opa_cmp_gte",
ast.LessThan.Name: "opa_cmp_lt",
ast.LessThanEq.Name: "opa_cmp_lte",
ast.NotEqual.Name: "opa_cmp_neq",
ast.GlobMatch.Name: "opa_glob_match",
ast.JSONMarshal.Name: "opa_json_marshal",
ast.JSONUnmarshal.Name: "opa_json_unmarshal",
Expand Down Expand Up @@ -1017,34 +1023,6 @@ func (c *Compiler) compileBlock(block *ir.Block) ([]instruction.Instruction, err
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.BrIf{Index: 0})
}
case *ir.LessThanStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32GeS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.LessThanEqualStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32GtS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.GreaterThanStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32LeS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.GreaterThanEqualStmt:
instrs = append(instrs, c.instrRead(stmt.A))
instrs = append(instrs, c.instrRead(stmt.B))
instrs = append(instrs, instruction.Call{Index: c.function(opaValueCompare)})
instrs = append(instrs, instruction.I32Const{Value: 0})
instrs = append(instrs, instruction.I32LtS{})
instrs = append(instrs, instruction.BrIf{Index: 0})
case *ir.NotEqualStmt:
if stmt.A == stmt.B { // same local, same bool constant, or same string constant
instrs = append(instrs, instruction.Br{Index: 0})
Expand Down
32 changes: 0 additions & 32 deletions internal/ir/ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,38 +356,6 @@ type EqualStmt struct {
Location
}

// LessThanStmt represents a < check of two local variables.
type LessThanStmt struct {
A LocalOrConst
B LocalOrConst

Location
}

// LessThanEqualStmt represents a <= check of two local variables.
type LessThanEqualStmt struct {
A LocalOrConst
B LocalOrConst

Location
}

// GreaterThanStmt represents a > check of two local variables.
type GreaterThanStmt struct {
A LocalOrConst
B LocalOrConst

Location
}

// GreaterThanEqualStmt represents a >= check of two local variables.
type GreaterThanEqualStmt struct {
A LocalOrConst
B LocalOrConst

Location
}

// NotEqualStmt represents a != check of two local variables.
type NotEqualStmt struct {
A LocalOrConst
Expand Down
49 changes: 1 addition & 48 deletions internal/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,54 +728,7 @@ func (p *Planner) planExprCall(e *ast.Expr, iter planiter) error {
switch operator {
case ast.Equality.Name:
return p.planUnify(e.Operand(0), e.Operand(1), iter)
case ast.Equal.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.EqualStmt{
A: a,
B: b,
})
return iter()
})
case ast.LessThan.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.LessThanStmt{
A: a,
B: b,
})
return iter()
})
case ast.LessThanEq.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.LessThanEqualStmt{
A: a,
B: b,
})
return iter()
})
case ast.GreaterThan.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.GreaterThanStmt{
A: a,
B: b,
})
return iter()
})
case ast.GreaterThanEq.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.GreaterThanEqualStmt{
A: a,
B: b,
})
return iter()
})
case ast.NotEqual.Name:
return p.planBinaryExpr(e, func(a, b ir.LocalOrConst) error {
p.appendStmt(&ir.NotEqualStmt{
A: a,
B: b,
})
return iter()
})

default:

var relation bool
Expand Down
15 changes: 9 additions & 6 deletions internal/planner/planner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,8 @@ p[1] {
}
`},
exps: map[ir.Stmt]string{
&ir.GreaterThanStmt{}: "module-0.rego:4:3: 1 > 2",
&ir.SetAddStmt{}: "module-0.rego:3:1: p[1]",
&ir.CallStmt{}: "module-0.rego:4:3: 1 > 2",
&ir.SetAddStmt{}: "module-0.rego:3:1: p[1]",
},
where: funcs,
},
Expand Down Expand Up @@ -511,9 +511,10 @@ a = { "a": "b" |
1 > 0
}`},
exps: map[ir.Stmt]string{
&ir.GreaterThanStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.CallStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.ObjectInsertOnceStmt{}: "module-0.rego:2:5: { \"a\": \"b\" |\n 1 > 0\n}",
},
where: funcs,
},
{
note: "array comprehension in policy",
Expand All @@ -524,9 +525,10 @@ a = [ "a" |
1 > 0
]`},
exps: map[ir.Stmt]string{
&ir.GreaterThanStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.CallStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.ArrayAppendStmt{}: "module-0.rego:2:5: [ \"a\" |\n 1 > 0\n]",
},
where: funcs,
},
{
note: "set comprehension in policy",
Expand All @@ -537,9 +539,10 @@ a = { "a" |
1 > 0
}`},
exps: map[ir.Stmt]string{
&ir.GreaterThanStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.SetAddStmt{}: "module-0.rego:2:5: { \"a\" |\n 1 > 0\n}",
&ir.CallStmt{}: "module-0.rego:3:3: 1 > 0",
&ir.SetAddStmt{}: "module-0.rego:2:5: { \"a\" |\n 1 > 0\n}",
},
where: funcs,
},
{
note: "set in policy",
Expand Down
66 changes: 66 additions & 0 deletions test/wasm/assets/018_builtins.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,70 @@
cases:
- note: equal built-in (true)
query: equal(1,1,x)
want_result:
- x: true
- # NOTE: This is a property of the planner/compiler interaction, more so
# than it is a property of the builtin implementation. Therefore, we only
# assert it once for equal/true, equal/false; instead of duplicating all
# the other comparison test cases.
note: equal built-in (true, result not captured)
query: equal(1,1)
want_defined: true
- note: equal built-in (false)
query: equal(1,2,x)
want_result:
- x: false
- note: equal built-in (false, result not captured)
query: equal(1,2)
want_defined: false
- note: gt built-in (true)
query: gt(1,0,x)
want_result:
- x: true
- note: gt built-in (false)
query: gt(1,2,x)
want_result:
- x: false
- note: gte built-in (true)
query: gte(1,0,x)
want_result:
- x: true
- note: gte built-in (true, equal)
query: gte(1,1,x)
want_result:
- x: true
- note: gte built-in (false)
query: gte(1,2,x)
want_result:
- x: false
- note: lt built-in (true)
query: lt(0,1,x)
want_result:
- x: true
- note: lt built-in (false)
query: lt(2,1,x)
want_result:
- x: false
- note: lte built-in (true)
query: lte(0,1,x)
want_result:
- x: true
- note: lte built-in (true, equal)
query: lte(0,0,x)
want_result:
- x: true
- note: lte built-in (false)
query: lte(2,1,x)
want_result:
- x: false
- note: neq built-in (true)
query: neq(0,1,x)
want_result:
- x: true
- note: neq built-in (false)
query: neq(1,1,x)
want_result:
- x: false
- note: abs built-in
query: abs(-1,x)
want_result: [{'x': 1}]
Expand Down
38 changes: 38 additions & 0 deletions wasm/src/comparisons.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include "value.h"
#include "comparisons.h"

OPA_BUILTIN
opa_value *opa_cmp_eq(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) == 0);
}

OPA_BUILTIN
opa_value *opa_cmp_neq(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) != 0);
}

OPA_BUILTIN
opa_value *opa_cmp_gt(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) > 0);
}

OPA_BUILTIN
opa_value *opa_cmp_gte(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) >= 0);
}

OPA_BUILTIN
opa_value *opa_cmp_lt(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) < 0);
}

OPA_BUILTIN
opa_value *opa_cmp_lte(opa_value *a, opa_value *b)
{
return opa_boolean(opa_value_compare(a, b) <= 0);
}
8 changes: 8 additions & 0 deletions wasm/src/comparisons.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "value.h"

opa_value *opa_cmp_eq(opa_value *a, opa_value *b);
opa_value *opa_cmp_neq(opa_value *a, opa_value *b);
opa_value *opa_cmp_gt(opa_value *a, opa_value *b);
opa_value *opa_cmp_gte(opa_value *a, opa_value *b);
opa_value *opa_cmp_lt(opa_value *a, opa_value *b);
opa_value *opa_cmp_lte(opa_value *a, opa_value *b);

0 comments on commit 8f7185d

Please sign in to comment.