|
19 | 19 | from __future__ import division
|
20 | 20 | from __future__ import print_function
|
21 | 21 |
|
| 22 | +import tempfile |
| 23 | + |
22 | 24 | import numpy as np
|
23 | 25 | import tensorflow as tf
|
24 | 26 |
|
@@ -49,6 +51,54 @@ def input_fn():
|
49 | 51 | self.assertLess(loss2, 0.01)
|
50 | 52 | self.assertTrue('centered_bias_weight' in classifier.get_variable_names())
|
51 | 53 |
|
| 54 | + def testTrainSaveLoad(self): |
| 55 | + """Tests that insures you can save and reload a trained model.""" |
| 56 | + |
| 57 | + def input_fn(): |
| 58 | + return { |
| 59 | + 'age': tf.constant([1]), |
| 60 | + 'language': tf.SparseTensor(values=['english'], |
| 61 | + indices=[[0, 0]], |
| 62 | + shape=[1, 1]) |
| 63 | + }, tf.constant([[1]]) |
| 64 | + |
| 65 | + language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100) |
| 66 | + age = tf.contrib.layers.real_valued_column('age') |
| 67 | + |
| 68 | + model_dir = tempfile.mkdtemp() |
| 69 | + classifier = tf.contrib.learn.LinearClassifier( |
| 70 | + model_dir=model_dir, |
| 71 | + feature_columns=[age, language]) |
| 72 | + classifier.fit(input_fn=input_fn, steps=100) |
| 73 | + out1 = classifier.predict(input_fn=input_fn) |
| 74 | + del classifier |
| 75 | + |
| 76 | + classifier2 = tf.contrib.learn.LinearClassifier( |
| 77 | + model_dir=model_dir, |
| 78 | + feature_columns=[age, language]) |
| 79 | + out2 = classifier2.predict(input_fn=input_fn) |
| 80 | + self.assertEqual(out1, out2) |
| 81 | + |
| 82 | + def testExport(self): |
| 83 | + """Tests that export model for servo works.""" |
| 84 | + |
| 85 | + def input_fn(): |
| 86 | + return { |
| 87 | + 'age': tf.constant([1]), |
| 88 | + 'language': tf.SparseTensor(values=['english'], |
| 89 | + indices=[[0, 0]], |
| 90 | + shape=[1, 1]) |
| 91 | + }, tf.constant([[1]]) |
| 92 | + |
| 93 | + language = tf.contrib.layers.sparse_column_with_hash_bucket('language', 100) |
| 94 | + age = tf.contrib.layers.real_valued_column('age') |
| 95 | + |
| 96 | + export_dir = tempfile.mkdtemp() |
| 97 | + classifier = tf.contrib.learn.LinearClassifier( |
| 98 | + feature_columns=[age, language]) |
| 99 | + classifier.fit(input_fn=input_fn, steps=100) |
| 100 | + tf.contrib.learn.utils.export.export_estimator(classifier, export_dir) |
| 101 | + |
52 | 102 | def testDisableCenteredBias(self):
|
53 | 103 | """Tests that we can disable centered bias."""
|
54 | 104 |
|
|
0 commit comments