forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
jax_integration_test.py
105 lines (88 loc) · 3 KB
/
jax_integration_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from keras_core import backend
from keras_core import initializers
from keras_core.backend import KerasTensor
from keras_core.layers.layer import Layer
from keras_core.operations import numpy as knp
from keras_core.operations.function import Function
class MiniDense(Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
w_shape = (input_dim, self.units)
w_value = initializers.GlorotUniform()(w_shape)
self.w = backend.Variable(w_value)
b_shape = (self.units,)
b_value = initializers.Zeros()(b_shape)
self.b = backend.Variable(b_value)
def call(self, inputs):
return knp.matmul(inputs, self.w) + self.b
class MiniDropout(Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
self.seed_generator = backend.random.SeedGenerator(1337)
def call(self, inputs):
return backend.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
class MiniBatchNorm(Layer):
def __init__(self, name=None):
super().__init__(name=name)
self.epsilon = 1e-5
self.momentum = 0.99
def build(self, input_shape):
shape = (input_shape[-1],)
self.mean = backend.Variable(
initializers.Zeros()(shape), trainable=False
)
self.variance = backend.Variable(
initializers.GlorotUniform()(shape), trainable=False
)
self.beta = backend.Variable(initializers.Zeros()(shape))
self.gamma = backend.Variable(initializers.Ones()(shape))
def call(self, inputs, training=False):
if training:
mean = knp.mean(inputs, axis=(0,)) # TODO: extend to rank 3+
variance = knp.var(inputs, axis=(0,))
outputs = (inputs - mean) / (variance + self.epsilon)
self.variance.assign(
self.variance * self.momentum + variance * (1.0 - self.momentum)
)
self.mean.assign(
self.mean * self.momentum + mean * (1.0 - self.momentum)
)
else:
outputs = (inputs - self.mean) / (self.variance + self.epsilon)
outputs *= self.gamma
outputs += self.beta
return outputs
# Eager call
layer = MiniDense(5)
x = knp.zeros((3, 4))
y = layer(x)
y = MiniBatchNorm()(y, training=True)
y = MiniDropout(0.5)(y)
assert y.shape == (3, 5)
assert layer.built
print(layer.variables)
assert len(layer.variables) == 2
# Symbolic call
x = KerasTensor((3, 4))
layer = MiniDense(5)
y = layer(x)
y = MiniBatchNorm()(y, training=True)
y = MiniDropout(0.5)(y)
assert y.shape == (3, 5)
assert layer.built
assert len(layer.variables) == 2
# Symbolic graph building
x = KerasTensor((3, 4))
y = MiniDense(5)(x)
y = MiniBatchNorm()(y, training=True)
y = MiniDropout(0.5)(y)
fn = Function(inputs=x, outputs=y)
y_val = fn(knp.ones((3, 4)))
assert y_val.shape == (3, 5)
print(y_val)