Skip to content

This work provides extensive empirical results on training LMs to count. We find that while traditional RNNs trivially achieve inductive counting, Transformers have to rely on positional embeddings to count out-of-domain. Modern RNNs (e.g. rwkv, mamba) also largely underperform traditional RNNs in generalizing counting inductively.

Notifications You must be signed in to change notification settings

zdxdsw/inductive_counting_with_LMs

Repository files navigation

🌟 This is the code repo for experiments performed in Language Models Need Inductive Biases to Count Inductively 🌟

File Structure

In /scripts, we maintain separate folders for different architecture types. Note, LSTM and RNN are subsumed in /scripts/s4.

Python Environments

To support reproducibility for individual sets of experiments, mamba and rwkv have their own environments, while causal_transformer and s4 use a shared env. Thus, we provide instructions for building three environments.

Here's how you setup the shared environment for causal_transformer and s4.

cd <path_to_this_repo> &&
python3 -m venv venv &&
source venv/bin/activate &&
pip install -r requirements.txt &&
cd scripts/s4 &&
pip install -r s4_requirements.txt

Please click these links for building mamba and rwkv environments.

Generate Data

For examples of the input-output formats, there are validation and OOD testing files for each task.

Our training data is generated in this notebook.

Train Models

If this is the first time you use accelerate, and you haven't configured it, please do: accelerate config, and config accordingly.

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Remember to specify output_dir, ckpt_dir, hf_cache_dir in config.py.

Training command:

cd scripts/causal_transformer && # or cd scripts/s4
python run.py --task <task_name> --cuda 0 --port 29500

Notes

  • <task_name> can be choosen from scripts/causal_transformer/config_taskspecific.py, e.g. counting_samesymbol_mod10bos.

  • Model ckpts will be saved to ckpt_dir specified in config.py Model outputs during validation will be saved to output_dir. Specifically, each run will create its own folder under output_dir named by the timestamp, which can be passed to tester.py through the argument "handle".

  • If you're running multiple jobs on the same machine, use different ports. Otherwise, accelerator will complain about busy port.

Test Models

python tester.py --handle <timestamp>

E.g., timestamp = 0522_103640

Cite Us 🙏

@article{chang2024language,
  title={Language Models Need Inductive Biases to Count Inductively},
  author={Chang, Yingshan and Bisk, Yonatan},
  journal={arXiv preprint arXiv:2405.20131},
  year={2024}
}

Acknowledgements

  • Implementation of causal Transformer, as well as its positional embedding variants, is borrowed heavily from huggingface's implementation of gpt-2, t5 and llama.
  • We give credit to the official S4 repo for implementation of s4.
  • We give credit to the official rwkv repo for implementation of rwkv.
  • We give credit to the official mamba repo for implementation of mamba, as well as the mamba-chat repo for setting up the mamba environment.

About

This work provides extensive empirical results on training LMs to count. We find that while traditional RNNs trivially achieve inductive counting, Transformers have to rely on positional embeddings to count out-of-domain. Modern RNNs (e.g. rwkv, mamba) also largely underperform traditional RNNs in generalizing counting inductively.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published