Skip to content

Commit

Permalink
add compiler.Compiler and compiler.optimizer WithEnablePrintStatement…
Browse files Browse the repository at this point in the history
…s option and pass through to ast.Compiler

When using the compiler.Compiler; rego "print" statements get removed and there is no way on the API to enable them.
So the PrintHook is broken.

Signed-off-by: Kevin St. Pierre <[email protected]>
  • Loading branch information
kevinstyra authored and tsandall committed Jul 22, 2022
1 parent 9f9fbb9 commit 646c841
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 33 deletions.
82 changes: 49 additions & 33 deletions compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,25 @@ const resultVar = ast.Var("result")

// Compiler implements bundle compilation and linking.
type Compiler struct {
capabilities *ast.Capabilities // the capabilities that compiled policies may require
bundle *bundle.Bundle // the bundle that the compiler operates on
revision *string // the revision to set on the output bundle
asBundle bool // whether to assume bundle layout on file loading or not
filter loader.Filter // filter to apply to file loader
paths []string // file paths to load. TODO(tsandall): add support for supplying readers for embedded users.
entrypoints orderedStringSet // policy entrypoints required for optimization and certain targets
optimizationLevel int // how aggressive should optimization be
target string // target type (wasm, rego, etc.)
output *io.Writer // output stream to write bundle to
entrypointrefs []*ast.Term // validated entrypoints computed from default decision or manually supplied entrypoints
compiler *ast.Compiler // rego ast compiler used for semantic checks and rewriting
policy *ir.Policy // planner output when wasm or plan targets are enabled
debug debug.Debug // optionally outputs debug information produced during build
bvc *bundle.VerificationConfig // represents the key configuration used to verify a signed bundle
bsc *bundle.SigningConfig // represents the key configuration used to generate a signed bundle
keyID string // represents the name of the default key used to verify a signed bundle
metadata *map[string]interface{} // represents additional data included in .manifest file
capabilities *ast.Capabilities // the capabilities that compiled policies may require
bundle *bundle.Bundle // the bundle that the compiler operates on
revision *string // the revision to set on the output bundle
asBundle bool // whether to assume bundle layout on file loading or not
filter loader.Filter // filter to apply to file loader
paths []string // file paths to load. TODO(tsandall): add support for supplying readers for embedded users.
entrypoints orderedStringSet // policy entrypoints required for optimization and certain targets
optimizationLevel int // how aggressive should optimization be
target string // target type (wasm, rego, etc.)
output *io.Writer // output stream to write bundle to
entrypointrefs []*ast.Term // validated entrypoints computed from default decision or manually supplied entrypoints
compiler *ast.Compiler // rego ast compiler used for semantic checks and rewriting
policy *ir.Policy // planner output when wasm or plan targets are enabled
debug debug.Debug // optionally outputs debug information produced during build
enablePrintStatements bool // optionally enable rego print statements
bvc *bundle.VerificationConfig // represents the key configuration used to verify a signed bundle
bsc *bundle.SigningConfig // represents the key configuration used to generate a signed bundle
keyID string // represents the name of the default key used to verify a signed bundle
metadata *map[string]interface{} // represents additional data included in .manifest file
}

// New returns a new compiler instance that can be invoked.
Expand Down Expand Up @@ -135,6 +136,14 @@ func (c *Compiler) WithDebug(sink io.Writer) *Compiler {
return c
}

// WithEnablePrintStatements enables print statements inside of modules compiled
// by the compiler. If print statements are not enabled, calls to print() are
// erased at compile-time.
func (c *Compiler) WithEnablePrintStatements(yes bool) *Compiler {
c.enablePrintStatements = yes
return c
}

