Skip to content

Commit

Permalink
Refactor callers to support calls as values
Browse files Browse the repository at this point in the history
  • Loading branch information
tsandall committed Jan 27, 2018
1 parent 02e6868 commit 082e445
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 126 deletions.
80 changes: 51 additions & 29 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"sort"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/types"
)

// Bytes formats Rego source code. The bytes provided do not have to be an entire
Expand Down Expand Up @@ -332,37 +331,33 @@ func (w *writer) writeFunctionCall(expr *ast.Expr, comments []*ast.Comment) []*a

terms := expr.Terms.([]*ast.Term)

if expr.Infix {
name := terms[0].Value.String()
if bi, ok := ast.BuiltinMap[name]; ok {
if types.Compare(bi.Decl.Result(), types.T) == 0 {
// Handle relational operators (=, !=, >, etc.)
comments = w.writeTerm(terms[1], comments)
w.write(" " + string(bi.Infix) + " ")
return w.writeTerm(terms[2], comments)
} else if bi.Infix != "" {
// Handle arithmetic operators (+, *, &, etc.)
comments = w.writeTerm(terms[3], comments)
w.write(" = ")
comments = w.writeTerm(terms[1], comments)
w.write(" " + string(bi.Infix) + " ")
return w.writeTerm(terms[2], comments)
}
}
comments = w.writeTerm(terms[len(terms)-1], comments)
w.write(" = " + string(terms[0].String()) + "(")
for i := 1; ; i++ {
comments = w.writeTerm(terms[i], comments)
if i < len(terms)-2 {
w.write(", ")
} else {
w.write(")")
break
}
}
bi, ok := ast.BuiltinMap[terms[0].Value.String()]
if !ok || bi.Infix == "" {
return w.writeFunctionCallPlain(terms, comments)
}

numDeclArgs := len(bi.Decl.Args())
numCallArgs := len(terms) - 1

if numCallArgs == numDeclArgs {
// Print infix where result is unassigned (e.g., x != y)
comments = w.writeTerm(terms[1], comments)
w.write(" " + string(bi.Infix) + " ")
return w.writeTerm(terms[2], comments)
} else if numCallArgs == numDeclArgs+1 {
// Print infix where result is assigned (e.g., z = x + y)
comments = w.writeTerm(terms[3], comments)
w.write(" " + ast.Equality.Infix + " ")
comments = w.writeTerm(terms[1], comments)
w.write(" " + bi.Infix + " ")
comments = w.writeTerm(terms[2], comments)
return comments
}

return w.writeFunctionCallPlain(terms, comments)
}

func (w *writer) writeFunctionCallPlain(terms []*ast.Term, comments []*ast.Comment) []*ast.Comment {
w.write(string(terms[0].String()) + "(")
for _, v := range terms[1 : len(terms)-1] {
comments = w.writeTerm(v, comments)
Expand All @@ -382,6 +377,10 @@ func (w *writer) writeWith(with *ast.With, comments []*ast.Comment) []*ast.Comme
}

func (w *writer) writeTerm(term *ast.Term, comments []*ast.Comment) []*ast.Comment {
return w.writeTermParens(false, term, comments)
}

func (w *writer) writeTermParens(parens bool, term *ast.Term, comments []*ast.Comment) []*ast.Comment {
comments = w.insertComments(comments, term.Location)
if !w.inline {
w.startLine()
Expand Down Expand Up @@ -411,6 +410,8 @@ func (w *writer) writeTerm(term *ast.Term, comments []*ast.Comment) []*ast.Comme
// not what x.String() would give us.
w.write(string(term.Location.Text))
}
case ast.Call:
comments = w.writeCall(parens, x, term.Location, comments)
case fmt.Stringer:
w.write(x.String())
}
Expand All @@ -421,6 +422,27 @@ func (w *writer) writeTerm(term *ast.Term, comments []*ast.Comment) []*ast.Comme
return comments
}

func (w *writer) writeCall(parens bool, x ast.Call, loc *ast.Location, comments []*ast.Comment) []*ast.Comment {

bi, ok := ast.BuiltinMap[x[0].String()]
if !ok || bi.Infix == "" {
return w.writeFunctionCallPlain([]*ast.Term(x), comments)
}

// TODO(tsandall): improve to consider precedence?
if parens {
w.write("(")
}
comments = w.writeTermParens(true, x[1], comments)
w.write(" " + bi.Infix + " ")
comments = w.writeTermParens(true, x[2], comments)
if parens {
w.write(")")
}

return comments
}

func (w *writer) writeObject(obj ast.Object, loc *ast.Location, comments []*ast.Comment) []*ast.Comment {
w.write("{")
defer w.write("}")
Expand Down
15 changes: 13 additions & 2 deletions format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package format

import (
"bytes"
"fmt"
"io/ioutil"
"path/filepath"
"strings"
Expand Down Expand Up @@ -71,7 +72,7 @@ func TestFormatSource(t *testing.T) {
}

if ln, at := differsAt(formatted, expected); ln != 0 {
t.Fatalf("Expected formatted bytes to equal expected bytes but differed near line %d / byte %d:\n%s", ln, at, formatted)
t.Fatalf("Expected formatted bytes to equal expected bytes but differed near line %d / byte %d (got: %q, expected: %q):\n%s", ln, at, formatted[at], expected[at], prefixWithLineNumbers(formatted))
}

if _, err := ast.ParseModule(rego+".tmp", string(formatted)); err != nil {
Expand All @@ -84,7 +85,7 @@ func TestFormatSource(t *testing.T) {
}

if ln, at := differsAt(formatted, expected); ln != 0 {
t.Fatalf("Expected roundtripped bytes to equal expected bytes but differed near line %d / byte %d:\n%s", ln, at, formatted)
t.Fatalf("Expected roundtripped bytes to equal expected bytes but differed near line %d / byte %d:\n%s", ln, at, prefixWithLineNumbers(formatted))
}

})
Expand All @@ -110,3 +111,13 @@ func differsAt(a, b []byte) (int, int) {
}
return ln, minLen
}

func prefixWithLineNumbers(bs []byte) []byte {
raw := string(bs)
lines := strings.Split(raw, "\n")
format := fmt.Sprintf("%%%dd %%s", len(fmt.Sprint(len(lines)+1)))
for i, line := range lines {
lines[i] = fmt.Sprintf(format, i+1, line)
}
return []byte(strings.Join(lines, "\n"))
}
14 changes: 14 additions & 0 deletions format/testfiles/test.rego
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ y = x[_]
}
} # Comment on rule closing brace

nested_infix {
x + 1
x = y + 2
plus(x, 1, 2)
plus(x, 1)
y = f(x)
f(x, y)
y = x + 1 + 2
x = y + # comment
z
x = (a + b) / 2
f((a+b)/2)
}

# more comments!
# more comments!
# more comments!
Expand Down
18 changes: 16 additions & 2 deletions format/testfiles/test.rego.formatted
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ default foo = false

foo[x] {
not x = g
1 = f(x)
z = g(x, "foo")
f(x) = 1
g(x, "foo") = z
}

globals = {
Expand Down Expand Up @@ -178,6 +178,20 @@ p[x] = y {
}
} # Comment on rule closing brace

nested_infix {
x + 1
x = y + 2
2 = x + 1
x + 1
y = f(x)
f(x, y)
y = (x + 1) + 2
x = y + z # comment

x = (a + b) / 2
f((a + b) / 2)
}

# more comments!
# more comments!
# more comments!
Expand Down
Loading

0 comments on commit 082e445

Please sign in to comment.