Skip to content

Commit

Permalink
Updates to fit Domainbed repo
Browse files Browse the repository at this point in the history
  • Loading branch information
shahtalebi committed Jun 16, 2021
1 parent 1118fb2 commit a8c0a5e
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 106 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/DomainBed.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 27 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# SAND-mask Repository
# Welcome to DomainBed

This repo is the code release for "[SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization](https://arxiv.org/abs/2106.02266)".
DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in [In Search of Lost Domain Generalization](https://arxiv.org/abs/2007.01434).

# Forked from DomainBed
## Current results

This project is mainly developed on top of the DomainBed repository, which is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in [In Search of Lost Domain Generalization](https://arxiv.org/abs/2007.01434).
![Result table](domainbed/results/2020_10_06_7df6f06/results.png)

## Published results

![Agnostic table](domainbed/results/Agnostic_Results.png)
![Oracle table](domainbed/results/Oracle_Results.png)
![Spirals table](domainbed/results/Spirals_Results.png)
Full results for [commit 7df6f06](https://github.com/facebookresearch/DomainBed/tree/7df6f06a6f9062284812a3f174c306218932c5e4) in LaTeX format available [here](domainbed/results/2020_10_06_7df6f06/results.tex).

## Available algorithms

Expand All @@ -28,10 +24,14 @@ The [currently available algorithms](domainbed/algorithms.py) are:
* Conditional Domain Adversarial Neural Network (CDANN, [Li et al., 2018](https://openaccess.thecvf.com/content_ECCV_2018/papers/Ya_Li_Deep_Domain_Generalization_ECCV_2018_paper.pdf))
* Style Agnostic Networks (SagNet, [Nam et al., 2020](https://arxiv.org/abs/1910.11645))
* Adaptive Risk Minimization (ARM, [Zhang et al., 2020](https://arxiv.org/abs/2007.02931)), contributed by [@zhangmarvin](https://github.com/zhangmarvin)
* Variance Risk Extrapolation (VREx, [Krueger et al., 2020](https://arxiv.org/abs/2003.00688)), contributed by [@zdhNarsil](https://github.com/zdhNarsil)
* Representation Self-Challenging (RSC, [Huang et al., 2020](https://arxiv.org/abs/2007.02454)), contributed by [@SirRob1997](https://github.com/SirRob1997)
----
* Spectral Decoupling (SD, [Pezeshki et al., 2020](https://arxiv.org/abs/2011.09468))
* Learning Explanations that are Hard to Vary (AND-Mask, [Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329))
* SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization (SAND-Mask)
* Out-of-Distribution Generalization with Maximal Invariant Predictor (IGA, [Koyama et al., 2020](https://arxiv.org/abs/2008.01883))
* Gradient Matching for Domain Generalization (Fish, [Shi et al., 2021](https://arxiv.org/pdf/2104.09937.pdf))
* Self-supervised Contrastive Regularization (SelfReg, [Kim et al., 2021](https://arxiv.org/abs/2104.09841))
* Smoothed-AND mask (SAND-mask, [Shahtalebi et al., 2021](https://arxiv.org/abs/2106.02266))

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py).

Expand All @@ -45,7 +45,10 @@ The [currently available datasets](domainbed/datasets.py) are:
* PACS ([Li et al., 2017](https://arxiv.org/abs/1710.03077))
* Office-Home ([Venkateswara et al., 2017](https://arxiv.org/abs/1706.07522))
* A TerraIncognita ([Beery et al., 2018](https://arxiv.org/abs/1807.04975)) subset
* Spirals ([Parascandolo et al., 2020](https://arxiv.org/abs/2009.00329))
* DomainNet ([Peng et al., 2019](http://ai.bu.edu/M3SDA/))
* A SVIRO ([Dias Da Cruz et al., 2020](https://arxiv.org/abs/2001.03483)) subset
* WILDS ([Koh et al., 2020](https://arxiv.org/abs/2012.07421)) FMoW ([Christie et al., 2018](https://arxiv.org/abs/1711.07846)) about satellite images
* WILDS ([Koh et al., 2020](https://arxiv.org/abs/2012.07421)) Camelyon17 ([Bandi et al., 2019](https://pubmed.ncbi.nlm.nih.gov/30716025/)) about tumor detection in tissues

Send us a PR to add your dataset! Any custom image dataset with folder structure `dataset/domain/class/image.xyz` is readily usable. While we include some datasets from the [WILDS project](https://wilds.stanford.edu/), please use their [official code](https://github.com/p-lambda/wilds/) if you wish to participate in their leaderboard.

Expand All @@ -62,17 +65,18 @@ Send us a PR to add your dataset! Any custom image dataset with folder structure
Download the datasets:

```sh
python -m domainbed.scripts.download \
--data_dir=/my/datasets/path
python3 -m domainbed.scripts.download \
--data_dir=./domainbed/data
```

Train a model:

```sh
python -m domainbed.scripts.train\
--algorithm SANDMask\
--dataset Spirals\
--test_env 0
python3 -m domainbed.scripts.train\
--data_dir=./domainbed/data/MNIST/\
--algorithm IGA\
--dataset ColoredMNIST\
--test_env 2
```

Launch a sweep:
Expand All @@ -91,10 +95,10 @@ python -m domainbed.scripts.sweep launch\
--data_dir=/my/datasets/path\
--output_dir=/my/sweep/output/path\
--command_launcher MyLauncher\
--algorithms SANDMask\
--datasets Spirals\
--n_hparams 20\
--n_trials 3
--algorithms ERM DANN\
--datasets RotatedMNIST VLCS\
--n_hparams 5\
--n_trials 1
```

After all jobs have either succeeded or failed, you can delete the data from failed jobs with ``python -m domainbed.scripts.sweep delete_incomplete`` and then re-launch them by running ``python -m domainbed.scripts.sweep launch`` again. Specify the same command-line arguments in all calls to `sweep` as you did the first time; this is how the sweep script knows which jobs were launched originally.
Expand Down Expand Up @@ -122,4 +126,4 @@ DATA_DIR=/my/datasets/path python -m unittest discover

## License

This source code is released under the MIT license, included [here](LICENSE).
This source code is released under the MIT license, included [here](LICENSE).
84 changes: 1 addition & 83 deletions domainbed/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
"SVIRO",
# WILDS datasets
"WILDSCamelyon",
"WILDSFMoW",
"Spirals"
"WILDSFMoW"
]

def get_dataset_class(dataset_name):
Expand Down Expand Up @@ -357,84 +356,3 @@ def __init__(self, root, test_envs, hparams):
super().__init__(
dataset, "region", test_envs, hparams['data_augmentation'], hparams)


class Spirals(MultipleDomainDataset):
CHECKPOINT_FREQ = 10
ENVIRONMENTS = [str(i) for i in range(16)]

def __init__(self, root, test_env, hparams):
super().__init__()
self.datasets = []

test_dataset = self.make_tensor_dataset(env='test')
self.datasets.append(test_dataset)
for env in self.ENVIRONMENTS:
env_dataset = self.make_tensor_dataset(env=env, seed=int(env))
self.datasets.append(env_dataset)

self.input_shape = (18,)
self.num_classes = 2

def make_tensor_dataset(self, env, n_examples=1024, n_envs=16, n_revolutions=3, n_dims=16,
flip_first_signature=False,
seed=0):

if env == 'test':
inputs, labels = self.generate_environment(2000,
n_rotations=n_revolutions,
env=env,
n_envs=n_envs,
n_dims_signatures=n_dims,
seed=2 ** 32 - 1
)
else:
inputs, labels = self.generate_environment(n_examples,
n_rotations=n_revolutions,
env=env,
n_envs=n_envs,
n_dims_signatures=n_dims,
seed=seed
)
if flip_first_signature:
inputs[:1, 2:] = -inputs[:1, 2:]

return TensorDataset(torch.tensor(inputs), torch.tensor(labels))

def generate_environment(self, n_examples, n_rotations, env, n_envs,
n_dims_signatures,
seed=None):
"""
env must either be "test" or an int between 0 and n_envs-1
n_dims_signatures: how many dimensions for the signatures (spirals are always 2)
seed: seed for numpy
"""
assert env == 'test' or 0 <= int(env) < n_envs

# Generate fixed dictionary of signatures
rng = np.random.RandomState(seed)

signatures_matrix = rng.randn(n_envs, n_dims_signatures)

radii = rng.uniform(0.08, 1, n_examples)
angles = 2 * n_rotations * np.pi * radii

labels = rng.randint(0, 2, n_examples)
angles = angles + np.pi * labels

radii += rng.uniform(-0.02, 0.02, n_examples)
xs = np.cos(angles) * radii
ys = np.sin(angles) * radii

if env == 'test':
signatures = rng.randn(n_examples, n_dims_signatures)
else:
env = int(env)
signatures_labels = np.array(labels * 2 - 1).reshape(1, -1)
signatures = signatures_matrix[env] * signatures_labels.T

signatures = np.stack(signatures)
mechanisms = np.stack((xs, ys), axis=1)
mechanisms /= mechanisms.std(axis=0) # make approx unit variance (signatures already are)
inputs = np.hstack((mechanisms, signatures))

return inputs.astype(np.float32), labels.astype(np.long)
Binary file removed domainbed/results/Agnostic_Results.png
Binary file not shown.
Binary file removed domainbed/results/Oracle_Results.png
Binary file not shown.
Binary file removed domainbed/results/Spirals_Results.png
Binary file not shown.

0 comments on commit a8c0a5e

Please sign in to comment.