Skip to content

Commit

Permalink
Update format package to tolerate nil locations
Browse files Browse the repository at this point in the history
Previously, the format package would return an error if any of the AST
nodes under the input were missing a location value. When the format
package was first implemented, the main use case was formatting policies
that people had written manually--which means they are provided to OPA
as files/raw strings. As a result, it made sense to treat a missing
location as an error condition because it simplifies the formatting
implementation.

However, when policies are generated (e.g., by partial evaluation) the
AST nodes do not typically carry locations. As a result, these AST nodes
cannot be formatted nicely.

These changes modify the format package to tolerate nil location values.
If a nil location value is encountered, the format package will set the
location value on the AST node to a default location, currently row 1
column 1 with text from the AST node's string representation.

Signed-off-by: Torin Sandall <[email protected]>
  • Loading branch information
tsandall committed Aug 30, 2018
1 parent 47456c6 commit e04365e
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 96 deletions.
103 changes: 102 additions & 1 deletion ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,17 @@ func IsKeyword(s string) bool {
}

type (
// Node represents a node in an AST. Nodes may be statements in a policy module
// or elements of an ad-hoc query, expression, etc.
Node interface {
fmt.Stringer
Loc() *Location
SetLoc(*Location)
}

// Statement represents a single statement in a policy module.
Statement interface {
Loc() *Location
Node
}
)

