forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First commit of learning_unsupervised_learning
- Loading branch information
Showing
21 changed files
with
2,623 additions
and
0 deletions.
There are no files selected for viewing
Validating CODEOWNERS rules …
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pyc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Learning Unsupervised Learning Rules | ||
This repository contains code and weights for the learned update rule | ||
presented in "Learning Unsupervised Learning Rules." At this time, this | ||
code can not meta-train the update rule. | ||
|
||
|
||
### Structure | ||
`run_eval.py` contains the main training loop. This constructs an op | ||
that runs one iteration of the learned update rule and assigns the | ||
results to variables. Additionally, it loads the weights from our | ||
pre-trained model. | ||
|
||
The base model and the update rule architecture definition can be found in | ||
`architectures/more_local_weight_update.py`. For a complete description | ||
of the model, see our [paper](https://arxiv.org/abs/1804.00222). | ||
|
||
### Dependencies | ||
[absl]([https://github.com/abseil/abseil-py), [tensorflow](https://tensorflow.org), [sonnet](https://github.com/deepmind/sonnet) | ||
|
||
### Usage | ||
|
||
First, download the [pre-trained optimizer model weights](https://storage.googleapis.com/learning_unsupervised_learning/200_tf_graph.zip) and extract it. | ||
|
||
```bash | ||
# move to the folder above this folder | ||
cd path_to/research/learning_unsupervised_learning/../ | ||
|
||
# launch the eval script | ||
python -m learning_unsupervised_learning.run_eval \ | ||
--train_log_dir="/tmp/learning_unsupervised_learning" \ | ||
--checkpoint_dir="/path/to/downloaded/model/tf_graph_data.ckpt" | ||
``` | ||
|
||
### Contact | ||
Luke Metz, Niru Maheswaranathan, Github: @lukemetz, @nirum. Email: {lmetz, nirum}@google.com | ||
|
||
|
Empty file.
17 changes: 17 additions & 0 deletions
17
research/learning_unsupervised_learning/architectures/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2018 Google, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
|
||
import more_local_weight_update |
153 changes: 153 additions & 0 deletions
153
research/learning_unsupervised_learning/architectures/common.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2018 Google, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
|
||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import sonnet as snt | ||
import tensorflow as tf | ||
import numpy as np | ||
import collections | ||
from learning_unsupervised_learning import utils | ||
|
||
from tensorflow.python.util import nest | ||
|
||
from learning_unsupervised_learning import variable_replace | ||
|
||
|
||
class LinearBatchNorm(snt.AbstractModule): | ||
"""Module that does a Linear layer then a BatchNorm followed by an activation fn""" | ||
def __init__(self, size, activation_fn=tf.nn.relu, name="LinearBatchNorm"): | ||
self.size = size | ||
self.activation_fn = activation_fn | ||
super(LinearBatchNorm, self).__init__(name=name) | ||
|
||
def _build(self, x): | ||
x = tf.to_float(x) | ||
initializers={"w": tf.truncated_normal_initializer(stddev=0.01)} | ||
lin = snt.Linear(self.size, use_bias=False, initializers=initializers) | ||
z = lin(x) | ||
|
||
scale = tf.constant(1., dtype=tf.float32) | ||
offset = tf.get_variable( | ||
"b", | ||
shape=[1, z.shape.as_list()[1]], | ||
initializer=tf.truncated_normal_initializer(stddev=0.1), | ||
dtype=tf.float32 | ||
) | ||
|
||
mean, var = tf.nn.moments(z, [0], keep_dims=True) | ||
z = ((z - mean) * tf.rsqrt(var + 1e-6)) * scale + offset | ||
|
||
x_p = self.activation_fn(z) | ||
|
||
return z, x_p | ||
|
||
# This needs to work by string name sadly due to how the variable replace | ||
# works and would also work even if the custom getter approuch was used. | ||
# This is verbose, but it should atleast be clear as to what is going on. | ||
# TODO(lmetz) a better way to do this (the next 3 functions: | ||
# _raw_name, w(), b() ) | ||
def _raw_name(self, var_name): | ||
"""Return just the name of the variable, not the scopes.""" | ||
return var_name.split("/")[-1].split(":")[0] | ||
|
||
|
||
@property | ||
def w(self): | ||
var_list = snt.get_variables_in_module(self) | ||
w = [x for x in var_list if self._raw_name(x.name) == "w"] | ||
assert len(w) == 1 | ||
return w[0] | ||
|
||
@property | ||
def b(self): | ||
var_list = snt.get_variables_in_module(self) | ||
b = [x for x in var_list if self._raw_name(x.name) == "b"] | ||
assert len(b) == 1 | ||
return b[0] | ||
|
||
|
||
|
||
class Linear(snt.AbstractModule): | ||
def __init__(self, size, use_bias=True, init_const_mag=True): | ||
self.size = size | ||
self.use_bias = use_bias | ||
self.init_const_mag = init_const_mag | ||
super(Linear, self).__init__(name="commonLinear") | ||
|
||
def _build(self, x): | ||
if self.init_const_mag: | ||
initializers={"w": tf.truncated_normal_initializer(stddev=0.01)} | ||
else: | ||
initializers={} | ||
lin = snt.Linear(self.size, use_bias=self.use_bias, initializers=initializers) | ||
z = lin(x) | ||
return z | ||
|
||
# This needs to work by string name sadly due to how the variable replace | ||
# works and would also work even if the custom getter approuch was used. | ||
# This is verbose, but it should atleast be clear as to what is going on. | ||
# TODO(lmetz) a better way to do this (the next 3 functions: | ||
# _raw_name, w(), b() ) | ||
def _raw_name(self, var_name): | ||
"""Return just the name of the variable, not the scopes.""" | ||
return var_name.split("/")[-1].split(":")[0] | ||
|
||
@property | ||
def w(self): | ||
var_list = snt.get_variables_in_module(self) | ||
if self.use_bias: | ||
assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) | ||
else: | ||
assert len(var_list) == 1, "Found not 1 but %d" % len(var_list) | ||
w = [x for x in var_list if self._raw_name(x.name) == "w"] | ||
assert len(w) == 1 | ||
return w[0] | ||
|
||
@property | ||
def b(self): | ||
var_list = snt.get_variables_in_module(self) | ||
assert len(var_list) == 2, "Found not 2 but %d" % len(var_list) | ||
b = [x for x in var_list if self._raw_name(x.name) == "b"] | ||
assert len(b) == 1 | ||
return b[0] | ||
|
||
|
||
def transformer_at_state(base_model, new_variables): | ||
"""Get the base_model that has been transformed to use the variables | ||
in final_state. | ||
Args: | ||
base_model: snt.Module | ||
Goes from batch to features | ||
new_variables: list | ||
New list of variables to use | ||
Returns: | ||
func: callable of same api as base_model. | ||
""" | ||
assert not variable_replace.in_variable_replace_scope() | ||
|
||
def _feature_transformer(input_data): | ||
"""Feature transformer at the end of training.""" | ||
initial_variables = base_model.get_variables() | ||
replacement = collections.OrderedDict( | ||
utils.eqzip(initial_variables, new_variables)) | ||
with variable_replace.variable_replace(replacement): | ||
features = base_model(input_data) | ||
return features | ||
|
||
return _feature_transformer |
Oops, something went wrong.