forked from google/neural-tangents
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_test.py
959 lines (816 loc) · 37.1 KB
/
predict_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `utils/predict.py`."""
import math
from absl.testing import absltest
from jax import test_util as jtu
from jax.api import grad
from jax.api import jit
from jax.api import vmap
from jax.config import config
from jax.experimental import optimizers
from jax.flatten_util import ravel_pytree
from jax.lib import xla_bridge
import jax.numpy as np
import jax.random as random
from neural_tangents import predict
from neural_tangents import stax
from neural_tangents.utils import batch
from neural_tangents.utils import empirical
from neural_tangents.utils import test_utils
from neural_tangents.utils import utils
config.parse_flags_with_absl()
MATRIX_SHAPES = [(3, 3), (4, 4)]
OUTPUT_LOGITS = [1, 2]
GETS = ('ntk', 'nngp', ('ntk', 'nngp'))
RTOL = 0.01
ATOL = 0.01
if not config.read('jax_enable_x64'):
RTOL = 0.02
ATOL = 0.02
FLAT = 'FLAT'
POOLING = 'POOLING'
# TODO(schsam): Add a pooling test when multiple inputs are supported in
# Conv + Pooling.
TRAIN_SHAPES = [(4, 8), (8, 8), (6, 4, 4, 3)]
TEST_SHAPES = [(6, 8), (16, 8), (2, 4, 4, 3)]
NETWORK = [FLAT, FLAT, FLAT, FLAT]
CONVOLUTION_CHANNELS = 256
test_utils.update_test_tolerance()
def _build_network(input_shape, network, out_logits):
if len(input_shape) == 1:
assert network == FLAT
return stax.serial(
stax.Dense(4096, W_std=1.2, b_std=0.05), stax.Erf(),
stax.Dense(out_logits, W_std=1.2, b_std=0.05))
elif len(input_shape) == 3:
if network == POOLING:
return stax.serial(
stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
stax.GlobalAvgPool(), stax.Dense(out_logits, W_std=2.0, b_std=0.05))
elif network == FLAT:
return stax.serial(
stax.Conv(CONVOLUTION_CHANNELS, (3, 3), W_std=2.0, b_std=0.05),
stax.Flatten(), stax.Dense(out_logits, W_std=2.0, b_std=0.05))
else:
raise ValueError('Unexpected network type found: {}'.format(network))
else:
raise ValueError('Expected flat or image test input.')
def _empirical_kernel(key, input_shape, network, out_logits):
init_fn, f, _ = _build_network(input_shape, network, out_logits)
_, params = init_fn(key, (-1,) + input_shape)
_kernel_fn = empirical.empirical_kernel_fn(f, trace_axes=())
kernel_fn = lambda x1, x2, get: _kernel_fn(x1, x2, get, params)
return params, f, jit(kernel_fn, static_argnums=(2,))
def _theoretical_kernel(key, input_shape, network, out_logits):
init_fn, f, kernel_fn = _build_network(input_shape, network, out_logits)
_, params = init_fn(key, (-1,) + input_shape)
return params, f, jit(kernel_fn, static_argnums=(2,))
KERNELS = {
'empirical': _empirical_kernel,
'theoretical': _theoretical_kernel,
}
class PredictTest(jtu.JaxTestCase):
def _assertAllClose(self, x, y, rtol):
x = ravel_pytree(x)[0]
y = ravel_pytree(y)[0]
diff = 2 * np.sum(np.abs(x - y)) / (np.sum(np.abs(x))
+ np.sum(np.abs(y)) + 1e-4)
self.assertLess(diff, rtol)
def _test_zero_time(self, predictor, fx_train_0, fx_test_0, g_td, momentum):
fx_train_t0, fx_test_t0 = predictor(0.0, fx_train_0, fx_test_0, g_td)
self.assertAllClose(fx_train_0, fx_train_t0)
self.assertAllClose(fx_test_0, fx_test_t0)
fx_train_only_t0 = predictor(0.0, fx_train_0, None, g_td)
self.assertAllClose(fx_train_0, fx_train_only_t0)
if momentum is not None:
# Test state-based prediction
state_0 = predict.ODEState(fx_train_0, fx_test_0)
state_t0 = predictor(0.0, state_0, None, g_td)
self.assertAllClose(state_0.fx_train, state_t0.fx_train)
self.assertAllClose(state_0.fx_test, state_t0.fx_test)
state_train_only_0 = predict.ODEState(fx_train_0)
state_train_only_t0 = predictor(0.0, state_0, None, g_td)
self.assertAllClose(state_train_only_0.fx_train,
state_train_only_t0.fx_train)
def _test_inf_time(self, predictor, fx_train_0, fx_test_0, g_td, y_train):
# Test infinite-time prediction
pr_inf = predictor(np.inf, fx_train_0)
self.assertAllClose(pr_inf, y_train, check_dtypes=False)
self.assertAllClose(pr_inf, predictor(None, fx_train_0))
self.assertAllClose(predictor(np.inf, fx_train_0, fx_test_0, g_td),
predictor(None, fx_train_0, fx_test_0, g_td))
def _test_multi_step(self, predictor, fx_train_0, fx_test_0, g_td, momentum):
# Test multi-time prediction
ts = np.arange(6).reshape((2, 1, 3))
fx_train_single, fx_test_single = predictor(ts, fx_train_0, fx_test_0, g_td)
fx_train_concat, fx_test_concat = [], []
for t in ts.ravel():
fx_train_concat_t, fx_test_concat_t = predictor(t, fx_train_0, fx_test_0,
g_td)
fx_train_concat += [fx_train_concat_t]
fx_test_concat += [fx_test_concat_t]
fx_train_concat = np.stack(fx_train_concat).reshape(
ts.shape + fx_train_single.shape[ts.ndim:])
fx_test_concat = np.stack(fx_test_concat).reshape(
ts.shape + fx_test_single.shape[ts.ndim:])
self.assertAllClose(fx_train_concat, fx_train_single)
self.assertAllClose(fx_test_concat, fx_test_single)
if momentum is not None:
state_0 = predict.ODEState(fx_train_0, fx_test_0)
t_1 = (0, 0, 2)
state_1 = predictor(ts[t_1], state_0, None, g_td)
self.assertAllClose(fx_train_single[t_1], state_1.fx_train)
self.assertAllClose(fx_test_single[t_1], state_1.fx_test)
t_max = (-1,) * ts.ndim
state_max = predictor(ts[t_max] - ts[t_1], state_1, None, g_td)
self.assertAllClose(fx_train_single[t_max], state_max.fx_train)
self.assertAllClose(fx_test_single[t_max], state_max.fx_test)
@classmethod
def _get_inputs(cls, out_logits, test_shape, train_shape):
key = random.PRNGKey(0)
key, split = random.split(key)
x_train = random.normal(split, train_shape)
key, split = random.split(key)
y_train = np.array(
random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32)
key, split = random.split(key)
x_test = random.normal(split, test_shape)
return key, x_test, x_train, y_train
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}_{}_'
'momentum={}_lr={}_loss={}_t={}'.format(train, test, network,
out_logits, name,
momentum, learning_rate,
loss, t),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
'fn_and_kernel':
fn,
'momentum':
momentum,
'loss':
loss,
'learning_rate':
learning_rate,
't':
t
} for train, test, network in
zip(TRAIN_SHAPES, TEST_SHAPES, NETWORK)
for out_logits in OUTPUT_LOGITS
for name, fn in KERNELS.items()
for momentum in [None, 0.9]
for learning_rate in [0.0002]
for t in [5]
for loss in ['mse_analytic', 'mse'])
)
def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
fn_and_kernel, momentum, learning_rate, t, loss):
key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)
g_dd = ntk(x_train, None, 'ntk')
g_td = ntk(x_test, x_train, 'ntk')
# Regress to an MSE loss.
loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))
trace_axes = () if g_dd.ndim == 4 else (-1,)
if loss == 'mse_analytic':
if momentum is not None:
raise absltest.SkipTest(momentum)
predictor = predict.gradient_descent_mse(g_dd, y_train,
learning_rate=learning_rate,
trace_axes=trace_axes)
elif loss == 'mse':
predictor = predict.gradient_descent(loss_fn, g_dd, y_train,
learning_rate=learning_rate,
momentum=momentum,
trace_axes=trace_axes)
else:
raise NotImplementedError(loss)
predictor = jit(predictor)
fx_train_0 = f(params, x_train)
fx_test_0 = f(params, x_test)
self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
if loss == 'mse_analytic':
self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train)
if momentum is None:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
else:
opt_init, opt_update, get_params = optimizers.momentum(learning_rate,
momentum)
opt_state = opt_init(params)
for i in range(t):
params = get_params(opt_state)
opt_state = opt_update(i, grad_loss(params, x_train), opt_state)
params = get_params(opt_state)
fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)
self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
@classmethod
def _cov_empirical(cls, x):
return np.einsum('itjk,itlk->tjl', x, x, optimize=True) / (x.shape[0] *
x.shape[-1])
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:1], TEST_SHAPES[:1], NETWORK[:1])
for out_logits in [1]))
def testNTKMeanCovPrediction(self, train_shape, test_shape, network,
out_logits):
key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
init_fn, f, kernel_fn = stax.serial(
stax.Dense(512, W_std=1.2, b_std=0.05), stax.Erf(),
stax.Dense(out_logits, W_std=1.2, b_std=0.05))
reg = 1e-6
predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train, diag_reg=reg)
ts = np.array([1., 5., 10.])
fx_test_inf, cov_test_inf = predictor(ts, x_test, 'ntk', True)
self.assertEqual(cov_test_inf.shape[1], x_test.shape[0])
self.assertGreater(np.min(np.linalg.eigh(cov_test_inf)[0]), -1e-8)
fx_train_inf, cov_train_inf = predictor(ts, None, 'ntk', True)
self.assertEqual(cov_train_inf.shape[1], x_train.shape[0])
self.assertGreater(np.min(np.linalg.eigh(cov_train_inf)[0]), -1e-8)
_kernel_fn = empirical.empirical_kernel_fn(f)
kernel_fn = jit(lambda x1, x2, params: _kernel_fn(x1, x2, 'ntk', params))
def predict_empirical(key):
_, params = init_fn(key, train_shape)
g_dd = kernel_fn(x_train, None, params)
g_td = kernel_fn(x_test, x_train, params)
predict_fn = predict.gradient_descent_mse(g_dd, y_train, diag_reg=reg)
fx_train_0 = f(params, x_train)
fx_test_0 = f(params, x_test)
return predict_fn(ts, fx_train_0, fx_test_0, g_td)
def predict_mc(count, key):
key = random.split(key, count)
fx_train, fx_test = vmap(predict_empirical)(key)
fx_train_mean = np.mean(fx_train, axis=0)
fx_test_mean = np.mean(fx_test, axis=0)
fx_train_centered = fx_train - fx_train_mean
fx_test_centered = fx_test - fx_test_mean
cov_train = PredictTest._cov_empirical(fx_train_centered)
cov_test = PredictTest._cov_empirical(fx_test_centered)
return fx_train_mean, fx_test_mean, cov_train, cov_test
fx_train_mc, fx_test_mc, cov_train_mc, cov_test_mc = predict_mc(4096, key)
rtol = 0.05
self._assertAllClose(fx_train_mc, fx_train_inf, rtol)
self._assertAllClose(cov_train_mc, cov_train_inf, rtol)
self._assertAllClose(cov_test_mc, cov_test_inf, rtol)
self._assertAllClose(fx_test_mc, fx_test_inf, rtol)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1])
for out_logits in OUTPUT_LOGITS))
def testGradientDescentMseEnsembleGet(self, train_shape, test_shape, network,
out_logits):
_, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
_, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)
predictor = predict.gradient_descent_mse_ensemble(kernel_fn,
x_train,
y_train,
diag_reg=0.)
for x in [None, 'x_test']:
with self.subTest(x=x):
x = x if x is None else x_test
out = predictor(None, x, 'ntk', compute_cov=True)
assert isinstance(out, predict.Gaussian)
out = predictor(1., x, 'nngp', compute_cov=True)
assert isinstance(out, predict.Gaussian)
out = predictor(np.array([0., 1.]), x, ('ntk',), compute_cov=True)
assert len(out) == 1 and isinstance(out[0], predict.Gaussian)
out = predictor(2., x, ('ntk', 'nngp'), compute_cov=True)
assert (len(out) == 2 and isinstance(out[0], predict.Gaussian) and
isinstance(out[1], predict.Gaussian))
out2 = predictor(2., x, ('nngp', 'ntk'), compute_cov=True)
self.assertAllClose(out[0], out2[1])
self.assertAllClose(out[1], out2[0])
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}_get={}'.format(
train, test, network, out_logits, get),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
'get':
get,
}
for train, test, network in zip(
TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1])
for out_logits in OUTPUT_LOGITS for get in GETS))
def testInfiniteTimeAgreement(self, train_shape, test_shape, network,
out_logits, get):
_, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
_, _, kernel_fn = _build_network(train_shape[1:], network, out_logits)
reg = 0.
predictor = predict.gradient_descent_mse_ensemble(kernel_fn,
x_train,
y_train,
diag_reg=reg)
for x in (None, 'x_test'):
with self.subTest(x=x):
x = x if x is None else x_test
fin = predictor(t=np.inf, x_test=x, get=get, compute_cov=True)
inf = predictor(t=None, x_test=x, get=get, compute_cov=True)
self.assertAllClose(inf, fin)
if x is None:
fin_x = predictor(t=np.inf, x_test=x_train, get=get, compute_cov=True)
inf_x = predictor(t=None, x_test=x_train, get=get, compute_cov=True)
self.assertAllClose(inf, inf_x)
self.assertAllClose(inf_x, fin_x)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1])
for out_logits in OUTPUT_LOGITS))
def testZeroTimeAgreement(self, train_shape, test_shape, network, out_logits):
"""Test that the NTK and NNGP agree at t=0."""
_, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
_, _, ker_fun = _build_network(train_shape[1:], network, out_logits)
reg = 1e-7
predictor = predict.gradient_descent_mse_ensemble(
ker_fun,
x_train,
y_train,
diag_reg=reg)
for x in (None, 'x_test'):
with self.subTest(x=x):
x = x if x is None else x_test
zero = predictor(t=0.0, x_test=x, get=('NTK', 'NNGP'), compute_cov=True)
if x is None:
k = ker_fun(x_train, None, get='nngp')
ref = (np.zeros_like(y_train, k.dtype), k)
else:
ref = (np.zeros((test_shape[0], out_logits)),
ker_fun(x_test, None, get='nngp'))
self.assertAllClose((ref,) * 2, zero, check_dtypes=False)
if x is None:
zero_x = predictor(t=0.0, x_test=x_train, get=('NTK', 'NNGP'),
compute_cov=True)
self.assertAllClose((ref,) * 2, zero_x)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1])
for out_logits in OUTPUT_LOGITS))
def testNTK_NTKNNGPAgreement(self, train_shape, test_shape, network,
out_logits):
_, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
_, _, ker_fun = _build_network(train_shape[1:], network, out_logits)
reg = 1e-7
predictor = predict.gradient_descent_mse_ensemble(ker_fun,
x_train,
y_train,
diag_reg=reg)
ts = np.logspace(-2, 8, 10).reshape((5, 2))
for t in (None, 'ts'):
for x in (None, 'x_test'):
with self.subTest(t=t, x=x):
x = x if x is None else x_test
t = t if t is None else ts
ntk = predictor(t=t, get='ntk', x_test=x)
# Test time broadcasting
if t is not None:
ntk_ind = np.array([predictor(t=t, get='ntk', x_test=x)
for t in t.ravel()]).reshape(
t.shape + ntk.shape[2:])
self.assertAllClose(ntk_ind, ntk)
# Create a hacked kernel function that always returns the ntk kernel
def always_ntk(x1, x2, get=('nngp', 'ntk')):
out = ker_fun(x1, x2, get=('nngp', 'ntk'))
if get == 'nngp' or get == 'ntk':
return out.ntk
else:
return out._replace(nngp=out.ntk)
predictor_ntk = predict.gradient_descent_mse_ensemble(always_ntk,
x_train,
y_train,
diag_reg=reg)
ntk_nngp = predictor_ntk(t=t, get='nngp', x_test=x)
# Test if you use nngp equations with ntk, you get the same mean
self.assertAllClose(ntk, ntk_nngp)
# Next test that if you go through the NTK code path, but with only
# the NNGP kernel, we recreate the NNGP dynamics.
# Create a hacked kernel function that always returns the nngp kernel
def always_nngp(x1, x2, get=('nngp', 'ntk')):
out = ker_fun(x1, x2, get=('nngp', 'ntk'))
if get == 'nngp' or get == 'ntk':
return out.nngp
else:
return out._replace(ntk=out.nngp)
predictor_nngp = predict.gradient_descent_mse_ensemble(always_nngp,
x_train,
y_train,
diag_reg=reg)
nngp_cov = predictor(t=t,
get='nngp',
x_test=x,
compute_cov=True).covariance
# test time broadcasting for covariance
nngp_ntk_cov = predictor_nngp(t=t,
get='ntk',
x_test=x,
compute_cov=True).covariance
if t is not None:
nngp_ntk_cov_ind = np.array(
[predictor_nngp(t=t,
get='ntk',
x_test=x,
compute_cov=True).covariance for
t in t.ravel()]).reshape(t.shape + nngp_cov.shape[2:])
self.assertAllClose(nngp_ntk_cov_ind, nngp_ntk_cov)
# Test if you use ntk equations with nngp, you get the same cov
# Although, due to accumulation of numerical errors, only roughly.
self.assertAllClose(nngp_cov, nngp_ntk_cov)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:-1], TEST_SHAPES[:-1], NETWORK[:-1])
for out_logits in OUTPUT_LOGITS))
def testPredCovPosDef(self, train_shape, test_shape, network, out_logits):
_, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
_, _, ker_fun = _build_network(train_shape[1:], network, out_logits)
ts = np.logspace(-3, 3, 10)
predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
ker_fun, x_train, y_train)
for get in ('nngp', 'ntk'):
for x in (None, 'x_test'):
for t in (None, 'ts'):
with self.subTest(get=get, x=x, t=t):
cov = predict_fn_mse_ens(t=t if t is None else ts,
get=get,
x_test=x if x is None else x_test,
compute_cov=True).covariance
self.assertAllClose(cov, np.moveaxis(cov, -1, -2))
self.assertGreater(np.min(np.linalg.eigh(cov)[0]), -1e-4)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_test={}_network={}_logits={}'.format(
train, test, network, out_logits),
'train_shape':
train,
'test_shape':
test,
'network':
network,
'out_logits':
out_logits,
}
for train, test, network in zip(
TRAIN_SHAPES[:1], TEST_SHAPES[:1], NETWORK[:1])
for out_logits in [1]))
def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
out_logits):
training_steps = 1000
learning_rate = 0.1
ensemble_size = 1024
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(128, W_std=1.2, b_std=0.05), stax.Erf(),
stax.Dense(out_logits, W_std=1.2, b_std=0.05))
opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)
key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
train_shape)
predict_fn_mse_ens = predict.gradient_descent_mse_ensemble(
kernel_fn,
x_train,
y_train,
learning_rate=learning_rate,
diag_reg=0.)
train = (x_train, y_train)
ensemble_key = random.split(key, ensemble_size)
loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))
def train_network(key):
_, params = init_fn(key, (-1,) + train_shape[1:])
opt_state = opt_init(params)
for i in range(training_steps):
opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)
return get_params(opt_state)
params = vmap(train_network)(ensemble_key)
rtol = 0.08
for x in [None, 'x_test']:
with self.subTest(x=x):
x = x if x is None else x_test
x_fin = x_train if x is None else x_test
ensemble_fx = vmap(apply_fn, (0, None))(params, x_fin)
mean_emp = np.mean(ensemble_fx, axis=0)
mean_subtracted = ensemble_fx - mean_emp
cov_emp = np.einsum(
'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
mean_subtracted.shape[0] * mean_subtracted.shape[-1])
ntk = predict_fn_mse_ens(training_steps, x, 'ntk', compute_cov=True)
self._assertAllClose(mean_emp, ntk.mean, rtol)
self._assertAllClose(cov_emp, ntk.covariance, rtol)
def testGradientDescentMseEnsembleTrain(self):
key = random.PRNGKey(1)
x = random.normal(key, (8, 4, 6, 3))
_, _, kernel_fn = stax.serial(stax.Conv(1, (2, 2)),
stax.Relu(),
stax.Conv(1, (2, 1)))
y = random.normal(key, (8, 2, 5, 1))
predictor = predict.gradient_descent_mse_ensemble(kernel_fn, x, y)
for t in [None, np.array([0., 1., 10.])]:
with self.subTest(t=t):
y_none = predictor(t, None, None, compute_cov=True)
y_x = predictor(t, x, None, compute_cov=True)
self._assertAllClose(y_none, y_x, 0.04)
def testGpInference(self):
reg = 1e-5
key = random.PRNGKey(1)
x_train = random.normal(key, (4, 2))
init_fn, apply_fn, kernel_fn_analytic = stax.serial(
stax.Dense(32, 2., 0.5),
stax.Relu(),
stax.Dense(10, 2., 0.5))
y_train = random.normal(key, (4, 10))
for kernel_fn_is_analytic in [True, False]:
if kernel_fn_is_analytic:
kernel_fn = kernel_fn_analytic
else:
_, params = init_fn(key, x_train.shape)
kernel_fn_empirical = empirical.empirical_kernel_fn(apply_fn)
def kernel_fn(x1, x2, get):
return kernel_fn_empirical(x1, x2, get, params)
for get in [None,
'nngp', 'ntk',
('nngp',), ('ntk',),
('nngp', 'ntk'), ('ntk', 'nngp')]:
k_dd = kernel_fn(x_train, None, get)
gp_inference = predict.gp_inference(k_dd, y_train, diag_reg=reg)
gd_ensemble = predict.gradient_descent_mse_ensemble(kernel_fn,
x_train,
y_train,
diag_reg=reg)
for x_test in [None, 'x_test']:
x_test = None if x_test is None else random.normal(key, (8, 2))
k_td = None if x_test is None else kernel_fn(x_test, x_train, get)
for compute_cov in [True, False]:
with self.subTest(kernel_fn_is_analytic=kernel_fn_is_analytic,
get=get,
x_test=x_test if x_test is None else 'x_test',
compute_cov=compute_cov):
if compute_cov:
nngp_tt = (True if x_test is None else
kernel_fn(x_test, None, 'nngp'))
else:
nngp_tt = None
out_ens = gd_ensemble(None, x_test, get, compute_cov)
out_ens_inf = gd_ensemble(np.inf, x_test, get, compute_cov)
self._assertAllClose(out_ens_inf, out_ens, 0.08)
if (get is not None and
'nngp' not in get and
compute_cov and
k_td is not None):
with self.assertRaises(ValueError):
out_gp_inf = gp_inference(get=get, k_test_train=k_td,
nngp_test_test=nngp_tt)
else:
out_gp_inf = gp_inference(get=get, k_test_train=k_td,
nngp_test_test=nngp_tt)
self.assertAllClose(out_ens, out_gp_inf)
def testPredictOnCPU(self):
x_train = random.normal(random.PRNGKey(1), (4, 4, 4, 2))
x_test = random.normal(random.PRNGKey(1), (8, 4, 4, 2))
y_train = random.uniform(random.PRNGKey(1), (4, 2))
_, _, kernel_fn = stax.serial(
stax.Conv(1, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(1))
for store_on_device in [False, True]:
for device_count in [0, 1]:
for get in ['ntk', 'nngp', ('nngp', 'ntk'), ('ntk', 'nngp')]:
for x in [None, 'x_test']:
with self.subTest(
store_on_device=store_on_device,
device_count=device_count,
get=get,
x=x):
kernel_fn_batched = batch.batch(kernel_fn, 2, device_count,
store_on_device)
predictor = predict.gradient_descent_mse_ensemble(
kernel_fn_batched, x_train, y_train)
x = x if x is None else x_test
predict_none = predictor(None, x, get, compute_cov=True)
predict_inf = predictor(np.inf, x, get, compute_cov=True)
self.assertAllClose(predict_none, predict_inf)
if x is not None:
on_cpu = (not store_on_device or
xla_bridge.get_backend().platform == 'cpu')
self.assertEqual(on_cpu, utils.is_on_cpu(predict_inf))
self.assertEqual(on_cpu, utils.is_on_cpu(predict_none))
def testPredictND(self):
n_chan = 6
key = random.PRNGKey(1)
im_shape = (5, 4, 3)
n_train = 2
n_test = 2
x_train = random.normal(key, (n_train,) + im_shape)
y_train = random.uniform(key, (n_train, 3, 2, n_chan))
init_fn, apply_fn, _ = stax.Conv(n_chan, (3, 2), (1, 2))
_, params = init_fn(key, x_train.shape)
fx_train_0 = apply_fn(params, x_train)
for trace_axes in [(),
(-1,),
(-2,),
(-3,),
(0, 1),
(2, 3),
(2,),
(1, 3),
(0, -1),
(0, 0, -3),
(0, 1, 2, 3),
(0, 1, -1, 2)]:
for ts in [None, np.arange(6).reshape((2, 3))]:
for x in [None, 'x_test']:
with self.subTest(trace_axes=trace_axes, ts=ts, x=x):
t_shape = ts.shape if ts is not None else ()
y_test_shape = t_shape + (n_test,) + y_train.shape[1:]
y_train_shape = t_shape + y_train.shape
x = x if x is None else random.normal(key, (n_test,) + im_shape)
fx_test_0 = None if x is None else apply_fn(params, x)
kernel_fn = empirical.empirical_kernel_fn(apply_fn,
trace_axes=trace_axes)
# TODO(romann): investigate the SIGTERM error on CPU.
# kernel_fn = jit(kernel_fn, static_argnums=(2,))
ntk_train_train = kernel_fn(x_train, None, 'ntk', params)
if x is not None:
ntk_test_train = kernel_fn(x, x_train, 'ntk', params)
loss = lambda x, y: 0.5 * np.mean(x - y)**2
predict_fn_mse = predict.gradient_descent_mse(ntk_train_train,
y_train,
trace_axes=trace_axes)
predict_fn_mse_ensemble = predict.gradient_descent_mse_ensemble(
kernel_fn, x_train, y_train, trace_axes=trace_axes,
params=params
)
if x is None:
p_train_mse = predict_fn_mse(ts, fx_train_0)
else:
p_train_mse, p_test_mse = predict_fn_mse(
ts, fx_train_0, fx_test_0, ntk_test_train)
self.assertAllClose(y_test_shape, p_test_mse.shape)
self.assertAllClose(y_train_shape, p_train_mse.shape)
p_nngp_mse_ens, p_ntk_mse_ens = predict_fn_mse_ensemble(
ts, x, ('nngp', 'ntk'), compute_cov=True)
ref_shape = y_train_shape if x is None else y_test_shape
self.assertAllClose(ref_shape, p_ntk_mse_ens.mean.shape)
self.assertAllClose(ref_shape, p_nngp_mse_ens.mean.shape)
if ts is not None:
predict_fn = predict.gradient_descent(
loss, ntk_train_train, y_train, trace_axes=trace_axes)
if x is None:
p_train = predict_fn(ts, fx_train_0)
else:
p_train, p_test = predict_fn(
ts, fx_train_0, fx_test_0, ntk_test_train)
self.assertAllClose(y_test_shape, p_test.shape)
self.assertAllClose(y_train_shape, p_train.shape)
@jtu.parameterized.named_parameters(
jtu.cases_from_list({
'testcase_name':
'_train={}_network={}_logits={}_{}'.format(
train, network, out_logits, name),
'train_shape':
train,
'network':
network,
'out_logits':
out_logits,
'fn_and_kernel':
fn
} for train, network in zip(TRAIN_SHAPES, NETWORK)
for out_logits in OUTPUT_LOGITS
for name, fn in KERNELS.items()))
def testMaxLearningRate(self, train_shape, network, out_logits,
fn_and_kernel):
key = random.PRNGKey(0)
key, split = random.split(key)
if len(train_shape) == 2:
train_shape = (train_shape[0] * 5, train_shape[1] * 10)
else:
train_shape = (16, 8, 8, 3)
x_train = random.normal(split, train_shape)
key, split = random.split(key)
y_train = np.array(
random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32)
# Regress to an MSE loss.
loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2)
grad_loss = jit(grad(loss))
def get_loss(opt_state):
return loss(get_params(opt_state), x_train)
steps = 20
for lr_factor in [0.5, 3.]:
params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)
g_dd = ntk(x_train, None, 'ntk')
step_size = predict.max_learning_rate(
g_dd, y_train_size=y_train.size) * lr_factor
opt_init, opt_update, get_params = optimizers.sgd(step_size)
opt_state = opt_init(params)
init_loss = get_loss(opt_state)
for i in range(steps):
params = get_params(opt_state)
opt_state = opt_update(i, grad_loss(params, x_train), opt_state)
trained_loss = get_loss(opt_state)
loss_ratio = trained_loss / (init_loss + 1e-12)
if lr_factor == 3.:
if not math.isnan(loss_ratio):
self.assertGreater(loss_ratio, 10.)
else:
self.assertLess(loss_ratio, 0.1)
if __name__ == '__main__':
absltest.main()