forked from open-policy-agent/opa
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobject.go
217 lines (187 loc) · 6.02 KB
/
object.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
// Copyright 2020 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.
package topdown
import (
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/internal/ref"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
objA, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
objB, err := builtins.ObjectOperand(operands[1].Value, 2)
if err != nil {
return err
}
r := mergeWithOverwrite(objA, objB)
return iter(ast.NewTerm(r))
}
func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
arr, err := builtins.ArrayOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Because we need merge-with-overwrite behavior, we can iterate
// back-to-front, and get a mostly correct set of key assignments that
// give us the "last assignment wins, with merges" behavior we want.
// However, if a non-object overwrites an object value anywhere in the
// chain of assignments for a key, we have to "freeze" that key to
// prevent accidentally picking up nested objects that could merge with
// it from earlier in the input array.
// Example:
// Input: [{"a": {"b": 2}}, {"a": 4}, {"a": {"c": 3}}]
// Want Output: {"a": {"c": 3}}
result := ast.NewObject()
frozenKeys := map[*ast.Term]struct{}{}
for i := arr.Len() - 1; i >= 0; i-- {
o, ok := arr.Elem(i).Value.(ast.Object)
if !ok {
return builtins.NewOperandElementErr(1, arr, arr.Elem(i).Value, "object")
}
mergewithOverwriteInPlace(result, o, frozenKeys)
if err != nil {
return err
}
}
return iter(ast.NewTerm(result))
}
func builtinObjectRemove(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and an array/set/object of keys
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a set of keys to remove
keysToRemove, err := getObjectKeysParam(operands[1].Value)
if err != nil {
return err
}
r := ast.NewObject()
obj.Foreach(func(key *ast.Term, value *ast.Term) {
if !keysToRemove.Contains(key) {
r.Insert(key, value)
}
})
return iter(ast.NewTerm(r))
}
func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
// Expect an object and an array/set/object of keys
obj, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// Build a new object from the supplied filter keys
keys, err := getObjectKeysParam(operands[1].Value)
if err != nil {
return err
}
filterObj := ast.NewObject()
keys.Foreach(func(key *ast.Term) {
filterObj.Insert(key, ast.NullTerm())
})
// Actually do the filtering
r, err := obj.Filter(filterObj)
if err != nil {
return err
}
return iter(ast.NewTerm(r))
}
func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
object, err := builtins.ObjectOperand(operands[0].Value, 1)
if err != nil {
return err
}
// if the get key is not an array, attempt to get the top level key for the operand value in the object
path, err := builtins.ArrayOperand(operands[1].Value, 2)
if err != nil {
if ret := object.Get(operands[1]); ret != nil {
return iter(ret)
}
return iter(operands[2])
}
// if the path is empty, then we skip selecting nested keys and return the whole object
if path.Len() == 0 {
return iter(operands[0])
}
// build an ast.Ref from the array and see if it matches within the object
pathRef := ref.ArrayPath(path)
value, err := object.Find(pathRef)
if err != nil {
return iter(operands[2])
}
return iter(ast.NewTerm(value))
}
// getObjectKeysParam returns a set of key values
// from a supplied ast array, object, set value
func getObjectKeysParam(arrayOrSet ast.Value) (ast.Set, error) {
keys := ast.NewSet()
switch v := arrayOrSet.(type) {
case *ast.Array:
_ = v.Iter(func(f *ast.Term) error {
keys.Add(f)
return nil
})
case ast.Set:
_ = v.Iter(func(f *ast.Term) error {
keys.Add(f)
return nil
})
case ast.Object:
_ = v.Iter(func(k *ast.Term, _ *ast.Term) error {
keys.Add(k)
return nil
})
default:
return nil, builtins.NewOperandTypeErr(2, arrayOrSet, "object", "set", "array")
}
return keys, nil
}
func mergeWithOverwrite(objA, objB ast.Object) ast.Object {
merged, _ := objA.MergeWith(objB, func(v1, v2 *ast.Term) (*ast.Term, bool) {
originalValueObj, ok2 := v1.Value.(ast.Object)
updateValueObj, ok1 := v2.Value.(ast.Object)
if !ok1 || !ok2 {
// If we can't merge, stick with the right-hand value
return v2, false
}
// Recursively update the existing value
merged := mergeWithOverwrite(originalValueObj, updateValueObj)
return ast.NewTerm(merged), false
})
return merged
}
// Modifies obj with any new keys from other, and recursively
// merges any keys where the values are both objects.
func mergewithOverwriteInPlace(obj, other ast.Object, frozenKeys map[*ast.Term]struct{}) {
other.Foreach(func(k, v *ast.Term) {
v2 := obj.Get(k)
// The key didn't exist in other, keep the original value.
if v2 == nil {
obj.Insert(k, v)
return
}
// The key exists in both. Merge or reject change.
updateValueObj, ok2 := v.Value.(ast.Object)
originalValueObj, ok1 := v2.Value.(ast.Object)
// Both are objects? Merge recursively.
if ok1 && ok2 {
// Check to make sure that this key isn't frozen before merging.
if _, ok := frozenKeys[v2]; !ok {
mergewithOverwriteInPlace(originalValueObj, updateValueObj, frozenKeys)
}
} else {
// Else, original value wins. Freeze the key.
frozenKeys[v2] = struct{}{}
}
})
}
func init() {
RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion)
RegisterBuiltinFunc(ast.ObjectUnionN.Name, builtinObjectUnionN)
RegisterBuiltinFunc(ast.ObjectRemove.Name, builtinObjectRemove)
RegisterBuiltinFunc(ast.ObjectFilter.Name, builtinObjectFilter)
RegisterBuiltinFunc(ast.ObjectGet.Name, builtinObjectGet)
}