Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ override builtin functions #5156

Merged
merged 3 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llx/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,17 @@ func BuiltinFunctionV2(typ types.Type, name string) (*chunkHandlerV2, error) {
func (e *blockExecutor) runBoundFunction(bind *RawData, chunk *Chunk, ref uint64) (*RawData, uint64, error) {
log.Trace().Uint64("ref", ref).Str("id", chunk.Id).Msg("exec> run bound function")

// check if the resource defines the function to allow providers to override
// builtin functions like `length` or any other function
if bind.Type.IsResource() && bind.Value != nil {
rr := bind.Value.(Resource)
resource := e.ctx.runtime.Schema().Lookup(rr.MqlName())
_, _, override := e.ctx.runtime.Schema().FindField(resource, chunk.Id)
if override {
return runResourceFunction(e, bind, chunk, ref)
}
}

fh, err := BuiltinFunctionV2(bind.Type, chunk.Id)
if err == nil {
res, dref, err := fh.f(e, bind, chunk, ref)
Expand All @@ -877,5 +888,6 @@ func (e *blockExecutor) runBoundFunction(bind *RawData, chunk *Chunk, ref uint64
if bind.Type.IsResource() {
return runResourceFunction(e, bind, chunk, ref)
}

return nil, 0, err
}
22 changes: 22 additions & 0 deletions mql/mql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,25 @@ func TestDictMethods(t *testing.T) {
},
})
}

func TestBuiltinFunctionOverride(t *testing.T) {
x := testutils.InitTester(testutils.LinuxMock())
x.TestSimple(t, []testutils.SimpleTest{
// This access the resource length property,
// which overrides the builtin function `length`
{
Code: "mos.groups.length",
ResultIndex: 0, Expectation: int64(5),
},
// This calls the native builtin `length` function
{
Code: "mos.groups.list.length",
ResultIndex: 0, Expectation: int64(7),
},
// Same here, builtint `length` function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: builtin

{
Code: "muser.groups.length",
ResultIndex: 0, Expectation: int64(2),
},
})
}
2 changes: 1 addition & 1 deletion mqlc/builtin_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func compileResourceDefault(c *compiler, typ types.Type, ref uint64, id string,
}
}

fieldPath, fieldinfos, ok := c.findField(resource, id)
fieldPath, fieldinfos, ok := c.Schema.FindField(resource, id)
if !ok {
addFieldSuggestions(publicFieldsInfo(c, resource), id, c.Result)
return "", errors.New("cannot find field '" + id + "' in resource " + resource.Name)
Expand Down
37 changes: 1 addition & 36 deletions mqlc/mqlc.go
Original file line number Diff line number Diff line change
Expand Up @@ -882,41 +882,6 @@ func filterEmptyExpressions(expressions []*parser.Expression) []*parser.Expressi
return res
}

type fieldPath []string

// TODO: embed this into the Schema LookupField call!
func (c *compiler) findField(resource *resources.ResourceInfo, fieldName string) (fieldPath, []*resources.Field, bool) {
fieldInfo, ok := resource.Fields[fieldName]
if ok {
return fieldPath{fieldName}, []*resources.Field{fieldInfo}, true
}

for _, f := range resource.Fields {
if f.IsEmbedded {
typ := types.Type(f.Type)
nextResource := c.Schema.Lookup(typ.ResourceName())
if nextResource == nil {
continue
}
childFieldPath, childFieldInfos, ok := c.findField(nextResource, fieldName)
if ok {
fp := make(fieldPath, len(childFieldPath)+1)
fieldInfos := make([]*resources.Field, len(childFieldPath)+1)
fp[0] = f.Name
fieldInfos[0] = f
for i, n := range childFieldPath {
fp[i+1] = n
}
for i, f := range childFieldInfos {
fieldInfos[i+1] = f
}
return fp, fieldInfos, true
}
}
}
return nil, nil, false
}

// compile a bound identifier to its binding
// example: user { name } , where name is compiled bound to the user
// it will return false if it cannot bind the identifier
Expand All @@ -933,7 +898,7 @@ func (c *compiler) compileBoundIdentifierWithMqlCtx(id string, binding *variable
return true, types.Nil, errors.New("cannot find resource that is called by '" + id + "' of type " + typ.Label())
}

