forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_custom_tf_workflow.py
81 lines (59 loc) · 2 KB
/
demo_custom_tf_workflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
import tensorflow as tf
from keras_core import Model
from keras_core import backend
from keras_core import initializers
from keras_core import layers
from keras_core import operations as ops
from keras_core import optimizers
class MyDense(layers.Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
w_shape = (input_dim, self.units)
w_value = initializers.GlorotUniform()(w_shape)
self.w = backend.Variable(w_value, name="kernel")
b_shape = (self.units,)
b_value = initializers.Zeros()(b_shape)
self.b = backend.Variable(b_value, name="bias")
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
def call(self, x):
x = self.dense1(x)
x = self.dense2(x)
return self.dense3(x)
def Dataset():
for _ in range(20):
yield (
np.random.random((32, 128)).astype("float32"),
np.random.random((32, 16)).astype("float32"),
)
def loss_fn(y_true, y_pred):
return ops.sum((y_true - y_pred) ** 2)
model = MyModel(hidden_dim=256, output_dim=16)
optimizer = optimizers.SGD(learning_rate=0.0001)
dataset = Dataset()
######### Custom TF workflow ###############
@tf.function(jit_compile=True)
def train_step(data):
x, y = data
with tf.GradientTape() as tape:
y_pred = model(x)
loss = loss_fn(y, y_pred)
# !! Glitch to be resolved !!
gradients = tape.gradient(
loss, [v.value for v in model.trainable_variables]
)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
for data in dataset:
loss = train_step(data)
print("Loss:", float(loss))