// WithPaths adds input filepaths to read policy and data from.
func (c *Compiler) WithPaths(p ...string) *Compiler {
c.paths = append(c.paths, p...)
Expand Down Expand Up @@ -374,14 +383,15 @@ func (c *Compiler) optimize(ctx context.Context) error {

if c.optimizationLevel <= 0 {
var err error
c.compiler, err = compile(c.capabilities, c.bundle, c.debug)
c.compiler, err = compile(c.capabilities, c.bundle, c.debug, c.enablePrintStatements)
return err
}

o := newOptimizer(c.capabilities, c.bundle).
WithEntrypoints(c.entrypointrefs).
WithDebug(c.debug.Writer()).
WithShallowInlining(c.optimizationLevel <= 1)
WithShallowInlining(c.optimizationLevel <= 1).
WithEnablePrintStatements(c.enablePrintStatements)

err := o.Do(ctx)
if err != nil {
Expand All @@ -399,7 +409,7 @@ func (c *Compiler) compilePlan(ctx context.Context) error {
// AST compiler will not be set because the default target does not require it.
if c.compiler == nil {
var err error
c.compiler, err = compile(c.capabilities, c.bundle, c.debug)
c.compiler, err = compile(c.capabilities, c.bundle, c.debug, c.enablePrintStatements)
if err != nil {
return err
}
Expand Down Expand Up @@ -615,15 +625,16 @@ func (err undefinedEntrypointErr) Error() string {
}

type optimizer struct {
capabilities *ast.Capabilities
bundle *bundle.Bundle
compiler *ast.Compiler
entrypoints []*ast.Term
nsprefix string
resultsymprefix string
outputprefix string
shallow bool
debug debug.Debug
capabilities *ast.Capabilities
bundle *bundle.Bundle
compiler *ast.Compiler
entrypoints []*ast.Term
nsprefix string
resultsymprefix string
outputprefix string
shallow bool
debug debug.Debug
enablePrintStatements bool
}

func newOptimizer(c *ast.Capabilities, b *bundle.Bundle) *optimizer {
Expand All @@ -644,6 +655,11 @@ func (o *optimizer) WithDebug(sink io.Writer) *optimizer {
return o
}

func (o *optimizer) WithEnablePrintStatements(yes bool) *optimizer {
o.enablePrintStatements = yes
return o
}

func (o *optimizer) WithEntrypoints(es []*ast.Term) *optimizer {
o.entrypoints = es
return o
Expand Down Expand Up @@ -683,7 +699,7 @@ func (o *optimizer) Do(ctx context.Context) error {
for i, e := range o.entrypoints {

var err error
o.compiler, err = compile(o.capabilities, o.bundle, o.debug)
o.compiler, err = compile(o.capabilities, o.bundle, o.debug, o.enablePrintStatements)
if err != nil {
return err
}
Expand Down Expand Up @@ -939,7 +955,7 @@ func (o *optimizer) getSupportModuleFilename(used map[string]int, module *ast.Mo

var safePathPattern = regexp.MustCompile(`^[\w-_/]+$`)

func compile(c *ast.Capabilities, b *bundle.Bundle, dbg debug.Debug) (*ast.Compiler, error) {
func compile(c *ast.Capabilities, b *bundle.Bundle, dbg debug.Debug, enablePrintStatements bool) (*ast.Compiler, error) {

modules := map[string]*ast.Module{}

Expand All @@ -951,7 +967,7 @@ func compile(c *ast.Capabilities, b *bundle.Bundle, dbg debug.Debug) (*ast.Compi
modules[mf.URL] = mf.Parsed
}

compiler := ast.NewCompiler().WithCapabilities(c).WithDebug(dbg.Writer())
compiler := ast.NewCompiler().WithCapabilities(c).WithDebug(dbg.Writer()).WithEnablePrintStatements(enablePrintStatements)
compiler.Compile(modules)

if compiler.Failed() {
Expand Down
58 changes: 58 additions & 0 deletions compile/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,64 @@ func TestCompilerLoadFilesystem(t *testing.T) {
})
}

func TestCompilerLoadFilesystemWithEnablePrintStatementsFalse(t *testing.T) {
files := map[string]string{
"test.rego": `
package test
allow { print(1) }
`,
"data.json": `
{"b1": {"k": "v"}}`,
}

test.WithTempFS(files, func(root string) {
compiler := New().
WithPaths(root).
WithTarget("plan").WithEntrypoints("test/allow").
WithEnablePrintStatements(false)

if err := compiler.Build(context.Background()); err != nil {
t.Fatal(err)
}

bundle := compiler.Bundle()

if strings.Contains(string(bundle.PlanModules[0].Raw), "internal.print") {
t.Fatalf("output different than expected:\n\ngot: %v\n\nfound: internal.print", string(bundle.PlanModules[0].Raw))
}
})
}

func TestCompilerLoadFilesystemWithEnablePrintStatementsTrue(t *testing.T) {
files := map[string]string{
"test.rego": `
package test
allow { print(1) }
`,
"data.json": `
{"b1": {"k": "v"}}`,
}

test.WithTempFS(files, func(root string) {
compiler := New().
WithPaths(root).
WithTarget("plan").WithEntrypoints("test/allow").
WithEnablePrintStatements(true)

if err := compiler.Build(context.Background()); err != nil {
t.Fatal(err)
}

bundle := compiler.Bundle()

if !strings.Contains(string(bundle.PlanModules[0].Raw), "internal.print") {
t.Fatalf("output different than expected:\n\ngot: %v\n\nmissing: internal.print", string(bundle.PlanModules[0].Raw))
}
})
}

func TestCompilerLoadHonorsFilter(t *testing.T) {
files := map[string]string{
"test.rego": `
Expand Down

0 comments on commit 646c841

Please sign in to comment.