Skip to content

Commit

Permalink
Updated implementation of L2TL.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301932916
  • Loading branch information
soarik authored and copybara-github committed Mar 20, 2020
1 parent ba0a31d commit b5bcea2
Show file tree
Hide file tree
Showing 16 changed files with 319 additions and 145 deletions.
4 changes: 2 additions & 2 deletions l2tl/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Codebase for "Learning to Transfer Learn" (WORK IN PROGRESS)
# Codebase for "Learning to Transfer Learn"

Authors: Linchao Zhu, Sercan O. Arik, Yi Yang, and Tomas Pfister

Expand All @@ -8,4 +8,4 @@ Learning to transfer learn (L2TL), to improve transfer learning on a target data

This repository contains an example implementation of L2TL framework on the task of transferring knowledge from MNIST to SVHN.

To run the experiments that compare training from random initialization, fine-tuning and L2TL, run `bash all_experiments.sh`.
To run the experiments that compare training from random initialization, fine-tuning and L2TL, run `bash run.sh`.
82 changes: 0 additions & 82 deletions l2tl/all_experiments.sh

This file was deleted.

14 changes: 9 additions & 5 deletions l2tl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from absl import flags
import model
import model_utils
import tensorflow.compat.v1 as tf
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
import tensorflow_datasets as tfds

flags.DEFINE_string('ckpt_path', '', 'Path to evaluation checkpoint')
Expand All @@ -36,8 +36,7 @@

NUM_EVAL_IMAGES = {
'mnist': 10000,
'svhn_cropped': 26032,
'svhn_cropped_small': 12000,
'svhn_cropped_small': 6000,
}


Expand All @@ -56,7 +55,12 @@ def model_fn(features, labels, mode, params):

def get_logits():
"""Return the logits."""
network_output = model.conv_model(feature, mode)
network_output = model.conv_model(
feature,
mode,
target_dataset=FLAGS.target_dataset,
src_hw=FLAGS.src_hw,
target_hw=FLAGS.target_hw)
name = FLAGS.cls_dense_name
with tf.variable_scope('target_CLS'):
logits = tf.layers.dense(
Expand Down Expand Up @@ -124,5 +128,5 @@ def make_input_dataset():


if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)
tf.logging.set_verbosity(tf.logging.INFO)
app.run(main)
50 changes: 29 additions & 21 deletions l2tl/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from absl import flags
import model
import model_utils
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.python.estimator import estimator
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tensorflow.python.estimator import estimator
import tensorflow_datasets as tfds

