forked from tensorflow/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CustomModel to be the base class for QA task.
PiperOrigin-RevId: 305842035
- Loading branch information
1 parent
2440842
commit 28ac71e
Showing
6 changed files
with
387 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
tensorflow_examples/lite/model_maker/core/task/classification_model_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an 'AS IS' BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import tensorflow.compat.v2 as tf | ||
from tensorflow_examples.lite.model_maker.core import model_export_format as mef | ||
from tensorflow_examples.lite.model_maker.core import test_util | ||
from tensorflow_examples.lite.model_maker.core.task import classification_model | ||
|
||
|
||
class MockClassificationModel(classification_model.ClassificationModel): | ||
|
||
def train(self, train_data, validation_data=None, **kwargs): | ||
pass | ||
|
||
def export(self, **kwargs): | ||
pass | ||
|
||
def evaluate(self, data, **kwargs): | ||
pass | ||
|
||
|
||
class ClassificationModelTest(tf.test.TestCase): | ||
|
||
def test_predict_top_k(self): | ||
input_shape = [24, 24, 3] | ||
num_classes = 2 | ||
model = MockClassificationModel( | ||
model_export_format=mef.ModelExportFormat.TFLITE, | ||
model_spec=None, | ||
index_to_label=['pos', 'neg'], | ||
num_classes=2, | ||
train_whole_model=False, | ||
shuffle=False) | ||
model.model = test_util.build_model(input_shape, num_classes) | ||
data = test_util.get_dataloader(2, input_shape, num_classes) | ||
|
||
topk_results = model.predict_top_k(data, k=2, batch_size=1) | ||
for topk_result in topk_results: | ||
top1_result, top2_result = topk_result[0], topk_result[1] | ||
top1_label, top1_prob = top1_result[0], top1_result[1] | ||
top2_label, top2_prob = top2_result[0], top2_result[1] | ||
|
||
self.assertIn(top1_label, model.index_to_label) | ||
self.assertIn(top2_label, model.index_to_label) | ||
self.assertNotEqual(top1_label, top2_label) | ||
|
||
self.assertLessEqual(top1_prob, 1) | ||
self.assertGreaterEqual(top1_prob, top2_prob) | ||
self.assertGreaterEqual(top2_prob, 0) | ||
|
||
self.assertEqual(top1_prob + top2_prob, 1.0) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
155 changes: 155 additions & 0 deletions
155
tensorflow_examples/lite/model_maker/core/task/custom_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Base custom model that is already retained by data.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import abc | ||
import os | ||
import tempfile | ||
|
||
import tensorflow.compat.v2 as tf | ||
from tensorflow_examples.lite.model_maker.core import compat | ||
from tensorflow_examples.lite.model_maker.core import model_export_format as mef | ||
|
||
DEFAULT_QUANTIZATION_STEPS = 2000 | ||
|
||
|
||
def get_representative_dataset_gen(dataset, num_steps): | ||
|
||
def representative_dataset_gen(): | ||
"""Generates representative dataset for quantized.""" | ||
for image, _ in dataset.take(num_steps): | ||
yield [image] | ||
|
||
return representative_dataset_gen | ||
|
||
|
||
class CustomModel(abc.ABC): | ||
""""The abstract base class that represents a Tensorflow classification model.""" | ||
|
||
def __init__(self, model_export_format, model_spec, shuffle): | ||
"""Initialize a instance with data, deploy mode and other related parameters. | ||
Args: | ||
model_export_format: Model export format such as saved_model / tflite. | ||
model_spec: Specification for the model. | ||
shuffle: Whether the data should be shuffled. | ||
""" | ||
if model_export_format != mef.ModelExportFormat.TFLITE: | ||
raise ValueError('Model export format %s is not supported currently.' % | ||
str(model_export_format)) | ||
|
||
self.model_export_format = model_export_format | ||
self.model_spec = model_spec | ||
self.shuffle = shuffle | ||
self.model = None | ||
|
||
def preprocess(self, sample_data, label): | ||
"""Preprocess the data.""" | ||
# TODO(yuqili): remove this method once preprocess for image classifier is | ||
# also moved to DataLoader part. | ||
return sample_data, label | ||
|
||
@abc.abstractmethod | ||
def train(self, train_data, validation_data=None, **kwargs): | ||
return | ||
|
||
@abc.abstractmethod | ||
def export(self, **kwargs): | ||
return | ||
|
||
def summary(self): | ||
self.model.summary() | ||
|
||
@abc.abstractmethod | ||
def evaluate(self, data, **kwargs): | ||
return | ||
|
||
def _gen_dataset(self, | ||
data, | ||
batch_size=32, | ||
is_training=True, | ||
input_pipeline_context=None): | ||
"""Generates training / validation dataset.""" | ||
# The dataset is always sharded by number of hosts. | ||
# num_input_pipelines is the number of hosts rather than number of cores. | ||
ds = data.dataset | ||
if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1: | ||
ds = ds.shard(input_pipeline_context.num_input_pipelines, | ||
input_pipeline_context.input_pipeline_id) | ||
|
||
ds = ds.map( | ||
self.preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) | ||
|
||
if is_training: | ||
if self.shuffle: | ||
ds = ds.shuffle(buffer_size=min(data.size, 100)) | ||
ds = ds.repeat() | ||
|
||
ds = ds.batch(batch_size) | ||
ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
return ds | ||
|
||
def _export_tflite(self, | ||
tflite_filename, | ||
quantized=False, | ||
quantization_steps=None, | ||
representative_data=None): | ||
"""Converts the retrained model to tflite format and saves it. | ||
Args: | ||
tflite_filename: File name to save tflite model. | ||
quantized: boolean, if True, save quantized model. | ||
quantization_steps: Number of post-training quantization calibration steps | ||
to run. Used only if `quantized` is True. | ||
representative_data: Representative data used for post-training | ||
quantization. Used only if `quantized` is True. | ||
""" | ||
temp_dir = None | ||
if compat.get_tf_behavior() == 1: | ||
temp_dir = tempfile.TemporaryDirectory() | ||
save_path = os.path.join(temp_dir.name, 'saved_model') | ||
self.model.save(save_path, include_optimizer=False, save_format='tf') | ||
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(save_path) | ||
else: | ||
converter = tf.lite.TFLiteConverter.from_keras_model(self.model) | ||
|
||
if quantized: | ||
if quantization_steps is None: | ||
quantization_steps = DEFAULT_QUANTIZATION_STEPS | ||
if representative_data is None: | ||
raise ValueError( | ||
'representative_data couldn\'t be None if model is quantized.') | ||
ds = self._gen_dataset( | ||
representative_data, batch_size=1, is_training=False) | ||
converter.representative_dataset = tf.lite.RepresentativeDataset( | ||
get_representative_dataset_gen(ds, quantization_steps)) | ||
|
||
converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||
converter.inference_input_type = tf.uint8 | ||
converter.inference_output_type = tf.uint8 | ||
converter.target_spec.supported_ops = [ | ||
tf.lite.OpsSet.TFLITE_BUILTINS_INT8 | ||
] | ||
tflite_model = converter.convert() | ||
if temp_dir: | ||
temp_dir.cleanup() | ||
|
||
with tf.io.gfile.GFile(tflite_filename, 'wb') as f: | ||
f.write(tflite_model) | ||
|
||
tf.compat.v1.logging.info('Export to tflite model in %s.', tflite_filename) |
Oops, something went wrong.