This is a PyTorch implementation of the paper Disentangled Representation Learning for Text-Video Retrieval:
@Article{DRLTVR2022,
author = {Qiang Wang and Yanhao Zhang and Yun Zheng and Pan Pan and Xian-Sheng Hua},
journal = {arXiv:2203.07111},
title = {Disentangled Representation Learning for Text-Video Retrieval},
year = {2022},
}
- Setup
- Fine-tuning code
- Visualization demo
git clone https://github.com/foolwood/DRL.git
cd DRL
conda create -n drl python=3.9
conda activate drl
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 -f https://download.pytorch.org/whl/torch_stable.html
cd tvr/models
wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
# wget https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt
# wget https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt
cd data/MSR-VTT
wget https://www.robots.ox.ac.uk/~maxbain/frozen-in-time/data/MSRVTT.zip ; unzip MSRVTT.zip
mv MSRVTT/videos/all ./videos ; mv MSRVTT/annotation/MSR_VTT.json ./anns/MSRVTT_data.json
- Train on MSR-VTT 1k.
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 \
main.py --do_train 1 --workers 8 --n_display 50 \
--epochs 5 --lr 1e-4 --coef_lr 1e-3 --batch_size 128 --batch_size_val 128 \
--anno_path data/MSR-VTT/anns --video_path data/MSR-VTT/videos --datatype msrvtt \
--max_words 32 --max_frames 12 --video_framerate 1 \
--base_encoder ViT-B/32 --agg_module seqTransf \
--interaction wti --wti_arch 2 --cdcr 3 --cdcr_alpha1 0.11 --cdcr_alpha2 0.0 --cdcr_lambda 0.001 \
--output_dir ckpts/ckpt_msrvtt_wti_cdcr
Reproduce the ablation experiments scripts
configs |
feature | gpus | Text-Video | Video-Text | train time (h) |
||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MdR | MnR | R@1 | R@5 | R@10 | MdR | MnR | ||||
CLIP4Clip | ViT/B-32 | 4 | 42.8 | 72.1 | 81.4 | 2.0 | 16.3 | 44.1 | 70.5 | 80.5 | 2.0 | 11.8 | 10.5 |
zero-shot | ViT/B-32 | 4 | 31.1 | 53.7 | 63.4 | 4.0 | 41.6 | 26.5 | 50.1 | 61.7 | 5.0 | 39.9 | - |
Interaction | |||||||||||||
DP+None | ViT/B-32 | 4 | 42.9 | 70.6 | 81.4 | 2.0 | 15.4 | 43.0 | 71.1 | 81.1 | 2.0 | 11.8 | 2.5 |
DP+seqTransf | ViT/B-32 | 4 | 42.8 | 71.1 | 81.1 | 2.0 | 15.6 | 44.1 | 70.9 | 80.9 | 2.0 | 11.7 | 2.6 |
XTI+None | ViT/B-32 | 4 | 40.5 | 71.1 | 82.6 | 2.0 | 13.6 | 42.7 | 70.8 | 80.2 | 2.0 | 12.5 | 14.3 |
XTI+seqTransf | ViT/B-32 | 4 | 42.4 | 71.3 | 80.9 | 2.0 | 15.2 | 40.1 | 69.2 | 79.6 | 2.0 | 15.8 | 16.8 |
TI+seqTransf | ViT/B-32 | 4 | 44.8 | 73.0 | 82.2 | 2.0 | 13.4 | 42.6 | 72.7 | 82.8 | 2.0 | 9.1 | 2.6 |
WTI+seqTransf | ViT/B-32 | 4 | 46.6 | 73.4 | 83.5 | 2.0 | 13.0 | 45.4 | 73.4 | 81.9 | 2.0 | 9.2 | 2.6 |
Channel DeCorrelation Regularization | |||||||||||||
DP+seqTransf+CDCR | ViT/B-32 | 4 | 43.9 | 71.1 | 81.2 | 2.0 | 15.3 | 42.3 | 70.3 | 81.1 | 2.0 | 11.4 | 2.6 |
TI+seqTransf+CDCR | ViT/B-32 | 4 | 45.8 | 73.0 | 81.9 | 2.0 | 12.8 | 43.3 | 71.8 | 82.7 | 2.0 | 8.9 | 2.6 |
WTI+seqTransf+CDCR | ViT/B-32 | 4 | 47.6 | 73.4 | 83.3 | 2.0 | 12.8 | 45.1 | 72.9 | 83.5 | 2.0 | 9.2 | 2.6 |
Note: the performances are slight boosts due to new hyperparameters.
Run our visualization demo using matplotlib (no GPU needed):
See LICENSE for details.
Our code is partly based on CLIP4Clip.