Skip to content

Commit

Permalink
update to python 3.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 246036311
  • Loading branch information
Gil Tabak authored and copybara-github committed Apr 30, 2019
1 parent b94d08f commit 6c41355
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 59 deletions.
1 change: 1 addition & 0 deletions correct_batch_effects_wdn/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Functions to evaluate and summarize metadata by their embeddings.
"""

Expand Down
111 changes: 55 additions & 56 deletions correct_batch_effects_wdn/forgetting_nuisance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Training to forget nuisance variables."""

from __future__ import absolute_import
Expand Down Expand Up @@ -53,9 +54,10 @@
MEAN_ONLY_NETWORK = "MeanOnlyNetwork"
WASSERSTEIN_SQRD_NETWORK = "WassersteinSqrdNetwork"
WASSERSTEIN_CUBED_NETWORK = "WassersteinCubedNetwork"
POSSIBLE_NETWORKS = [WASSERSTEIN_NETWORK, WASSERSTEIN_2_NETWORK,
WASSERSTEIN_SQRD_NETWORK, WASSERSTEIN_CUBED_NETWORK,
MEAN_ONLY_NETWORK]
POSSIBLE_NETWORKS = [
WASSERSTEIN_NETWORK, WASSERSTEIN_2_NETWORK, WASSERSTEIN_SQRD_NETWORK,
WASSERSTEIN_CUBED_NETWORK, MEAN_ONLY_NETWORK
]

FLAGS = flags.FLAGS

Expand All @@ -65,9 +67,8 @@
flags.DEFINE_integer("num_steps", None, "Number of steps (after pretrain).")
flags.DEFINE_integer("disc_steps_per_training_step", None, "Number critic steps"
"to use per main training step.")
flags.DEFINE_enum(
"network_type", "WassersteinNetwork", POSSIBLE_NETWORKS,
"Network to use. Can be WassersteinNetwork.")
flags.DEFINE_enum("network_type", "WassersteinNetwork", POSSIBLE_NETWORKS,
"Network to use. Can be WassersteinNetwork.")
flags.DEFINE_integer("batch_n", 10, "Number of points to use per minibatch"
"for each loss.")
flags.DEFINE_float("learning_rate", 1e-4, "Initial learning rate to use.")
Expand All @@ -93,8 +94,9 @@
"transformation.")
flags.DEFINE_integer("seed", 42, "Seed to use for numpy.")
flags.DEFINE_integer("tf_seed", 42, "Seed to use for tensorflow.")
flags.DEFINE_float("cov_fix", 0.001, "Multiple of identity to add if using"
"Wasserstein-2 distance.")
flags.DEFINE_float(
"cov_fix", 0.001, "Multiple of identity to add if using"
"Wasserstein-2 distance.")