flags.DEFINE_string(
Expand All @@ -38,7 +38,7 @@
'warm_start_ckpt_path', None, 'The path to the checkpoint '
'that will be used before training.')
flags.DEFINE_integer(
'log_step_count_steps', 64, 'The number of steps at '
'log_step_count_steps', 200, 'The number of steps at '
'which the global step information is logged.')
flags.DEFINE_integer('train_steps', 100, 'Number of steps for training.')
flags.DEFINE_float('target_base_learning_rate', 0.001,
Expand All @@ -49,23 +49,21 @@

FLAGS = flags.FLAGS

NUM_TRAIN_IMAGES = {
'mnist': 50000,
'svhn_cropped': 73257,
'svhn_cropped_small': 600,
}


def lr_schedule():
"""Learning rate scheduling."""
num_train_images = NUM_TRAIN_IMAGES[FLAGS.target_dataset]
target_lr = FLAGS.target_base_learning_rate
steps_per_epoch = num_train_images // FLAGS.train_batch_size
current_step = tf.train.get_global_step()
return tf.train.piecewise_constant(current_step, [
steps_per_epoch * 80,
steps_per_epoch * 120,
], [target_lr, target_lr * 0.1, target_lr * 0.01])

if FLAGS.target_dataset == 'mnist':
return tf.train.piecewise_constant(current_step, [
500,
1500,
], [target_lr, target_lr * 0.1, target_lr * 0.01])
else:
return tf.train.piecewise_constant(current_step, [
800,
], [target_lr, target_lr * 0.1])


def get_model_fn():
Expand All @@ -74,7 +72,6 @@ def get_model_fn():
def model_fn(features, labels, mode, params):
"""Returns the model function."""
feature = features['feature']
print(feature)
labels = labels['label']
one_hot_labels = model_utils.get_label(
labels,
Expand All @@ -84,11 +81,20 @@ def model_fn(features, labels, mode, params):

def get_logits():
"""Return the logits."""
avg_pool = model.conv_model(feature, mode)
avg_pool = model.conv_model(
feature,
mode,
target_dataset=FLAGS.target_dataset,
src_hw=FLAGS.src_hw,
target_hw=FLAGS.target_hw)
name = 'final_dense_dst'
with tf.variable_scope('target_CLS'):
logits = tf.layers.dense(
inputs=avg_pool, units=FLAGS.src_num_classes, name=name)
inputs=avg_pool,
units=FLAGS.src_num_classes,
name=name,
kernel_initializer=tf.random_normal_initializer(stddev=.05),
)
return logits

logits = get_logits()
Expand All @@ -112,8 +118,10 @@ def get_logits():
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
finetune_learning_rate = lr_schedule()
optimizer = tf.train.AdamOptimizer(finetune_learning_rate)

optimizer = tf.train.MomentumOptimizer(
learning_rate=finetune_learning_rate,
momentum=0.9,
use_nesterov=True)
train_op = tf.contrib.slim.learning.create_train_op(loss, optimizer)
with tf.variable_scope('finetune'):
train_op = optimizer.minimize(loss, cur_finetune_step)
Expand Down Expand Up @@ -158,7 +166,7 @@ def main(unused_argv):
checkpoint_path = FLAGS.warm_start_ckpt_path
reader = tf.train.NewCheckpointReader(checkpoint_path)
for key in reader.get_variable_to_shape_map():
keep_str = 'Momentum|global_step|finetune_global_step|Adam'
keep_str = 'Momentum|global_step|finetune_global_step|Adam|final_dense_dst'
if not re.findall('({})'.format(keep_str,), key):
var_names.append(key)

Expand Down
31 changes: 31 additions & 0 deletions l2tl/ft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

model_name=finetune_svhn
steps=1200

python3 finetuning.py \
--target_dataset=svhn_cropped_small \
--train_steps=$steps \
--model_dir=trained_models/${model_name} \
--train_batch_size=8 \
--src_num_classes=5 \
--target_base_learning_rate=0.005 \
--warm_start_ckpt_path=trained_models/mnist_pretrain/model.ckpt-2000

python3 evaluate.py \
--ckpt_path=trained_models/${model_name}/model.ckpt-$steps \
--src_num_classes=5 \
--target_dataset=svhn_cropped_small \
--cls_dense_name=final_dense_dst
27 changes: 27 additions & 0 deletions l2tl/mnist_pretrain.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

model_name=mnist_pretrain

python3 finetuning.py \
--target_dataset=mnist \
--train_steps=2000 \
--target_base_learning_rate=0.01 \
--model_dir=trained_models/${model_name}\
--train_batch_size=128

python3 evaluate.py \
--target_dataset=mnist \
--ckpt_path=trained_models/${model_name}/model.ckpt-2000 \
--cls_dense_name=final_dense_dst
20 changes: 12 additions & 8 deletions l2tl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from absl import flags

import tensorflow.compat.v1 as tf
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import

flags.DEFINE_string('source_dataset', 'mnist', 'Name of the source dataset.')
flags.DEFINE_string('target_dataset', 'svhn_cropped_small',
Expand All @@ -35,22 +35,26 @@
FLAGS = flags.FLAGS


def conv_model(features, mode, dataset_name=None, reuse=None):
def conv_model(features,
mode,
target_dataset,
src_hw=28,
target_hw=32,
dataset_name=None,
reuse=None):
"""Architecture of the LeNet model for MNIST."""

def build_network(features, is_training):
"""Returns the network output."""
# Input reshape
if dataset_name == 'mnist' or FLAGS.target_dataset == 'mnist':
input_layer = tf.reshape(features, [-1, FLAGS.src_hw, FLAGS.src_hw, 1])
if dataset_name == 'mnist' or target_dataset == 'mnist':
input_layer = tf.reshape(features, [-1, src_hw, src_hw, 1])
input_layer = tf.pad(input_layer, [[0, 0], [2, 2], [2, 2], [0, 0]])
else:
input_layer = tf.reshape(features,
[-1, FLAGS.target_hw, FLAGS.target_hw, 3])
input_layer = tf.reshape(features, [-1, target_hw, target_hw, 3])
input_layer = tf.image.rgb_to_grayscale(input_layer)

input_layer = tf.reshape(input_layer,
[-1, FLAGS.target_hw, FLAGS.target_hw, 1])
input_layer = tf.reshape(input_layer, [-1, target_hw, target_hw, 1])
input_layer = tf.image.convert_image_dtype(input_layer, dtype=tf.float32)

discard_rate = 0.2
Expand Down
5 changes: 1 addition & 4 deletions l2tl/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v1 as tf
import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import


def metric_fn(labels, logits):
"""Metric function for evaluation."""
predictions = tf.argmax(logits, axis=1)
top_1_accuracy = tf.metrics.accuracy(labels, predictions)
in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
top_5_accuracy = tf.metrics.mean(in_top_5)

return {
'top_1_accuracy': top_1_accuracy,
'top_5_accuracy': top_5_accuracy,
}


Expand Down
2 changes: 1 addition & 1 deletion l2tl/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy>=1.16.2
tensorflow==1.13.1
tensorflow-datasets
tensorflow-datasets==1.3.0
scipy>=1.3.0
tensorflow-probability==0.5.0
tfp-nightly==0.7.0.dev20190529
Loading

0 comments on commit b5bcea2

Please sign in to comment.