PyTorch implementation of DANN (Domain-Adversarial Training of Neural Networks)
Unsupervised Domain Adaptation by Backpropagation
Yaroslav Ganin, Victor Lempitsky
In PMLR-2015
Domain-Adversarial Training of Neural Networks
Yaroslav Ganin et al.
In JMLR-2016
Install library versions that are compatible with your environment.
git clone https://github.com/NaJaeMin92/pytorch-DANN.git
cd pytorch-DANN
conda create -n dann python=3.7
conda activate dann
pip install -r requirements.txt
python=3.7
pytorch=1.12.1
matplotlib=3.2.2
sklearn=1.0.2
Running the code below will execute both source-only
and DANN
training and testing:
python main.py
# You can adjust training settings in 'params.py', including batch size and the number of training epochs.
Our code includes the functionality to visualize t-SNE
, both before and after the process of domain adaptation using sklearn.manifold
.
MNIST -> MNIST-M
Method | Test #1 | Test #2 | Test #3 | Test #4 | Test #5 | Avg. |
---|---|---|---|---|---|---|
Source Accuracy | 89 | 98 | 98 | 90 | 98 | 61.2 |
Target Accuracy | 47 | 56 | 54 | 46 | 53 | 51.2 |
DANN
Method | Test #1 | Test #2 | Test #3 | Test #4 | Test #5 | Avg. |
---|---|---|---|---|---|---|
Source Accuracy | 96 | 96 | 97 | 97 | 96 | 96.4 |
Target Accuracy | 83 | 78 | 80 | 80 | 78 | 79.8 |
Domain Accuracy | 60 | 60 | 61 | 64 | 61 | 61.2 |