################################################################################
##### Functions and classes for storing and retrieving data
Expand Down Expand Up @@ -449,8 +451,8 @@ def encoding_to_multiplexed_encoding(self, arr, encoding):
ones = np.ones((num_points, 1))
values_with_ones = np.hstack((arr, ones))
## TODO(tabakg): make this faster
multiplexed_values = self._zeros_maker((num_points,
(num_dim + 1) * num_categories))
multiplexed_values = self._zeros_maker(
(num_points, (num_dim + 1) * num_categories))
for row_idx, (value, enc) in enumerate(zip(values_with_ones, encoding)):
bin_num = self.encoding_to_num[tuple(enc)]
multiplexed_values[row_idx, bin_num * (num_dim + 1):(bin_num + 1) *
Expand Down Expand Up @@ -596,7 +598,8 @@ def wasserstein_distance(x_,
## gradient penalty
(gradient,) = tf.gradients(discriminator_model(x_hat), [x_hat])
gradient_penalty = penalty_lambda * tf.square(
tf.maximum(0.0, tf.norm(gradient, ord=2) - 1.0))
tf.maximum(0.0,
tf.norm(gradient, ord=2) - 1.0))

## calculate discriminator's loss
disc_loss = (
Expand Down Expand Up @@ -662,10 +665,10 @@ def wasserstein_2_distance(x_, y_, mean_only=False):
cov_y_ = cov_tf(mean_y_, y_) + FLAGS.cov_fix * tf.eye(FLAGS.feature_dim)
sqrt_cov_x_ = mat_sqrt_tf(cov_x_)

prod = tf.matmul(
tf.matmul(sqrt_cov_x_, cov_y_), sqrt_cov_x_)
return transform.sum_of_square(mean_x_ - mean_y_) + tf.trace(
cov_x_ + cov_y_ - 2.0 * mat_sqrt_tf(prod))
prod = tf.matmul(tf.matmul(sqrt_cov_x_, cov_y_), sqrt_cov_x_)
return transform.sum_of_square(mean_x_ -
mean_y_) + tf.trace(cov_x_ + cov_y_ -
2.0 * mat_sqrt_tf(prod))


class Network(object):
Expand Down Expand Up @@ -722,8 +725,8 @@ def __init__(self, holder, feature_dim, batch_n):
self._feature_dim = feature_dim
if self._input_dim != self._feature_dim:
raise ValueError("Currently only supporting feature_dim == input_dim. "
"But we have feature_dim = %s and input_dim = %s"
% (feature_dim, holder.input_dim))
"But we have feature_dim = %s and input_dim = %s" %
(feature_dim, holder.input_dim))
self._batch_n = batch_n


Expand Down Expand Up @@ -803,26 +806,24 @@ class WassersteinNetwork(Network):
features.
b (dict): Mapping from keys to b tensorflow tensors going from inputs to
features.
ignore_disc (bool): If this is true, do not train a discriminator.
This should be set to True when using a distance that does not need to
be learned, e.g. the Wasserstein-2 distance.
ignore_disc (bool): If this is true, do not train a discriminator. This
should be set to True when using a distance that does not need to be
learned, e.g. the Wasserstein-2 distance.
"""

def __init__(
self,
holder,
feature_dim,
batch_n,
target_levels,
nuisance_levels,
layer_width=2,
num_layers=2,
lambda_mean=0.,
lambda_cov=0.,
power=1.,
ignore_disc=False,
mean_only=False
):
def __init__(self,
holder,
feature_dim,
batch_n,
target_levels,
nuisance_levels,
layer_width=2,
num_layers=2,
lambda_mean=0.,
lambda_cov=0.,
power=1.,
ignore_disc=False,
mean_only=False):
"""Inits WassersteinNetwork.
Args:
Expand All @@ -844,9 +845,9 @@ def __init__(
lambda_mean (float): penalty term for the mean term of the transformation.
lambda_cov (float): penalty term for the cov term of the transformation.
power (float): power of each pair-wise wasserstein distance to use.
ignore_disc (bool): If this is true, do not train a discriminator.
This should be set to True when using a distance that does not need to
be learned, e.g. the Wasserstein-2 distance.
ignore_disc (bool): If this is true, do not train a discriminator. This
should be set to True when using a distance that does not need to be
learned, e.g. the Wasserstein-2 distance.
mean_only (bool): Using the Wasserstein-2 distance, but only the mean
component (i.e. no covariance dependence). This is for experimental
purposes.
Expand Down Expand Up @@ -886,12 +887,11 @@ def __init__(
])))

## Map from each possible target key to all input keys that generated it.
self._keys_for_targets = {
target: [ ## pylint: disable=g-complex-comprehension
s for s in shufflers if target == get_filtered_key(
s[INPUT_KEY_INDEX], self._target_levels)
] for target in self._unique_targets
}
self._keys_for_targets = collections.defaultdict(list)
for target in self._unique_targets:
for s in shufflers:
if target == get_filtered_key(s[INPUT_KEY_INDEX], self._target_levels):
self._keys_for_targets[target].append(s)

## Generate input placeholders.
self._x_inputs, self._x_vals = self.get_input_vals()
Expand Down Expand Up @@ -994,8 +994,8 @@ def pairwise_wasserstein(self, power=1.):
self._num_layers,
self._batch_n,
seed=None)
wass_loss_target[target] += tf.math.pow(
wass_dists, power) / normalization
wass_loss_target[target] += tf.math.pow(wass_dists,
power) / normalization
grad_pen_target[target] += grad_pens / normalization

