-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathjax_getitem.py
336 lines (274 loc) · 13.8 KB
/
jax_getitem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
import jax
import jax.tree_util
import awkward as ak
import numpy as np
from numbers import Integral, Real
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
class AuxData(object):
def __init__(self, layout):
self.layout = layout
def __eq__(self, other):
if self.layout is not None:
return self.layout.form == other.layout.form
def find_dataptrs(layout):
def find_nparray_ptrs(node, depth, data_ptrs):
if isinstance(node, ak.layout.NumpyArray):
data_ptrs.append(node.ptr)
data_ptrs = []
ak._util.recursive_walk(layout, find_nparray_ptrs, args=(data_ptrs,))
return data_ptrs
class DifferentiableArray(ak.Array):
def __init__(self, aux_data, tracers):
self.aux_data = aux_data
self.tracers = tracers
if self.aux_data.layout is not None:
self.data_ptrs = find_dataptrs(self.aux_data.layout)
assert len(self.tracers) == len(self.data_ptrs)
self.map_ptrs_to_tracers = dict(zip(self.data_ptrs, self.tracers))
else:
self.data_ptrs = None
self.map_ptrs_to_tracers = None
@property
def layout(self):
return self.aux_data.layout
@layout.setter
def layout(self, layout):
raise ValueError(
"this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
)
def __getitem__(self, where):
if self.layout is None:
raise TypeError("Cannot slice a scalar")
out = self.layout[where]
def find_nparray_node_newptr(layout, outlayout):
outlayout_fieldloc = outlayout.identities.fieldloc
def find_nparray_node(node, depth, fieldloc, shape, nodenum, nodenum_index):
if isinstance(node, ak.layout.NumpyArray):
if node.identities.fieldloc == fieldloc and np.asarray(node.identities).shape[1] == shape[1]:
nodenum_index = nodenum
return
else:
nodenum = nodenum + 1
nodenum_index = -1
ak._util.recursive_walk(layout, find_nparray_node, args=(outlayout_fieldloc, np.asarray(outlayout).shape, 0, nodenum_index))
if nodenum_index == -1:
raise ValueError("Couldn't find the node in new slice")
return nodenum_index
if not isinstance(out, ak.layout.Content):
def recurse(array, recurse_where):
if isinstance(recurse_where, Integral) or isinstance(recurse_where, str):
if isinstance(array.layout, ak.layout.NumpyArray):
if array.layout.ptr in self.map_ptrs_to_tracers:
tracer = self.map_ptrs_to_tracers[array.layout.ptr]
else:
tracer = array.tracers[find_nparray_node_newptr(self.layout, array.layout)]
return tracer[recurse_where]
elif isinstance(where, tuple):
return recurse(array[where[:-1]], where[len(where) - 1])
else:
raise ValueError("Can't slice the array with {0}".format(where))
child = [recurse(self, where)]
aux_data = AuxData(None)
return DifferentiableArray(aux_data, child)
else:
def fetch_indices_and_fieldloc_layout(layout):
if isinstance(layout, ak.layout.NumpyArray):
return [((layout.identities.fieldloc, np.asarray(layout.identities).shape[1]), np.asarray(layout.identities))]
elif isinstance(layout, ak._util.listtypes):
return fetch_indices_and_fieldloc_layout(layout.content)
elif isinstance(layout, ak._util.indexedtypes):
return fetch_indices_and_fieldloc_layout(layout.project())
elif isinstance(layout, ak._util.uniontypes):
raise ValueError("Can't differntiate an UnionArray type {0}".format(layout))
elif isinstance(layout, ak._util.recordtypes):
indices = []
for content in layout.contents:
indices = indices + fetch_indices_and_fieldloc_layout(content)
return indices
elif isinstance(layout, ak._util.indexedtypes):
return fetch_indices_and_fieldloc_layout(layout.content)
elif isinstance(layout, ak._util.indexedoptiontypes):
return fetch_indices_and_fieldloc_layout(layout.content)
elif isinstance(layout, (ak.layout.BitMaskedArray,
ak.layout.ByteMaskedArray,
ak.layout.UnmaskedArray)):
return fetch_indices_and_fieldloc_layout(layout.content)
else:
raise NotImplementedError
def fetch_children_tracer(layout, preslice_identities, children = []):
if isinstance(layout, ak.layout.NumpyArray):
def find_intersection_indices(preslice_identities, postslice_identities):
multiplier = np.append(np.cumprod((np.max(preslice_identities, axis=0) + 1)[::-1])[-2::-1], 1)
haystack = np.sum(preslice_identities * multiplier, axis=1)
needle = np.sum(postslice_identities * multiplier, axis=1)
return np.searchsorted(haystack, needle)
def find_corresponding_identity(postslice_identities, preslice_identities):
for identity in preslice_identities:
if identity[0] == postslice_identities:
return identity[1]
raise ValueError("Couldn't find postslice identities in preslice identities")
if layout.ptr in self.map_ptrs_to_tracers:
tracer = self.map_ptrs_to_tracers[layout.ptr]
indices = find_intersection_indices(find_corresponding_identity((layout.identities.fieldloc, np.asarray(layout.identities).shape[1]), preslice_identities), np.asarray(layout.identities))
children.append(jax.numpy.take(tracer, indices))
return children
else:
tracer = self.tracers[find_nparray_node_newptr(self.layout, layout)]
indices = find_intersection_indices(find_corresponding_identity((layout.identities.fieldloc, np.asarray(layout.identities).shape[1]), preslice_identities), np.asarray(layout.identities))
children.append(jax.numpy.take(tracer, indices))
return children
elif isinstance(layout, ak._util.listtypes):
return fetch_children_tracer(layout.content, preslice_identities)
elif isinstance(layout, ak._util.uniontypes):
raise ValueError("Can't differntiate an UnionArray type {0}".format(layout))
elif isinstance(layout, ak._util.recordtypes):
children = []
for content in layout.contents:
children = children + fetch_children_tracer(content, preslice_identities)
return children
elif isinstance(layout, ak._util.indexedtypes):
return fetch_children_tracer(layout.content, preslice_identities)
elif isinstance(layout, ak._util.indexedoptiontypes):
return fetch_children_tracer(layout.content, preslice_identities)
elif isinstance(layout, (ak.layout.BitMaskedArray,
ak.layout.ByteMaskedArray,
ak.layout.UnmaskedArray)):
return fetch_children_tracer(layout.content, preslice_identities)
else:
raise NotImplementedError("fetch_children_tracer not completely implemented yet for {0}".format(layout))
children = fetch_children_tracer(out, fetch_indices_and_fieldloc_layout(self.aux_data.layout))
out = out.deep_copy()
out.setidentities()
aux_data = AuxData(out)
return DifferentiableArray(aux_data, children)
def __setitem__(self, where, what):
raise ValueError(
"this operation cannot be performed in a JAX-compiled or JAX-differentiated function"
)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# optional sanity-check (i.e. sanity is optional)
for x in inputs:
if isinstance(x, DifferentiableArray):
if self.layout is not None:
assert x.aux_data == self.aux_data
assert len(x.tracers) == len(self.tracers)
else:
assert x.aux_data.layout == self.aux_data.layout
assert len(x.tracers) == len(self.tracers)
# ak.Array __add__, etc. map to the NumPy functions, switch to JAX
for name, np_ufunc in np.core.umath.__dict__.items():
if ufunc is np_ufunc:
ufunc = getattr(jax.numpy, name)
# need to apply the ufunc to the same argument list for each tracer separately
nexttracers = []
for i in range(len(self.tracers)):
nextinputs = [
x.tracers[i] if isinstance(x, DifferentiableArray) else x
for x in inputs
]
nexttracers.append(getattr(ufunc, method)(*nextinputs, **kwargs))
# and return a new DifferentiableArray (keep it wrapped!)
return DifferentiableArray(self.aux_data, nexttracers)
def special_flatten(array):
if isinstance(array, DifferentiableArray):
aux_data, children = array.aux_data, array.tracers
else:
def create_databuffers(node, depth, databuffers):
if isinstance(node, ak.layout.NumpyArray):
databuffers.append(node)
databuffers = []
ak._util.recursive_walk(array.layout, create_databuffers, args=(databuffers,))
array.layout.setidentities()
aux_data = AuxData(array.layout)
children = [jax.numpy.asarray(x) for x in databuffers]
return children, aux_data
def special_unflatten(aux_data, children):
if any(isinstance(x, jax.core.Tracer) for x in children):
return DifferentiableArray(aux_data, children)
elif all(child is None for child in children):
return None
else:
if aux_data.layout is None:
assert len(children) == 1
return np.ndarray.item(np.asarray(children[0]))
def function(layout, num = 0):
if isinstance(layout, ak.layout.NumpyArray):
num = num + 1
return lambda: ak.layout.NumpyArray(children[num - 1])
arr = ak._util.recursively_apply(aux_data.layout, function, pass_depth=False)
return ak.Array(arr)
jax.tree_util.register_pytree_node(ak.Array, special_flatten, special_unflatten)
jax.tree_util.register_pytree_node(DifferentiableArray, special_flatten, special_unflatten)
###############################################################################
# TESTING
###############################################################################
#### ak.layout.NumpyArray ####
test_numpyarray = ak.Array(np.arange(10, dtype=np.float64))
test_numpyarray_tangent = ak.Array(np.arange(10, dtype=np.float64))
def func_numpyarray_1(x):
return x[4] ** 2
def func_numpyarray_2(x):
return x[2:5] ** 2 + x[1:4] ** 2
def func_numpyarray_3(x):
return x[::-1]
#### ak.layout.ListOffsetArray ####
test_listoffsetarray = ak.Array([[1., 2., 3.], [], [4., 5.]])
test_listoffsetarray_tangent = ak.Array([[0., 0., 0.], [], [0., 1.]])
def func_listoffsetarray_1(x):
return x[2] * 2
def func_listoffsetarray_2(x):
return x * x
def func_listoffsetarray_3(x):
return x[0, 0] * x[2, 1]
def func_listoffsetarray_4(x):
return x[::-1] ** 2
def func_listoffsetarray_5(x):
return 2 * x[:-1]
def func_listoffsetarray_6(x):
return x[0][0] * x[2][1]
#### ak.layout.RecordArray ####
test_recordarray = ak.Array([
[{"x": 1.1, "y": [1.0]}, {"x": 2.2, "y": [1.0, 2.2]}],
[],
[{"x": 3.3, "y": [1.0, 2.0, 3.0]}]
])
test_recordarray_tangent = ak.Array([
[{"x": 0.0, "y": [1.0]}, {"x": 2.0, "y": [1.5, 0.0]}],
[],
[{"x": 1.5, "y": [2.0, 0.5, 1.0]}]
])
def func_recordarray_1(array):
return 2 * array.y[2][0][1] + 10
def func_recordarray_2(array):
return 2 * array.y[0][0][0] ** 2
def func_recordarray_3(array):
return 2 * array.y[2][0] + 10
def func_recordarray_4(array):
return 2 * array.y[0][0] ** 2
def func_recordarray_5(array):
return 2 * array.y[2] + 10
def func_recordarray_6(array):
return 2 * array.y[0] ** 2
def func_recordarray_7(array):
return 2 * array.y
def func_recordarray_8(array):
return 2 * array.y ** 2
def func_recordarray_9(array):
return 2 * array.y[2, 0, 1] + 10
def func_recordarray_10(array):
return 2 * array.y[0, 0, 0] ** 2
def func_recordarray_11(array):
return 2 * array.y[2, 0] + 10
def func_recordarray_12(array):
return 2 * array.y[0, 0] ** 2
value_jvp, jvp_grad = jax.jvp(func_numpyarray_3, (test_numpyarray,), (test_numpyarray_tangent,))
jit_value = jax.jit(func_numpyarray_3)(test_numpyarray)
# value_vjp, vjp_func = jax.vjp(func_recordarray_12, test_recordarray)
# print(type(value_vjp))
# print(vjp_func(test_recordarray))
# value, grad = jax.value_and_grad(func_numpyarray_2)(test_nparray)
print("Value and Grad are {0} and {1}".format(value_jvp, jvp_grad))
print("JIT value is {0}".format(jit_value))
# print("VJP value and grad is {0} and {1}".format(value_vjp, vjp_func(test_nparray)))
# print("Value and grad are {0} and {1}".format(value, grad))