forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model_pickling.py
147 lines (107 loc) · 3.9 KB
/
test_model_pickling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import pytest
import sys
import numpy as np
from numpy.testing import assert_allclose
import keras
from keras import layers
from keras import optimizers
from keras import losses
from keras import metrics
if sys.version_info[0] == 3:
import pickle
else:
import cPickle as pickle
def test_sequential_model_pickling():
model = keras.Sequential()
model.add(layers.Dense(2, input_shape=(3,)))
model.add(layers.RepeatVector(3))
model.add(layers.TimeDistributed(layers.Dense(3)))
model.compile(loss=losses.MSE,
optimizer=optimizers.RMSprop(lr=0.0001),
metrics=[metrics.categorical_accuracy],
sample_weight_mode='temporal')
x = np.random.random((1, 3))
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
out = model.predict(x)
state = pickle.dumps(model)
new_model = pickle.loads(state)
out2 = new_model.predict(x)
assert_allclose(out, out2, atol=1e-05)
# test that new updates are the same with both models
x = np.random.random((1, 3))
y = np.random.random((1, 3, 3))
model.train_on_batch(x, y)
new_model.train_on_batch(x, y)
out = model.predict(x)
out2 = new_model.predict(x)
assert_allclose(out, out2, atol=1e-05)
def test_sequential_model_pickling_custom_objects():
# test with custom optimizer, loss
class CustomSGD(optimizers.SGD):
pass
def custom_mse(*args, **kwargs):
return losses.mse(*args, **kwargs)
model = keras.Sequential()
model.add(layers.Dense(2, input_shape=(3,)))
model.add(layers.Dense(3))
model.compile(loss=custom_mse, optimizer=CustomSGD(), metrics=['acc'])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
state = pickle.dumps(model)
with keras.utils.CustomObjectScope(
{'CustomSGD': CustomSGD, 'custom_mse': custom_mse}):
model = pickle.loads(state)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
def test_functional_model_pickling():
inputs = keras.Input(shape=(3,))
x = layers.Dense(2)(inputs)
outputs = layers.Dense(3)(x)
model = keras.Model(inputs, outputs)
model.compile(loss=losses.MSE,
optimizer=optimizers.Adam(),
metrics=[metrics.categorical_accuracy])
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
out = model.predict(x)
state = pickle.dumps(model)
model = pickle.loads(state)
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
def test_pickling_multiple_metrics_outputs():
inputs = keras.Input(shape=(5,))
x = layers.Dense(5)(inputs)
output1 = layers.Dense(1, name='output1')(x)
output2 = layers.Dense(1, name='output2')(x)
model = keras.Model(inputs=inputs, outputs=[output1, output2])
metrics = {'output1': ['mse', 'binary_accuracy'],
'output2': ['mse', 'binary_accuracy']
}
loss = {'output1': 'mse', 'output2': 'mse'}
model.compile(loss=loss, optimizer='sgd', metrics=metrics)
# assure that model is working
x = np.array([[1, 1, 1, 1, 1]])
out = model.predict(x)
model = pickle.loads(pickle.dumps(model))
out2 = model.predict(x)
assert_allclose(out, out2, atol=1e-05)
def test_pickling_without_compilation():
"""Test pickling model without compiling.
"""
model = keras.Sequential()
model.add(layers.Dense(2, input_shape=(3,)))
model.add(layers.Dense(3))
model = pickle.loads(pickle.dumps(model))
def test_pickling_right_after_compilation():
model = keras.Sequential()
model.add(layers.Dense(2, input_shape=(3,)))
model.add(layers.Dense(3))
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
model._make_train_function()
model = pickle.loads(pickle.dumps(model))
if __name__ == '__main__':
pytest.main([__file__])