Skip to content

Commit

Permalink
Implement distributed inception (tensorflow#44)
Browse files Browse the repository at this point in the history
Implements a distributed trainer for Inception.
  • Loading branch information
jmchen-g authored and mrry committed Apr 13, 2016
1 parent 9a1dfdf commit 84b58a6
Show file tree
Hide file tree
Showing 6 changed files with 842 additions and 302 deletions.
702 changes: 417 additions & 285 deletions inception/README.md

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions inception/inception/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ py_binary(
],
)

py_binary(
name = "imagenet_distributed_train",
srcs = [
"imagenet_distributed_train.py",
],
deps = [
":imagenet_data",
":inception_distributed_train",
],
)

py_binary(
name = "flowers_train",
srcs = [
Expand All @@ -124,6 +135,17 @@ py_library(
],
)

py_library(
name = "inception_distributed_train",
srcs = [
"inception_distributed_train.py",
],
deps = [
":image_processing",
":inception",
],
)

py_binary(
name = "build_image_data",
srcs = ["data/build_image_data.py"],
Expand Down
65 changes: 65 additions & 0 deletions inception/inception/imagenet_distributed_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=line-too-long
"""A binary to train Inception in a distributed manner using multiple systems.
Please see accompanying README.md for details and instructions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from inception import inception_distributed_train
from inception.imagenet_data import ImagenetData

FLAGS = tf.app.flags.FLAGS


def main(unused_args):
assert FLAGS.job_name in ['ps', 'worker'], 'job_name must be ps or worker'

# Extract all the hostnames for the ps and worker jobs to construct the
# cluster spec.
ps_hosts = FLAGS.ps_hosts.split(',')
worker_hosts = FLAGS.worker_hosts.split(',')
tf.logging.info('PS hosts are: %s' % ps_hosts)
tf.logging.info('Worker hosts are: %s' % worker_hosts)

cluster_spec = tf.train.ClusterSpec({'ps': ps_hosts,
'worker': worker_hosts})
server = tf.train.Server(
{'ps': ps_hosts,
'worker': worker_hosts},
job_name=FLAGS.job_name,
task_index=FLAGS.task_id)

if FLAGS.job_name == 'ps':
# `ps` jobs wait for incoming connections from the workers.
server.join()
else:
# `worker` jobs will actually do the work.
dataset = ImagenetData(subset=FLAGS.subset)
assert dataset.data_files()
# Only the chief checks for or creates train_dir.
if FLAGS.task_id == 0:
if not tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.MakeDirs(FLAGS.train_dir)
inception_distributed_train.train(server.target, dataset, cluster_spec)

if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
Loading

0 comments on commit 84b58a6

Please sign in to comment.