KD3A: Unsupervised Multi-Source Decentralized Domain Adaptation via Knowledge Distillation (Accepted at ICML 2021)
Here is the official implementation of the model KD3A
in paper KD3A: Unsupervised Multi-Source Decentralized Domain Adaptation via Knowledge Distillation.
Python Environment: >= 3.6
torch >= 1.2.0
torchvision >= 0.4.0
tensorbard >= 2.0.0
numpy
yaml
We need users to declare a base path
to store the dataset as well as the log of training procedure. The directory structure should be
base_path
│
└───dataset
│ │ DigitFive
│ │ mnist_data.mat
│ │ mnistm_with_label.mat
| | svhn_train_32x32.mat
│ │ ...
│ │ DomainNet
│ │ ...
│ │ OfficeCaltech10
│ │ ...
| | Office31
| | ...
| | AmazonReview
| | ...
└───trained_model_1
│ │ parmater
│ │ runs
└───trained_model_2
│ │ parmater
│ │ runs
...
└───trained_model_n
│ │ parmater
│ │ runs
Our framework now support five multi-source domain adaptation datasets: DigitFive, DomainNet, AmazonReview, OfficeCaltech10 and Office31
.
-
DigitFive
The DigitFive dataset can be accessed in Google Drive.
-
DomainNet
VisDA2019 provides the DomainNet dataset.
-
AmazonReview
The AmazonReview dataset can be accessed in Google Drive.
The configuration files can be found under the folder ./config
, and we provide four config files with the format .yaml
. To perform the unsupervised multi-source decentralized domain adaptation on the specific dataset (e.g., DomainNet), please use the following commands:
python main.py --config DomainNet.yaml --target-domain clipart -bp base_path
The training process for DomainNet is as follows.
In training process, our model will record the domain weights as well as the accuracy for target domain as
Source Domains :['infograph', 'painting', 'quickdraw', 'real', 'sketch']
Domain Weight : [0.1044, 0.3263, 0.0068, 0.2531, 0.2832]
Target Domain clipart Accuracy Top1 : 0.726 Top5: 0.902
-
Irrelevant Domains
We view quickdraw as the irrelevant domain, and the K3DA assigns low weights to it in training process.
-
Malicious Domains
We use the poisoning attack with level
to create malicious domains. The related settings in the configuration files is as follows: UMDAConfig: malicious: attack_domain: "real" attack_level: 0.3
With this setting, we will perform poisoning attack in the source domain
real
withmislabeled samples.
We also provide the settings in .yaml
config files to perform model aggregation with communication rounds
UMDAConfig:
communication_rounds: 1
The communication rounds can be set into
If you find this useful in your work please consider citing:
@InProceedings{pmlr-v139-feng21f,
title = {KD3A: Unsupervised Multi-Source Decentralized Domain Adaptation via Knowledge Distillation},
author = {Feng, Haozhe and You, Zhaoyang and Chen, Minghao and Zhang, Tianye and Zhu, Minfeng and Wu, Fei and Wu, Chao and Chen, Wei},
booktitle = {Proceedings of the 38th International Conference on Machine Learning},
pages = {3274--3283},
year = {2021},
editor = {Meila, Marina and Zhang, Tong},
volume = {139},
series = {Proceedings of Machine Learning Research},
month = {18--24 Jul},
publisher = {PMLR}
}