forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Release code for submission "Correcting for Batch Effects Using Wasse…
…rstein Distance". PiperOrigin-RevId: 238348173
- Loading branch information
1 parent
3656cf8
commit f5ba49c
Showing
16 changed files
with
4,543 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Correcting for Batch Effects Using Wasserstein Distance | ||
|
||
This directory contains reference code for the paper | ||
[Correcting for Batch Effects Using Wasserstein Distance](https://arxiv.org/abs/1711.00882). | ||
|
||
The code is implemented in Tensorflow and the required packages are listed in | ||
`requirements.txt`. | ||
|
||
## Datasets | ||
The datasets are two different types of embeddings derived from the raw image | ||
dataset: https://data.broadinstitute.org/bbbc/BBBC021/. They are CellProfiler | ||
embeddings and deep neural network embeddings. | ||
|
||
### CellProfiler Embeddings | ||
The original CellProfiler embeddings were downloaded from | ||
http://pubs.broadinstitute.org/ljosa_jbiomolscreen_2013/ as csv files. | ||
|
||
To convert it into a dataframe and save it as an h5 file: | ||
|
||
``` | ||
python -m correct_batch_effects_wdn.ljosa_embeddings_to_h5 \ | ||
--ljosa_data_directory=${LJOSA_DATA_DIRECTORY} | ||
``` | ||
|
||
The h5 file would be saved at | ||
|
||
`${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_462.h5` | ||
|
||
We follow the paper https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3884769/ to | ||
preprocess the CellProfiler embeddings: | ||
|
||
``` | ||
python -m correct_batch_effects_wdn.ljosa_preprocessing \ | ||
--original_df=${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_462.h5 \ | ||
--post_normalization_path=${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_normalized.h5 \ | ||
--post_fa_path=${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_fa.h5 | ||
``` | ||
|
||
This would generate two h5 files. The first file is at | ||
|
||
`${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_normalized.h5`, | ||
|
||
where each dimension of the embeddings has been normalized by percentile | ||
matching. | ||
|
||
The second file is at | ||
|
||
`${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_fa.h5`, | ||
|
||
where the post-normalized embeddings have been projected to embeddings with | ||
dimension 50 by factor analysis. | ||
|
||
### Deep Neural Network Embeddings | ||
Deep neural network embeddings are obtained by running a pipeline on the raw | ||
image dataset. In the pipeline, the raw images are corrected for imaging artifacts, cell patches are obtained by | ||
cell center finding, and a pre-trained deep neural network is applied to the patch | ||
images to obtain embeddings. Each embedding is of dimension 192, with 64 dimensions for | ||
each of the three stains. More details | ||
can be found in the paper https://ai.google/research/pubs/pub46293. Due to the | ||
proprietary reason, the code and generated embeddings cannot be open sourced | ||
here. Readers who are interested in testing the code can instead use the feature vectors | ||
generated from [inception_v3 on TensorFlow Hub](https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1). | ||
|
||
## Model Training | ||
A Wasserstein distance network is trained to correct for batch effects. | ||
### CellProfiler Embeddings | ||
|
||
``` | ||
python -m correct_batch_effects_wdn.forgetting_nuisance \ | ||
--network_type=WassersteinNetwork \ | ||
--input_df="${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_fa.h5" \ | ||
--num_steps_pretrain=100000 \ | ||
--num_steps=5000 \ | ||
--save_dir="${SAVE_DIR}/ljosa_embeddings_post_fa" \ | ||
--disc_steps_per_training_step=50 \ | ||
--checkpoint_interval=2000 \ | ||
--nuisance_levels=batch \ | ||
--batch_n=100 \ | ||
--target_levels=compound \ | ||
--feature_dim=50 \ | ||
--layer_width=2 \ | ||
--num_layers=2 \ | ||
--learning_rate=1e-4 | ||
``` | ||
|
||
### Deep Neural Network Embeddings | ||
|
||
``` | ||
python -m correct_batch_effects_wdn.forgetting_nuisance \ | ||
--network_type=WassersteinNetwork \ | ||
--input_df="${LJOSA_DATA_DIRECTORY}/ljosa_deep_post_tvn.h5" \ | ||
--num_steps_pretrain=100000 \ | ||
--num_steps=5000 \ | ||
--save_dir="${SAVE_DIR}/ljosa_deep_post_tvn" \ | ||
--disc_steps_per_training_step=50 \ | ||
--checkpoint_interval=2000 \ | ||
--nuisance_levels=batch \ | ||
--batch_n=100 \ | ||
--target_levels=compound \ | ||
--feature_dim=192 \ | ||
--layer_width=2 \ | ||
--num_layers=2 \ | ||
--learning_rate=1e-4 | ||
``` | ||
|
||
## Model Evaluation | ||
Model performance is evaluated by a number of metrics, quantifying how much | ||
biological signal is preserved in the embeddings and how much batch effect has | ||
been removed after applying the learned transformation. | ||
### CellProfiler Embeddings | ||
|
||
``` | ||
DF_DIR="${SAVE_DIR}/ljosa_embeddings_post_fa/(('input_df', \ | ||
'ljosa_embeddings_post_fa.h5'), ('network_type', 'WassersteinNetwork'), \ | ||
('num_steps_pretrain', 100000), ('num_steps', 5000), ('batch_n', 100), \ | ||
('learning_rate', 0.0001), ('feature_dim', 50), \ | ||
('disc_steps_per_training_step', 50), ('target_levels', \ | ||
('compound',)), ('nuisance_levels', ('batch',)), ('layer_width', 2), \ | ||
('num_layers', 2), ('lambda_mean', 0.0), ('lambda_cov', 0.0), \ | ||
('cov_fix', 0.001))" | ||
python -m correct_batch_effects_wdn.evaluate_metrics \ | ||
--transformation_file="${DF_DIR}/data.pkl" \ | ||
--input_df="${LJOSA_DATA_DIRECTORY}/ljosa_embeddings_post_fa.h5" \ | ||
--output_file="${DF_DIR}/evals.pkl" \ | ||
--num_bootstrap=200 | ||
``` | ||
|
||
### Deep Neural Network Embeddings | ||
|
||
``` | ||
DF_DIR="${SAVE_DIR}/ljosa_deep_post_tvn/(('input_df', \ | ||
'ljosa_deep_post_tvn.h5'), ('network_type', 'WassersteinNetwork'), \ | ||
('num_steps_pretrain', 100000), ('num_steps', 5000), ('batch_n', 100), \ | ||
('learning_rate', 0.0001), ('feature_dim', 192), \ | ||
('disc_steps_per_training_step', 50), ('target_levels', ('compound',)), \ | ||
('nuisance_levels', ('batch',)), ('layer_width', 2), ('num_layers', 2), \ | ||
('lambda_mean', 0.0), ('lambda_cov', 0.0), ('cov_fix', 0.001))" | ||
python -m correct_batch_effects_wdn.evaluate_metrics \ | ||
--transformation_file="${DF_DIR}/data.pkl" \ | ||
--input_df="${LJOSA_DATA_DIRECTORY}/ljosa_deep_post_tvn.h5" \ | ||
--output_file="${DF_DIR}/evals.pkl" \ | ||
--num_bootstrap=200 | ||
``` | ||
|
||
### Sample Code for Loading `evals.pkl` | ||
|
||
``` | ||
import six.moves.cPickle as pickle | ||
from tensorflow import gfile | ||
def load_contents(file_path): | ||
with gfile.GFile(file_path, mode="r") as f: | ||
contents = f.read() | ||
contents = pickle.loads(contents) | ||
return contents | ||
evals = load_contents(path) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Compute various distance metrics for probability densities.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import sklearn.metrics.pairwise | ||
|
||
|
||
def _combine(v1, v2): | ||
"""Combine a vector and a vector or array into a single vector.""" | ||
return np.concatenate((v1, v2.reshape(-1))) | ||
|
||
|
||
def _split(v, col1, col2): | ||
"""Split a vector into a vector + a vector or array. | ||
The first vector is 1D with col1 columns. The second has col2 columns and | ||
is a 1-D vector if len(v) == col1 + col2 or an array otherwise. | ||
Args: | ||
v: vector to split | ||
col1: number of columns for the first portion | ||
col2: number of columns for the second portion | ||
Returns: | ||
A tuple consisting of the first split vector and the second. | ||
""" | ||
v1 = v[:col1] | ||
v2 = v[col1:] | ||
if len(v2) == col2: | ||
return v1, v2 | ||
return v1, v2.reshape([-1, col2]) | ||
|
||
|
||
def _wrapped_dist_fn(v1, v2, dist_fn=None, dfcol=None, auxcol=None): | ||
"""Wrapper for a distance function that splits the inputs. | ||
This allows us to use distances that require auxiliary quantities with | ||
sklearn's parwise_distances function. | ||
Args: | ||
v1: first input vector - will be split | ||
v2: second input vector - will be split | ||
dist_fn: distance function to call on split vectors | ||
dfcol: number of columns for the first split portion | ||
auxcol: number of columns for the second split portion | ||
Returns: | ||
The value of dist_fn called on the split versions of v1 and v2. | ||
""" | ||
v11, v12 = _split(v1, dfcol, auxcol) | ||
v21, v22 = _split(v2, dfcol, auxcol) | ||
return dist_fn(v11, v21, v12, v22) | ||
|
||
|
||
def matrix(dist_fn, df, aux_df=None, n_jobs=1, **kwds): | ||
"""Compute a distance matrix between rows of a DataFrame. | ||
Args: | ||
dist_fn: A distance function. If aux_df = None, should take 2 Series | ||
as arguments; if aux_df is a data frame, should take 4 Series as | ||
arguments (row1, row2, aux1, aux2). | ||
df: DataFrame for which we want to compute row distances | ||
aux_df: optional auxiliary DataFrame whose rows provide additional | ||
distance function arguments | ||
n_jobs: number of parallel jobs to use in computing the distance matrix. | ||
Note that setting n_jobs > 1 does not work well in Colab. | ||
**kwds: additional keyword arguments are passed to sklearn's | ||
pairwise_distances function | ||
Returns: | ||
A matrix of distances. | ||
""" | ||
dfrow, dfcol = df.shape | ||
if aux_df is not None: | ||
auxrow, auxcol = aux_df.shape | ||
|
||
# aux_df specifies either a set of vectors of variances or arrays of | ||
# covariances for use with the distance functions below. sklearn's | ||
# pairwise distance function doesn't allow for this kind of side info, | ||
# so we need to flatten the side information and append it to the vectors | ||
# in df, then we need to wrap the distance functions so the side info is | ||
# split out before computing distances. | ||
if aux_df is not None: | ||
combined = np.zeros([dfrow, dfcol + int(auxrow / dfrow) * auxcol]) | ||
for i, (idx, row) in enumerate(df.iterrows()): | ||
combined[i, :] = _combine(row.as_matrix(), aux_df.loc[idx].as_matrix()) | ||
kwds.update(dist_fn=dist_fn, dfcol=dfcol, auxcol=auxcol) | ||
dist = sklearn.metrics.pairwise.pairwise_distances( | ||
X=combined, metric=_wrapped_dist_fn, n_jobs=n_jobs, **kwds) | ||
else: | ||
dist = sklearn.metrics.pairwise.pairwise_distances( | ||
X=df.as_matrix(), metric=dist_fn, n_jobs=n_jobs, **kwds) | ||
|
||
return pd.DataFrame(dist, columns=df.index, index=df.index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google Research Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for Distance library.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import pandas.util.testing as pandas_testing | ||
import tensorflow as tf | ||
|
||
from correct_batch_effects_wdn import distance | ||
|
||
|
||
class DistanceTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(DistanceTest, self).setUp() | ||
self.m = pd.DataFrame({ | ||
'v1': [1.0, 3.0], | ||
'v2': [2.0, 2.0]}, index=['a', 'b']) | ||
self.v = pd.DataFrame({ | ||
'v1': [1.0, 3.0], | ||
'v2': [2.0, 1.0]}, index=['a', 'b']) | ||
# cosine distance between m[0,:] and m[1,:] | ||
self.dcos = 1.0 - 7.0/np.sqrt(5.0 * 13.0) | ||
# euclidean distance between m[0,:] and m[1,:] | ||
self.deuc = np.sqrt((-2.0)**2.0 + 0.0**2.0) | ||
|
||
def testCombine(self): | ||
expected = np.array([1.0, 2.0, 3.0, 4.0]) | ||
result = distance._combine(np.array([1.0, 2.0]), np.array([3.0, 4.0])) | ||
np.testing.assert_almost_equal(expected, result) | ||
expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) | ||
result = distance._combine(np.array([1.0, 2.0]), | ||
np.array([[3.0, 4.0], | ||
[5.0, 6.0]])) | ||
np.testing.assert_almost_equal(expected, result) | ||
|
||
def testSplit(self): | ||
expected = (np.array([1.0, 2.0]), np.array([3.0, 4.0])) | ||
result = distance._split(np.array([1.0, 2.0, 3.0, 4.0]), 2, 2) | ||
np.testing.assert_almost_equal(expected[0], result[0]) | ||
np.testing.assert_almost_equal(expected[1], result[1]) | ||
expected = (np.array([1.0, 2.0]), np.array([[3.0, 4.0], | ||
[5.0, 6.0]])) | ||
result = distance._split( | ||
np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), 2, 2) | ||
np.testing.assert_almost_equal(expected[0], result[0]) | ||
np.testing.assert_almost_equal(expected[1], result[1]) | ||
|
||
def testMatrix(self): | ||
expected = pd.DataFrame( | ||
{'a': [0.0, self.dcos], | ||
'b': [self.dcos, 0.0]}, index=['a', 'b']) | ||
pandas_testing.assert_frame_equal( | ||
expected, | ||
distance.matrix('cosine', self.m)) | ||
|
||
expected = pd.DataFrame( | ||
{'a': [0.0, self.deuc], | ||
'b': [self.deuc, 0.0]}, index=['a', 'b']) | ||
pandas_testing.assert_frame_equal( | ||
expected, | ||
distance.matrix('euclidean', self.m)) | ||
|
||
euc = lambda v1, v2: np.sqrt((v2 - v1).dot(v2 - v1)) | ||
pandas_testing.assert_frame_equal( | ||
expected, | ||
distance.matrix(euc, self.m)) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Oops, something went wrong.