Skip to content

Commit 668192f

Browse files
Unit test demonstrating prediction from loaded model and export.
Change: 128431529
1 parent 3e7aab5 commit 668192f

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

tensorflow/contrib/learn/python/learn/estimators/linear_test.py

+50
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from __future__ import division
2020
from __future__ import print_function
2121

22+
import tempfile
23+
2224
import numpy as np
2325
import tensorflow as tf
2426

@@ -49,6 +51,54 @@ def input_fn():
4951
self.assertLess(loss2, 0.01)
5052
self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
5153

54+
def testTrainSaveLoad(self):
55+
"""Tests that insures you can save and reload a trained model."""
56+
57+
def input_fn():
58+
return {
59+
'age': tf.constant([1]),
60+
'language': tf.SparseTensor(values=['english'],
61+
indices=[[0, 0]],
62+
shape=[1, 1])
63+
}, tf.constant([[1]])
64+
65+
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
66+
age = tf.contrib.layers.real_valued_column('age')
67+
68+
model_dir = tempfile.mkdtemp()
69+
classifier = tf.contrib.learn.LinearClassifier(
70+
model_dir=model_dir,
71+
feature_columns=[age, language])
72+
classifier.fit(input_fn=input_fn, steps=100)
73+
out1 = classifier.predict(input_fn=input_fn)
74+
del classifier
75+
76+
classifier2 = tf.contrib.learn.LinearClassifier(
77+
model_dir=model_dir,
78+
feature_columns=[age, language])
79+
out2 = classifier2.predict(input_fn=input_fn)
80+
self.assertEqual(out1, out2)
81+
82+
def testExport(self):
83+
"""Tests that export model for servo works."""
84+
85+
def input_fn():
86+
return {
87+
'age': tf.constant([1]),
88+
'language': tf.SparseTensor(values=['english'],
89+
indices=[[0, 0]],
90+
shape=[1, 1])
91+
}, tf.constant([[1]])
92+
93+
language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100)
94+
age = tf.contrib.layers.real_valued_column('age')
95+
96+
export_dir = tempfile.mkdtemp()
97+
classifier = tf.contrib.learn.LinearClassifier(
98+
feature_columns=[age, language])
99+
classifier.fit(input_fn=input_fn, steps=100)
100+
tf.contrib.learn.utils.export.export_estimator(classifier, export_dir)
101+
52102
def testDisableCenteredBias(self):
53103
"""Tests that we can disable centered bias."""
54104

0 commit comments

Comments
 (0)