fieldPath, fieldinfos, ok := c.findField(resource, id)
fieldPath, fieldinfos, ok := c.Schema.FindField(resource, id)
if ok {
fieldinfo := fieldinfos[len(fieldinfos)-1]
c.CompilerConfig.Stats.CallField(resource.Name, fieldinfo)
Expand Down
39 changes: 39 additions & 0 deletions providers-sdk/v1/resources/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@

package resources

import (
"go.mondoo.com/cnquery/v11/types"
)

type ResourcesSchema interface {
Lookup(resource string) *ResourceInfo
LookupField(resource string, field string) (*ResourceInfo, *Field)
FindField(resource *ResourceInfo, field string) (FieldPath, []*Field, bool)
AllResources() map[string]*ResourceInfo
}

Expand Down Expand Up @@ -117,6 +122,40 @@ func (s *Schema) LookupField(resource string, field string) (*ResourceInfo, *Fie
return res, res.Fields[field]
}

type FieldPath []string

func (s *Schema) FindField(resource *ResourceInfo, field string) (FieldPath, []*Field, bool) {
fieldInfo, ok := resource.Fields[field]
if ok {
return FieldPath{field}, []*Field{fieldInfo}, true
}

for _, f := range resource.Fields {
if f.IsEmbedded {
typ := types.Type(f.Type)
nextResource := s.Lookup(typ.ResourceName())
if nextResource == nil {
continue
}
childFieldPath, childFieldInfos, ok := s.FindField(nextResource, field)
if ok {
fp := make(FieldPath, len(childFieldPath)+1)
fieldInfos := make([]*Field, len(childFieldPath)+1)
fp[0] = f.Name
fieldInfos[0] = f
for i, n := range childFieldPath {
fp[i+1] = n
}
for i, f := range childFieldInfos {
fieldInfos[i+1] = f
}
return fp, fieldInfos, true
}
}
}
return nil, nil, false
}

func (s *Schema) AllResources() map[string]*ResourceInfo {
return s.Resources
}
40 changes: 40 additions & 0 deletions providers-sdk/v1/testutils/mockprovider/resources/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
package resources

import (
"fmt"

"go.mondoo.com/cnquery/v11/llx"
"go.mondoo.com/cnquery/v11/providers-sdk/v1/plugin"
)
Expand Down Expand Up @@ -56,3 +58,41 @@ func (c *mqlMuser) dict() (any, error) {
"string2": "👋",
}, nil
}

// This is an example of how we can override builtin functions today, this will have to change to provide
// a better mechanism to do so but for now, this pattern is adopted in multiple providers

// The example overrides the `length` builtin function by creating a custom list resource which
// essentially defers the loading of the actual "groups" (for this example) and provides a new function
// `length` that returns the number of "groups" but in a more performant way.

// groups() just initializes the custom list resource
func (c *mqlMos) groups() (*mqlCustomGroups, error) {
mqlResource, err := CreateResource(c.MqlRuntime, "customGroups", map[string]*llx.RawData{})
return mqlResource.(*mqlCustomGroups), err
}

// list() is where we actually load the real resources, which could be slow in big environments
func (c *mqlCustomGroups) list() ([]interface{}, error) {
res := []interface{}{}
for i := 0; i < 7; i++ {
group, err := CreateResource(c.MqlRuntime, "mgroup", map[string]*llx.RawData{
"name": llx.StringData(fmt.Sprintf("group%d", i+1)),
})
if err != nil {
return res, err
}
res = append(res, group)
}
return res, nil
}

// length() overrides the builtin function, this should be a more performant way to count
// the "groups"
//
// NOTE this length here is different from the builtin one just for testing
func (c *mqlCustomGroups) length() (int64, error) {
// use `c.MqlRuntime.Connection` to get the provider connection
// make performant API call to count resources
return 5, nil
}
13 changes: 13 additions & 0 deletions providers-sdk/v1/testutils/mockprovider/resources/mockprovider.lr
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,16 @@ muser {
mgroup {
name string
}

mos {
// example override builtin func
groups() customGroups
}

// definition of custom list resource
customGroups {
[]mgroup

// overrides builtin function
length() int
}
Loading
Loading