Skip to content

Latest commit

 

History

History
48 lines (35 loc) · 2.31 KB

README.md

File metadata and controls

48 lines (35 loc) · 2.31 KB

MASF

Domain Generalization via Model-Agnostic Learning of Semantic Features

We study the challenging problem of domain generalization, i.e., training a model on multi-domain source data such that it can directly generalize to unseen target domains. We adopta model-agnostic learning paradigm with gradient-based meta-train and meta-testprocedures to expose the optimization to domain shift. Further, we introduce two complementary losses which explicitly regularize the semantic structure ofthe feature space. Globally, we align a derived soft confusion matrix to preservegeneral knowledge about inter-class relationships. Locally, we promote domain-independent class-specific cohesion and separation of sample features with ametric-learning component.

This is the reference implementation of the domain generalization method described in our paper:

@inproceedings{dou2019domain,
    author = {Qi Dou and Daniel C. Castro and Konstantinos Kamnitsas and Ben Glocker},
    title = {Domain Generalization via Model-Agnostic Learning of Semantic Features},
    booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
    year = {2019},
}

If you make use of the code, please cite the paper in any resulting publications.

Setup

Check dependencies in requirements.txt, and necessarily run

pip install -r requirements.txt

Running MASF

Download PACS dataset from here, put in dataroot /path/to/PACS_dataset, put the .txt files in '/path/to/image/filelist'
Download the ImageNet pretrained AlexNet weights bvlc_alexnet.npy from here.
To run masf with target domain as art_painting

python main.py --dataset pacs --target_domain art_painting --inner_lr 1e-5 --outer_lr 1e-5 --metric_lr 1e-5 --margin 20

Monitoring training with Tensorboard

Tensorboard logs of losses and gradients are stored in /log/, to observe it run

tensorboard --logdir {/log/}

Running on medical data

To run on medical dataset, replace functions of construct_alexnet_weights() and forward_alexnex() to construct_unet_weights() and forward_unet() demoed in the medical folder