Skip to content

Commit

Permalink
OpenFlamingo v2
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Jun 28, 2023
1 parent c2e80b4 commit 5c90779
Show file tree
Hide file tree
Showing 38 changed files with 3,928 additions and 955 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.pt
*.json

wandb/

Expand Down
14 changes: 13 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
## 2.0.0
* Add gradient checkpointing, FullyShardedDataParallel
* Model releases
* (CLIP ViT-L-14 / MPT-1B)
* (CLIP ViT-L-14 / MPT-1B Dolly)
* (CLIP ViT-L-14 / RedPajama-3B)
* (CLIP ViT-L-14 / RedPajama-3B Instruct)
* (CLIP ViT-L-14 / MPT-7B)
* Remove color jitter when training
* Fix cross-attention bug when calling generate()

## 1.0.0

* it works
* Initial code release
* Early model release (CLIP ViT-L-14 / LLaMA-7B)
44 changes: 0 additions & 44 deletions MODEL_CARD.md

This file was deleted.

119 changes: 69 additions & 50 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

[![PyPI version](https://badge.fury.io/py/open_flamingo.svg)](https://badge.fury.io/py/open_flamingo)

[Blog post](https://laion.ai/blog/open-flamingo/) | Paper (coming soon)
Blog posts: [1](https://laion.ai/blog/open-flamingo/), [2]() | Paper (coming soon) | [Demo](https://huggingface.co/spaces/openflamingo/OpenFlamingo)

Welcome to our open source version of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) model! In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models. We also provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) trained on a new [Multimodal C4](https://github.com/allenai/mmc4) dataset. Please refer to our blog post for more details.
Welcome to our open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)!

This repo is still under development, and we hope to release better performing and larger OpenFlamingo models soon. If you have any questions, please feel free to open an issue. We also welcome contributions!
In this repository, we provide a PyTorch implementation for training and evaluating OpenFlamingo models.
If you have any questions, please feel free to open an issue. We also welcome contributions!

# Table of Contents
- [Installation](#installation)
Expand Down Expand Up @@ -35,37 +36,69 @@ or to create a conda environment for running OpenFlamingo, run
conda env create -f environment.yml
```

# Usage
We provide an initial [OpenFlamingo 9B model](https://huggingface.co/openflamingo/OpenFlamingo-9B) using a CLIP ViT-Large vision encoder and a LLaMA-7B language model. In general, we support any [CLIP vision encoder](https://huggingface.co/models?search=clip). For the language model, we support [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.
# Approach
OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context learning.

NOTE: To use LLaMA models, you will need to use this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to HuggingFace format.
## Model architecture
OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.

![OpenFlamingo architecture](docs/flamingo.png)
Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)

# Usage
## Initializing an OpenFlamingo model
We support pretrained vision encoders from the [OpenCLIP](https://github.com/mlfoundations/open_clip) package, which includes OpenAI's pretrained models.
We also support pretrained language models from the `transformers` package, such as [MPT](https://huggingface.co/models?search=mosaicml%20mpt), [RedPajama](https://huggingface.co/models?search=redpajama), [LLaMA](https://huggingface.co/models?search=llama), [OPT](https://huggingface.co/models?search=opt), [GPT-Neo](https://huggingface.co/models?search=gpt-neo), [GPT-J](https://huggingface.co/models?search=gptj), and [Pythia](https://huggingface.co/models?search=pythia) models.

``` python
from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="<path to llama weights in HuggingFace format>",
tokenizer_path="<path to llama tokenizer in HuggingFace format>",
cross_attn_every_n_layers=4
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
cross_attn_every_n_layers=1
)
```

## Released OpenFlamingo models
We have trained the following OpenFlamingo models so far.

|# params|Language model|Vision encoder|Xattn frequency*|COCO 4-shot CIDEr**|VQAv2 4-shot Accuracy**|Weights|
|------------|--------------|--------------|----------|-----------|-------|----|
|3B| mosaicml/mpt-1b-redpajama-200b | openai CLIP ViT-L/14 | 1 | - | 45.9 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b)|
|3B| mosaicml/mpt-1b-redpajama-200b-dolly | openai CLIP ViT-L/14 | 1 | 82.7 | 46.8 |[Link](https://huggingface.co/openflamingo/OpenFlamingo-3B-vitl-mpt1b-langinstruct)|
|4B| togethercomputer/RedPajama-INCITE-Base-3B-v1 | openai CLIP ViT-L/14 | 2 | 81.8 | 48.1| [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b)|
|4B| togethercomputer/RedPajama-INCITE-Instruct-3B-v1 | openai CLIP ViT-L/14 | 2 | 85.8 | 49.1 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-4B-vitl-rpj3b-langinstruct)|
|9B| mosaicml/mpt-7b | openai CLIP ViT-L/14 | 4 | 89.0 | 52.3 | [Link](https://huggingface.co/openflamingo/OpenFlamingo-9B-vitl-mpt7b)|

*\* Xattn frequency refers to the `--cross_attn_every_n_layers` argument.*

*\*\* 4-shot COCO and VQAv2 performances were calculated over a sample of 5000 test split examples, following the [Flamingo paper](https://arxiv.org/abs/2204.14198).*

Note: as part of our v2 release, we have deprecated a previous LLaMA-based checkpoint. However, you can continue to use our older checkpoint using the new codebase.

## Downloading pretrained weights

To instantiate an OpenFlamingo model with one of our released weights, initialize the model as above and use the following code.

```python
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch

checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B", "checkpoint.pt")
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
```

## Generating text
Here is an example of generating text conditioned on interleaved images/text, in this case we will do few-shot image captioning.
Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning.

``` python
from PIL import Image
import requests
import torch

"""
Step 1: Load images
Expand Down Expand Up @@ -95,8 +128,7 @@ query_image = Image.open(
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape
batch_size x num_media x num_frames x channels x height x width.
In this case batch_size = 1, num_media = 3, num_frames = 1
(this will always be one expect for video which we don't support yet),
In this case batch_size = 1, num_media = 3, num_frames = 1,
channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
Expand Down Expand Up @@ -130,45 +162,36 @@ generated_text = model.generate(
print("Generated text: ", tokenizer.decode(generated_text[0]))
```

# Approach
OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. [Multimodal C4](https://github.com/allenai/mmc4)) and can be used to generate text conditioned on interleaved images/text. For example, OpenFlamingo can be used to generate a caption for an image, or to generate a question given an image and a text passage. The benefit of this approach is that we are able to rapidly adapt to new tasks using in-context training.

## Model architecture
OpenFlamingo seeks to fuse a pretrained vision encoder and a language model using cross attention layers. The model architecture is shown below.

![OpenFlamingo architecture](docs/flamingo.png)
Credit: [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model)

# Training
To train a model, modify the following example command, which uses OPT 1.3B as an example LM:
We provide training scripts in `open_flamingo/train`. We provide an example Slurm script in `open_flamingo/scripts/run_train.py`, as well as the following example command:
```
torchrun --nnodes=1 --nproc_per_node=4 train.py \
--run_name flamingo3B \
--lm_path facebook/opt-1.3b \
--tokenizer_path facebook/opt-1.3b \
--dataset_resampled \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--batch_size_mmc4 4 \
--batch_size_laion 8 \
--train_num_samples_mmc4 125000 \
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=6 \
--num_epochs 250 \
--lr_scheduler constant \
--warmup_steps 5000 \
--use_media_placement_augmentation \
--mmc4_textsim_threshold 0.32
torchrun --nnodes=1 --nproc_per_node=4 open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 480 \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--report_to_wandb
```

## Dataset
We expect all our training datasets to be [WebDataset](https://github.com/webdataset/webdataset) shards.
We train our models on the [LAION 2B](https://huggingface.co/datasets/laion/laion2B-en) and [Multimodal C4](https://github.com/allenai/mmc4) datasets. By default the LAION 2B dataset is in WebDataset format if it is downloaded using the [img2dataset tool](https://github.com/rom1504/img2dataset) and Multimodal C4 can be converted to the WebDataset format using this [script](https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/train/convert_mmc4_to_wds.py).
*Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).*

For more details, see our [training README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/train).


# Evaluation
We currently support running evaluations on [COCO](https://cocodataset.org/#home), [VQAv2](https://visualqa.org/index.html), [OKVQA](https://okvqa.allenai.org), [Flickr30k](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset), and [ImageNet](https://image-net.org/index.php). Note that currently these evaluations are ran in validation mode (as specified in the Flamingo paper). We will be adding support for running evaluations in test mode in the future.
An example evaluation script is at `open_flamingo/scripts/run_eval.sh`. Please see our [evaluation README](https://github.com/mlfoundations/open_flamingo/tree/main/open_flamingo/eval) for more details.


To run evaluations on OKVQA you will need to run the following command:
Expand All @@ -177,19 +200,15 @@ import nltk
nltk.download('wordnet')
```

To evaluate the model, run the script at `open_flamingo/scripts/run_eval.sh`

# Future plans
- [ ] Add support for video input
- [ ] Release better performing and larger OpenFlamingo models
- [ ] Expand our evaluation suite
- [ ] Add support for FSDP training

# Team

OpenFlamingo is developed by:

[Anas Awadalla](https://anas-awadalla.streamlit.app/), [Irena Gao](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).
[Anas Awadalla*](https://anas-awadalla.streamlit.app/), [Irena Gao*](https://i-gao.github.io/), [Joshua Gardner](https://homes.cs.washington.edu/~jpgard/), [Jack Hessel](https://jmhessel.com/), [Yusuf Hanafy](https://www.linkedin.com/in/yusufhanafy/), [Wanrong Zhu](https://wanrong-zhu.com/), [Kalyani Marathe](https://sites.google.com/uw.edu/kalyanimarathe/home?authuser=0), [Yonatan Bitton](https://yonatanbitton.github.io/), [Samir Gadre](https://sagadre.github.io/), [Shiori Sagawa](https://cs.stanford.edu/~ssagawa/), [Jenia Jitsev](https://scholar.google.de/citations?user=p1FuAMkAAAAJ&hl=en), [Simon Kornblith](https://simonster.com/), [Pang Wei Koh](https://koh.pw/), [Gabriel Ilharco](https://gabrielilharco.com/), [Mitchell Wortsman](https://mitchellnw.github.io/), [Ludwig Schmidt](https://people.csail.mit.edu/ludwigs/).

The team is primarily from the University of Washington, Stanford, AI2, UCSB, and Google.

Expand Down
Loading

0 comments on commit 5c90779

Please sign in to comment.