Expand Down Expand Up @@ -263,9 +271,17 @@ func NewComment(text []byte) *Comment {

// Loc returns the location of the comment in the definition.
func (c *Comment) Loc() *Location {
if c == nil {
return nil
}
return c.Location
}

// SetLoc sets the location on c.
func (c *Comment) SetLoc(loc *Location) {
c.Location = loc
}

func (c *Comment) String() string {
return "#" + string(c.Text)
}
Expand All @@ -290,9 +306,17 @@ func (pkg *Package) Equal(other *Package) bool {

// Loc returns the location of the Package in the definition.
func (pkg *Package) Loc() *Location {
if pkg == nil {
return nil
}
return pkg.Location
}

// SetLoc sets the location on pkg.
func (pkg *Package) SetLoc(loc *Location) {
pkg.Location = loc
}

func (pkg *Package) String() string {
// Omit head as all packages have the DefaultRootDocument prepended at parse time.
path := make(Ref, len(pkg.Path)-1)
Expand Down Expand Up @@ -355,9 +379,17 @@ func (imp *Import) Equal(other *Import) bool {

// Loc returns the location of the Import in the definition.
func (imp *Import) Loc() *Location {
if imp == nil {
return nil
}
return imp.Location
}

// SetLoc sets the location on imp.
func (imp *Import) SetLoc(loc *Location) {
imp.Location = loc
}

// Name returns the variable that is used to refer to the imported virtual
// document. This is the alias if defined otherwise the last element in the
// path.
Expand Down Expand Up @@ -426,9 +458,17 @@ func (rule *Rule) Equal(other *Rule) bool {

// Loc returns the location of the Rule in the definition.
func (rule *Rule) Loc() *Location {
if rule == nil {
return nil
}
return rule.Location
}

// SetLoc sets the location on rule.
func (rule *Rule) SetLoc(loc *Location) {
rule.Location = loc
}

// Path returns a ref referring to the document produced by this rule. If rule
// is not contained in a module, this function panics.
func (rule *Rule) Path() Ref {
Expand Down Expand Up @@ -588,6 +628,19 @@ func (head *Head) Vars() VarSet {
return vis.vars
}

// Loc returns the Location of head.
func (head *Head) Loc() *Location {
if head == nil {
return nil
}
return head.Location
}

// SetLoc sets the location on head.
func (head *Head) SetLoc(loc *Location) {
head.Location = loc
}

// Copy returns a deep copy of a.
func (a Args) Copy() Args {
cpy := Args{}
Expand All @@ -605,6 +658,21 @@ func (a Args) String() string {
return "(" + strings.Join(buf, ", ") + ")"
}

// Loc returns the Location of a.
func (a Args) Loc() *Location {
if len(a) == 0 {
return nil
}
return a[0].Location
}

// SetLoc sets the location on a.
func (a Args) SetLoc(loc *Location) {
if len(a) != 0 {
a[0].SetLocation(loc)
}
}

// Vars returns a set of vars that appear in a.
func (a Args) Vars() VarSet {
vis := &VarVisitor{vars: VarSet{}}
Expand Down Expand Up @@ -719,6 +787,13 @@ func (body Body) Loc() *Location {
return body[0].Location
}

// SetLoc sets the location on body.
func (body Body) SetLoc(loc *Location) {
if len(body) != 0 {
body[0].SetLocation(loc)
}
}

func (body Body) String() string {
var buf []string
for _, v := range body {
Expand Down Expand Up @@ -957,6 +1032,19 @@ func (expr *Expr) SetLocation(loc *Location) *Expr {
return expr
}

// Loc returns the Location of expr.
func (expr *Expr) Loc() *Location {
if expr == nil {
return nil
}
return expr.Location
}

// SetLoc sets the location on expr.
func (expr *Expr) SetLoc(loc *Location) {
expr.SetLocation(loc)
}

func (expr *Expr) String() string {
var buf []string
if expr.Negated {
Expand Down Expand Up @@ -1048,6 +1136,19 @@ func (w *With) SetLocation(loc *Location) *With {
return w
}

// Loc returns the Location of w.
func (w *With) Loc() *Location {
if w == nil {
return nil
}
return w.Location
}

// SetLoc sets the location on w.
func (w *With) SetLoc(loc *Location) {
w.Location = loc
}

// RuleSet represents a collection of rules that produce a virtual document.
type RuleSet []*Rule

Expand Down
13 changes: 13 additions & 0 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,19 @@ func (term *Term) SetLocation(loc *Location) *Term {
return term
}

// Loc returns the Location of term.
func (term *Term) Loc() *Location {
if term == nil {
return nil
}
return term.Location
}

// SetLoc sets the location on term.
func (term *Term) SetLoc(loc *Location) {
term.SetLocation(loc)
}

// Copy returns a deep copy of term.
func (term *Term) Copy() *Term {

Expand Down
12 changes: 12 additions & 0 deletions ast/visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,18 @@ func WalkRules(x interface{}, f func(*Rule) bool) {
Walk(vis, x)
}

// WalkNodes calls the function f on all nodes under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkNodes(x interface{}, f func(Node) bool) {
vis := &GenericVisitor{func(x interface{}) bool {
if n, ok := x.(Node); ok {
return f(n)
}
return false
}}
Walk(vis, x)
}

// GenericVisitor implements the Visitor interface to provide
// a utility to walk over AST nodes using a closure. If the closure
// returns true, the visitor will not walk over AST nodes under x.
Expand Down
84 changes: 26 additions & 58 deletions format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,53 +39,33 @@ func Source(filename string, src []byte) ([]byte, error) {
return formatted, nil
}

// Ast formats a Rego AST element. If the passed value is not a valid AST element,
// Ast returns nil and an error. Ast relies on all AST elements having non-nil
// Location values, and will return an error if this is not the case.
// MustAst is a helper function to format a Rego AST element. If any errors
// occurs this function will panic. This is mostly used for test
func MustAst(x interface{}) []byte {
bs, err := Ast(x)
if err != nil {
panic(err)
}
return bs
}

// Ast formats a Rego AST element. If the passed value is not a valid AST
// element, Ast returns nil and an error. Ast relies on all AST elements having
// non-nil Location values. If an AST element with a nil Location value is
// encountered, a default location will be set on the AST node.
func Ast(x interface{}) (formatted []byte, err error) {
defer func() {
// Ast relies on all terms in the ast element having non-nil Location
// values. If a location is nil, Ast will panic, so we need to recover
// gracefully.
if r := recover(); r != nil {
formatted = nil
switch r := r.(type) {
case nilLocationErr:
err = r
default:
panic(r)
}
}
}()

// Check all elements in the Ast interface have a location.
ast.Walk(ast.NewGenericVisitor(func(x interface{}) bool {
switch x := x.(type) {
case *ast.Module, ast.Value: // Pass, they don't have locations.
case ast.Body:
// Empty bodies do not have a location. The body formatting implementation must
// handle this.
if len(x) > 0 {
assertHasLocation(x)
}
case *ast.Term:
switch v := x.Value.(type) {
case ast.Ref:
if h := v[0]; !ast.RootDocumentNames.Contains(h) {
assertHasLocation(x)
}
case ast.Var:
if vt := ast.VarTerm(string(v)); !ast.RootDocumentNames.Contains(vt) {
assertHasLocation(x)
}
default:
assertHasLocation(x)
ast.WalkNodes(x, func(x ast.Node) bool {
if b, ok := x.(ast.Body); ok {
if len(b) == 0 {
return false
}
case *ast.Package, *ast.Import, *ast.Rule, *ast.Head, *ast.Expr, *ast.With, *ast.Comment:
assertHasLocation(x)
}
if x.Loc() == nil {
x.SetLoc(defaultLocation(x))
}
return false
}), x)
})

w := &writer{indent: "\t"}
switch x := x.(type) {
Expand Down Expand Up @@ -117,6 +97,10 @@ func Ast(x interface{}) (formatted []byte, err error) {
return w.buf.Bytes(), nil
}

func defaultLocation(x ast.Node) *ast.Location {
return ast.NewLocation([]byte(x.String()), "", 1, 1)
}

type writer struct {
buf bytes.Buffer

Expand Down Expand Up @@ -935,19 +919,3 @@ func (w *writer) down() {
}
w.level--
}

func assertHasLocation(xs ...interface{}) {
for _, x := range xs {
if getLoc(x) == nil {
panic(nilLocationErr{x})
}
}
}

type nilLocationErr struct {
x interface{}
}

func (err nilLocationErr) Error() string {
return fmt.Sprintf("nil location on %T", err.x)
}
15 changes: 10 additions & 5 deletions format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@ func TestFormatNilLocation(t *testing.T) {
rule := ast.MustParseRule(`r = y { y = "foo" }`)
rule.Head.Location = nil

_, err := Ast(rule)
if err == nil {
t.Fatal("Expected error for rule with nil Location in head")
bs, err := Ast(rule)
if err != nil {
t.Fatal(err)
}

if _, ok := err.(nilLocationErr); !ok {
t.Fatalf("Expected nilLocationErr, got %v", err)
exp := strings.Trim(`
r = y {
y = "foo"
}`, " \n")

if string(bs) != exp {
t.Fatalf("Expected %q but got %q", exp, string(bs))
}
}

Expand Down
Loading

0 comments on commit e04365e

Please sign in to comment.