return wass_loss_target, grad_pen_target
Expand Down Expand Up @@ -1026,14 +1026,15 @@ def pairwise_wasserstein_2(self, mean_only=False):
wass_loss_target[target] = 0
normalization = num_per_target * (num_per_target - 1) / 2.
## Iterate through all pairs of nuisance for a given target
for i in xrange(num_per_target):
for j in xrange(i + 1, num_per_target):
for i in range(num_per_target):
for j in range(i + 1, num_per_target):
key_i = tuple(self._keys_for_targets[target][i])
key_j = tuple(self._keys_for_targets[target][j])

## Generate W2 distance and gradient penalty.
wass_2_dists = wasserstein_2_distance(
self._features[key_i], self._features[key_j],
self._features[key_i],
self._features[key_j],
mean_only=mean_only) / normalization
wass_loss_target[target] += wass_2_dists

Expand Down Expand Up @@ -1159,9 +1160,7 @@ def train(self,
random_nums = []

def do_loss_without_step():
"""Get losses and update loss_hist and gran_pen_hist, no training step.
"""
"""Get losses and update loss_hist and gran_pen_hist, no training step."""

feed_dict = {}
for key, shuffler in self.holder.data_shufflers.items():
Expand Down Expand Up @@ -1209,8 +1208,8 @@ def do_train_step(trainer, increment_global_step_op, train=True):
loss_val = sess.run([loss], feed_dict=feed_dict)[0]
grad_pen_val = None
else:
loss_val, grad_pen_val = sess.run(
[loss, grad_pen], feed_dict=feed_dict)
loss_val, grad_pen_val = sess.run([loss, grad_pen],
feed_dict=feed_dict)
## if trainer is not ran, increment global step anyway.
step = sess.run(increment_global_step_op)

Expand Down Expand Up @@ -1271,8 +1270,8 @@ def do_train_step(trainer, increment_global_step_op, train=True):
pos_dist = True ## if ignoring discriminator, this is not an issue
else:
pos_dist = (-loss_hist[-1] > 0)
step = do_train_step(input_trainer, increment_global_step_op,
train=pos_dist)
step = do_train_step(
input_trainer, increment_global_step_op, train=pos_dist)
main_step = step - num_steps_pretrain
sv.stop()

Expand Down
2 changes: 1 addition & 1 deletion correct_batch_effects_wdn/forgetting_nuisance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def testWassersteinNetwork(self):
[network._features, network._x_vals], feed_dict=feed_dict)

## make sure each input used in every batch came from the actual inputs
for key, vals in x_vals.iteritems():
for key, vals in x_vals.items():
for row in vals: # iterate over every element
## identify distance from closest element in inputs
differences = [la.norm(np.array(row) - np.array(candidates))
Expand Down
5 changes: 3 additions & 2 deletions correct_batch_effects_wdn/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""IO related utilities."""

from __future__ import absolute_import
Expand Down Expand Up @@ -45,7 +46,7 @@ def write_dataframe_to_hdf5(df, path, complib='zlib', complevel=5, key='data'):
store[key] = df
# pylint: disable=protected-access
buf = store._handle.get_file_image()
with gfile.GFile(path, 'w') as f:
with gfile.GFile(path, 'wb') as f:
f.write(buf)


Expand All @@ -59,7 +60,7 @@ def read_dataframe_from_hdf5(path, key='data'):
Returns:
pandas.DataFrame loaded from the HDF5 file.
"""
with gfile.GFile(path, 'r') as f:
with gfile.GFile(path, 'rb') as f:
with pandas.HDFStore(
'in_memory',
mode='r',
Expand Down
1 change: 1 addition & 0 deletions correct_batch_effects_wdn/ljosa_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Pre-processing step for Ljosa embeddings.
The goal is to replicate the work done in the Ljosa paper
Expand Down
1 change: 1 addition & 0 deletions correct_batch_effects_wdn/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Functions to transform raw embeddings into more meaningful representations.
"""

Expand Down

0 comments on commit 6c41355

Please sign in to comment.