Skip to content
forked from LFhase/FeAT

[NeurIPS 2023] "Understanding and Improving Feature Learning for Out-of-Distribution Generalization"

Notifications You must be signed in to change notification settings

tmlr-group/FeAT

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FeAT: Feature Augmented Training

Paper Github License License

This repo contains the sample code for reproducing the results of our NeurIPS 2023: Understanding and Improving Feature Learning for Out-of-Distribution Generalization, which has also been presented as spotlight at ICLR DG, and at ICML SCIS Workshop. 😆😆😆

Updates:

  • Camera-ready version of the paper is updated link!
  • Detailed running instructions will be released soon!

What feature does ERM learn for generalization?

Empirical risk minimization (ERM) is the de facto objective adopted in Machine Learning and obtains impressive generalization performance. Nevertheless, ERM is shown to be prone to spurious correlations, and is suspected to learn predictive but spurious features for minimizing the empirical risk. However, recently Rosenfeld et al., 2022;Kirichenko et al., 2022 empirically show that ERM already learn invariant features that hold an invariant relation with the label for in-distribution and Out-of-Distribution (OOD) generalization.

We resolve the puzzle by theoretically proving that ERM essentially learns both spurious and invariant features. Meanwhile, we also find OOD objectives such as IRMv1 can hardly learn new features even at the begining of the optimization. Therefore, when optimizing OOD objectives such as IRMv1, pre-training the model with ERM is usually necessary for satisfactory performance. As shown in the right subfigure, the OOD performance of various OOD objective first grows with more ERM pre-training epochs.

However, ERM has its preference to learning features depending on the inductive biases of the dataset and the architecture. The limited feature learning can pose a bottleneck for OOD generalization. Therefore, we propose Feature Augmented Training (FeAT), that aims to learn all features so long as they are useful for generalization. Iteratively, FeAT divides the training data $\mathcal{D}_{tr}$ into augmentation sets $D^a$ where the features not sufficiently well learned by the model, and the retention sets $D^r$ that contain features already learned by the current model at the round. Learning on the partitioned datasets with FeAT augments the model with new features contained in the growing augmentation sets while retaining the already learned features contained in the retention sets, which will lead the model to learn richer features for OOD training and obtain a better OOD performance.

For more interesting stories of rich feature learning, please read more into the repositories Bonsai, RRL and the blog by Jianyu. 😆

Structure of Codebase

The whole code base contain four parts, corresponding to experiments presented in the paper:

  • ColoredMNIST: Proof of Concept on ColoredMNIST
  • WILDS: Verification of FeAT in WILDS

Dependencies

We are running with cuda=10.2 and python=3.8.12 with the following key libraries:

wilds==2.0.0
torch==1.9.0

ColoredMNIST

The corresponding code is in the folder ColoredMNIST. The code is modified from RFC.

WILDS

The corresponding code is in the folder WILDS. The code is modified from PAIR and spurious-feature-learning.

Misc

If you find our paper and repo useful, please cite our paper:

@inproceedings{
chen2023FeAT,
title={Understanding and Improving Feature Learning for Out-of-Distribution Generalization},
author={Yongqiang Chen and Wei Huang and Kaiwen Zhou and Yatao Bian and Bo Han and James Cheng},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=eozEoAtjG8}
}

About

[NeurIPS 2023] "Understanding and Improving Feature Learning for Out-of-Distribution Generalization"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%