Skip to content

Commit

Permalink
ast: Adding annotation override support to all annotation types (open…
Browse files Browse the repository at this point in the history
…-policy-agent#4370)

Fixes: open-policy-agent#4331
Signed-off-by: Johan Fylling <[email protected]>
  • Loading branch information
johanfylling authored Feb 22, 2022
1 parent 9afdad7 commit 1d1eb4d
Show file tree
Hide file tree
Showing 10 changed files with 1,277 additions and 595 deletions.
550 changes: 550 additions & 0 deletions ast/annotations.go

Large diffs are not rendered by default.

507 changes: 507 additions & 0 deletions ast/annotations_test.go

Large diffs are not rendered by default.

173 changes: 5 additions & 168 deletions ast/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,8 @@ func (tc *typeChecker) CheckBody(env *TypeEnv, body Body) (*TypeEnv, Errors) {
// CheckTypes runs type checking on the rules returns a TypeEnv if no errors
// are found. The resulting TypeEnv wraps the provided one. The resulting
// TypeEnv will be able to resolve types of refs that refer to rules.
func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T) (*TypeEnv, Errors) {
func (tc *typeChecker) CheckTypes(env *TypeEnv, sorted []util.T, as *annotationSet) (*TypeEnv, Errors) {
env = tc.newEnv(env)
var as *annotationSet
if tc.ss != nil {
var errs Errors
as, errs = buildAnnotationSet(sorted)
if len(errs) > 0 {
return env, errs
}
}
for _, s := range sorted {
tc.checkRule(env, as, s.(*Rule))
}
Expand Down Expand Up @@ -1172,19 +1164,19 @@ func getObjectType(ref Ref, o types.Type, rule *Rule, d *types.DynamicProperty)

func getRuleAnnotation(as *annotationSet, rule *Rule) (result []*SchemaAnnotation) {

for _, x := range as.GetSubpackagesScope(rule.Module.Package.Path) {
for _, x := range as.getSubpackagesScope(rule.Module.Package.Path) {
result = append(result, x.Schemas...)
}

if x := as.GetPackageScope(rule.Module.Package); x != nil {
if x := as.getPackageScope(rule.Module.Package); x != nil {
result = append(result, x.Schemas...)
}

if x := as.GetDocumentScope(rule.Path()); x != nil {
if x := as.getDocumentScope(rule.Path()); x != nil {
result = append(result, x.Schemas...)
}

for _, x := range as.GetRuleScope(rule) {
for _, x := range as.getRuleScope(rule) {
result = append(result, x.Schemas...)
}

Expand Down Expand Up @@ -1215,158 +1207,3 @@ func processAnnotation(ss *SchemaSet, annot *SchemaAnnotation, rule *Rule, allow
func errAnnotationRedeclared(a *Annotations, other *Location) *Error {
return NewError(TypeErr, a.Location, "%v annotation redeclared: %v", a.Scope, other)
}

type annotationSet struct {
byRule map[*Rule][]*Annotations
byPackage map[*Package]*Annotations
byPath *annotationTreeNode
}

func buildAnnotationSet(rules []util.T) (*annotationSet, Errors) {
as := newAnnotationSet()
processed := map[*Module]struct{}{}
var errs Errors
for _, x := range rules {
module := x.(*Rule).Module
if _, ok := processed[module]; ok {
continue
}
processed[module] = struct{}{}
for _, a := range module.Annotations {
if err := as.Add(a); err != nil {
errs = append(errs, err)
}
}
}
if len(errs) > 0 {
return nil, errs
}
return as, nil
}

func newAnnotationSet() *annotationSet {
return &annotationSet{
byRule: map[*Rule][]*Annotations{},
byPackage: map[*Package]*Annotations{},
byPath: newAnnotationTree(),
}
}

func (as *annotationSet) Add(a *Annotations) *Error {
switch a.Scope {
case annotationScopeRule:
rule := a.node.(*Rule)
as.byRule[rule] = append(as.byRule[rule], a)
case annotationScopePackage:
pkg := a.node.(*Package)
if exist, ok := as.byPackage[pkg]; ok {
return errAnnotationRedeclared(a, exist.Location)
}
as.byPackage[pkg] = a
case annotationScopeDocument:
rule := a.node.(*Rule)
path := rule.Path()
x := as.byPath.Get(path)
if x != nil {
return errAnnotationRedeclared(a, x.Value.Location)
}
as.byPath.Insert(path, a)
case annotationScopeSubpackages:
pkg := a.node.(*Package)
x := as.byPath.Get(pkg.Path)
if x != nil {
return errAnnotationRedeclared(a, x.Value.Location)
}
as.byPath.Insert(pkg.Path, a)
}
return nil
}

func (as *annotationSet) GetRuleScope(r *Rule) []*Annotations {
if as == nil {
return nil
}
return as.byRule[r]
}

func (as *annotationSet) GetSubpackagesScope(path Ref) []*Annotations {
if as == nil {
return nil
}
return as.byPath.Ancestors(path)
}

func (as *annotationSet) GetDocumentScope(path Ref) *Annotations {
if as == nil {
return nil
}
if node := as.byPath.Get(path); node != nil {
return node.Value
}
return nil
}

func (as *annotationSet) GetPackageScope(pkg *Package) *Annotations {
if as == nil {
return nil
}
return as.byPackage[pkg]
}

type annotationTreeNode struct {
Value *Annotations
Children map[Value]*annotationTreeNode // we assume key elements are hashable (vars and strings only!)
}

func newAnnotationTree() *annotationTreeNode {
return &annotationTreeNode{
Value: nil,
Children: map[Value]*annotationTreeNode{},
}
}

func (t *annotationTreeNode) Insert(path Ref, value *Annotations) {
node := t
for _, k := range path {
child, ok := node.Children[k.Value]
if !ok {
child = newAnnotationTree()
node.Children[k.Value] = child
}
node = child
}
node.Value = value
}

func (t *annotationTreeNode) Get(path Ref) *annotationTreeNode {
node := t
for _, k := range path {
if node == nil {
return nil
}
child, ok := node.Children[k.Value]
if !ok {
return nil
}
node = child
}
return node
}

func (t *annotationTreeNode) Ancestors(path Ref) (result []*Annotations) {
node := t
for _, k := range path {
if node == nil {
return result
}
child, ok := node.Children[k.Value]
if !ok {
return result
}
if child.Value != nil {
result = append(result, child.Value)
}
node = child
}
return result
}
23 changes: 16 additions & 7 deletions ast/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ func TestCheckInferenceRules(t *testing.T) {

ref := MustParseRef(tc.ref)
checker := newTypeChecker()
env, err := checker.CheckTypes(nil, elems)
as, _ := buildAnnotationSet(elems)
env, err := checker.CheckTypes(nil, elems, as)

if err != nil {
t.Fatalf("Unexpected error %v:", err)
Expand Down Expand Up @@ -1020,9 +1021,12 @@ func TestFunctionTypeInferenceUnappliedWithObjectVarKey(t *testing.T) {
f(x) = y { y = {x: 1} }
`)

env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), []util.T{
elems := []util.T{
module.Rules[0],
})
}

as, _ := buildAnnotationSet(elems)
env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems, as)

if len(err) > 0 {
t.Fatal(err)
Expand Down Expand Up @@ -1218,8 +1222,10 @@ func TestCheckErrorOrdering(t *testing.T) {
inputReversed[1] = inputReversed[2]
inputReversed[2] = tmp

_, errs1 := newTypeChecker().CheckTypes(nil, input)
_, errs2 := newTypeChecker().CheckTypes(nil, inputReversed)
as, _ := buildAnnotationSet(input)
_, errs1 := newTypeChecker().CheckTypes(nil, input, as)
asReversed, _ := buildAnnotationSet(inputReversed)
_, errs2 := newTypeChecker().CheckTypes(nil, inputReversed, asReversed)

if errs1.Error() != errs2.Error() {
t.Fatalf("Expected error slices to be equal. errs1:\n\n%v\n\nerrs2:\n\n%v\n\n", errs1, errs2)
Expand Down Expand Up @@ -1265,7 +1271,8 @@ func newTestEnv(rs []string) *TypeEnv {
}
}

env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems)
as, _ := buildAnnotationSet(elems)
env, err := newTypeChecker().CheckTypes(newTypeChecker().Env(BuiltinMap), elems, as)
if len(err) > 0 {
panic(err)
}
Expand Down Expand Up @@ -1983,7 +1990,9 @@ p { input = "foo" }`},
}

oldTypeEnv := newTypeChecker().WithSchemaSet(schemaSet).Env(BuiltinMap)
typeenv, errors := newTypeChecker().WithSchemaSet(schemaSet).CheckTypes(oldTypeEnv, elems)
as, errors := buildAnnotationSet(elems)
typeenv, checkErrors := newTypeChecker().WithSchemaSet(schemaSet).CheckTypes(oldTypeEnv, elems, as)
errors = append(errors, checkErrors...)
if len(errors) > 0 {
for _, e := range errors {
if tc.err == "" || !strings.Contains(e.Error(), tc.err) {
Expand Down
52 changes: 50 additions & 2 deletions ast/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ type Compiler struct {
debug debug.Debug // emits debug information produced during compilation
schemaSet *SchemaSet // user-supplied schemas for input and data documents
inputType types.Type // global input type retrieved from schema set
annotationSet *annotationSet // hierarchical set of annotations
strict bool // enforce strict compilation checks
}

Expand Down Expand Up @@ -284,7 +285,8 @@ func NewCompiler() *Compiler {
{"RewriteEquals", "compile_stage_rewrite_equals", c.rewriteEquals},
{"RewriteDynamicTerms", "compile_stage_rewrite_dynamic_terms", c.rewriteDynamicTerms},
{"CheckRecursion", "compile_stage_check_recursion", c.checkRecursion},
{"CheckTypes", "compile_stage_check_types", c.checkTypes},
{"SetAnnotationSet", "compile_stage_set_annotationset", c.setAnnotationSet}, // must be run after CheckRecursion
{"CheckTypes", "compile_stage_check_types", c.checkTypes}, // must be run after CheckRecursion
{"CheckUnsafeBuiltins", "compile_state_check_unsafe_builtins", c.checkUnsafeBuiltins},
{"CheckDeprecatedBuiltins", "compile_state_check_deprecated_builtins", c.checkDeprecatedBuiltins},
{"BuildRuleIndices", "compile_stage_rebuild_indices", c.buildRuleIndices},
Expand Down Expand Up @@ -536,6 +538,42 @@ func (c *Compiler) GetRulesWithPrefix(ref Ref) (rules []*Rule) {
return rules
}

func (c *Compiler) GetPackageAnnotations(pkg *Package) *Annotations {
as := c.annotationSet

subPkgAnnot := as.getSubpackagesScope(pkg.Path)

result := make([]*Annotations, 0, len(subPkgAnnot)+1)

result = append(result, subPkgAnnot...)

if x := as.getPackageScope(pkg); x != nil {
result = append(result, x)
}

return mergeAnnotationsList(result)
}

func (c *Compiler) GetRuleAnnotations(rule *Rule) *Annotations {
as := c.annotationSet

ruleAnnot := as.getRuleScope(rule)

result := make([]*Annotations, 0, len(ruleAnnot)+2)

if a := c.GetPackageAnnotations(rule.Module.Package); a != nil {
result = append(result, a)
}

if a := as.getDocumentScope(rule.Path()); a != nil {
result = append(result, a)
}

result = append(result, ruleAnnot...)

return mergeAnnotationsList(result)
}

func extractRules(s []util.T) (rules []*Rule) {
for _, r := range s {
rules = append(rules, r.(*Rule))
Expand Down Expand Up @@ -1151,6 +1189,16 @@ func parseSchema(schema interface{}) (types.Type, error) {
return types.A, nil
}

func (c *Compiler) setAnnotationSet() {
// Recursion is caught in earlier step, so this cannot fail.
sorted, _ := c.Graph.Sort()
as, errs := buildAnnotationSet(sorted)
for _, err := range errs {
c.err(err)
}
c.annotationSet = as
}

// checkTypes runs the type checker on all rules. The type checker builds a
// TypeEnv that is stored on the compiler.
func (c *Compiler) checkTypes() {
Expand All @@ -1160,7 +1208,7 @@ func (c *Compiler) checkTypes() {
WithSchemaSet(c.schemaSet).
WithInputType(c.inputType).
WithVarRewriter(rewriteVarsInRef(c.RewrittenVars))
env, errs := checker.CheckTypes(c.TypeEnv, sorted)
env, errs := checker.CheckTypes(c.TypeEnv, sorted, c.annotationSet)
for _, err := range errs {
c.err(err)
}
Expand Down
8 changes: 0 additions & 8 deletions ast/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,6 @@ func (p *Parser) WithCapabilities(c *Capabilities) *Parser {
return p
}

const (
annotationScopePackage = "package"
annotationScopeImport = "import"
annotationScopeRule = "rule"
annotationScopeDocument = "document"
annotationScopeSubpackages = "subpackages"
)

func (p *Parser) parsedTermCacheLookup() (*Term, *state) {
l := p.s.loc.Offset
// stop comparing once the cached offsets are lower than l
Expand Down
Loading

0 comments on commit 1d1eb4d

Please sign in to comment.