Skip to content

[NeurIPS 2024] Official implementation of the paper "Point-PRC: A Prompt Learning Based Regulation Framework for Generalizable Point Cloud Analysis"

License

Notifications You must be signed in to change notification settings

auniquesun/Point-PRC

Repository files navigation

Point-PRC

This repository is the official implementation of the NeurIPS 2024 paper "Point-PRC: A Prompt Learning Based Regulation Framework for Generalizable Point Cloud Analysis".

Motivation

Motivation of our research: to promote the performances on downstream 3D tasks while maintaining good generalization of large 3D models. The experiments are conducted on ShapeNetCoreV2. ULIP2 can reach 71.22% zero-shot recognition accuracy on this dataset. Recent works built on ULIP-2 introduce lightweight prompt tuning (PT) to further boost target tasks (75.80% accuracy). However, we observe the improvements come at the expenses of a severe drop in 3D domain generalization (e.g., 57.07% accuracy on new classes, much behind 71.22%), and develop a systematic regulation constraint (RC) framework to address this challenge.

Introduction

This paper investigates the 3D domain generalization (3DDG) ability of large 3D models based on prevalent prompt learning. Recent works demonstrate the performances of 3D point cloud recognition can be boosted remarkably by parameter-efficient prompt tuning. However, we observe that the improvement on downstream tasks comes at the expense of a severe drop in 3D domain generalization. To resolve this challenge, we present a comprehensive regulation framework that allows the learnable prompts to actively interact with the well-learned general knowledge in large 3D models to maintain good generalization. Specifically, the proposed framework imposes multiple explicit constraints on the prompt learning trajectory by maximizing the mutual agreement between task-specific predictions and task-agnostic knowledge. We design the regulation framework as a plug-and-play module to embed into existing representative large 3D models. Surprisingly, our method not only realizes consistently increasing generalization ability but also enhances task-specific 3D recognition performances across various 3DDG benchmarks by a clear margin. Considering the lack of study and evaluation on 3DDG, we also create three new benchmarks, namely base-to-new, cross-dataset and few-shot generalization benchmarks, to enrich the field and inspire future research.

Environment

Package Setup

  • dassl
  • Ubuntu 23.10
  • Python 3.8.16
  • PyTorch 1.12.0
  • CUDA 11.6
  • torchvision 0.13.0
  • timm 0.9.16
  • pueue & pueued 2.0.4
  # NOTE The option 1 is recommended. A complete package list is provided in `env.yaml`
  # option 1: create conda virtual env by your own
  conda create -n pointprc python=3.8.16
  codna activate pointprc
  # install torch
  pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
  # install dassl
  git clone https://github.com/auniquesun/dassl
  cd dassl/
  python setup.py develop # (no need to re-build if the source code is modified)

  # option 2: create conda virtual env according to the provided env.yaml
  conda env create -f env.yaml
  codna activate pointprc

pueue is a shell command management software, we use it for scheduling the model training & evaluation tasks, please refer to the official page for installation and basic usage. We recommend this tool because under its help you can run the experiments at scale thus save your time.

NOTE: We provide a complete package list of our virtual environment in env.yaml. Feel free to check whether you need a specific package. If it the case, run the following command to install it, e.g.

  pip install h5py==3.10.0 plyfile==1.0.3

Pre-trained Weights

  1. In the experiments, we use the following models as the baselines. The pre-trained weights of these models can be found in their public GitHub repositories.

    • NOTE: ULIP-2 uses same text encoder as ULIP
  2. Make a folder called weights under this project and save the pre-trained weights into this folder.

Datasets

  1. We conduct experiments on three new 3D domain generalization (3DDG) benchmarks proposed by us, as introduced in the next section.

    • base-to-new class generalization (base2new)
    • cross-dataset generalization (xset)
    • few-shot generalization (fewshot)
  2. The structure of these benchmarks should be organized as follows.

    /path/to/Point-PRC
    |----data # placed in the same level of `trainers`, `weights`, etc. 
        |----base2new
            |----modelnet40
            |----scanobjectnn
            |----shapenetcorev2
        |----xset
            |----corruption
            |----dg
            |----sim2real
            |----pointda
        |----fewshot
            |----modelnet40
            |----scanobjectnn
            |----shapenetcorev2
  1. You can find the usage instructions and download links of these new 3DDG benchmarks in the following section.

New 3DDG Benchmarks

Base-to-new Class Generalization

  1. The datasets used in this benchmark can be downloaded according to the following links.

  2. The following table shows the statistics of this benchmark.

Cross-dataset Generalization

  1. The datasets used in this benchmark can be downloaded according to the following links.

  2. The following table shows the statistics of this benchmark.

Few-shot Generalization

  1. Although this benchmark contains same datasets as the Base-to-new Class, it investigates the model generalization under extremely low-data regime (1, 2, 4, 8, and 16 shots), which is quite different from the evaluation setting in Base-to-new Class Generalization.

  2. The following table shows the statistics of this benchmark.

Usage

We describe the experiment settings and implementation details in the paper, referring to Section 3.3, Section 4.1, Appendix A.1 and A.2.

Base-to-new Class Generalization

  1. This part corresponds to the experiments in Section 4.2 (Table 1) and Appendix (Table 9).

  2. To evaluate the performances on this benchmark, you will use scripts/pointprc/base2new_train.sh and scripts/pointprc/base2new_test.sh. The former trains a model on base classes while the latter evaluates the trained model on new classes. Both scripts have 18 input arguments, as commented in the script file.

  3. Taking S-PB_T50_RS (the hardest split of ScanObjectNN) as an example, we train the model on base classes and then evaluate the performance on new classes.

    # prompt learning on base classes
    ./scripts/pointprc/base2new_train.sh 0 data/base2new/scanobjectnn scanobjectnn custom_ulip manual64 9 2 2 16 20 hardest full task_perform False False False ulip2 l1_dist

    # test on novel classes
    ./scripts/pointprc/base2new_test.sh 0 data/base2new/scanobjectnn scanobjectnn custom_ulip manual64 9 2 2 16 20 hardest full task_perform False False False ulip2 l1_dist

Few-shot Generalization

  1. This part corresponds to the experiments in Section 4.4 (Figure 4).

  2. To evaluate the performances on this benchmark, you will use scripts/pointprc/fewshot_train_eval.sh. This script also has 18 command-line arguments, as explained in the file.

  3. Taking ShapeNetCoreV2 as an example, we train the model using 1, 2, 4, 8 and 16 shots per class and then evaluate the performance on the whole test set of the dataset.

    # train using (1/2/4/8/16)-shot and evaluate on the whole test set
    ./scripts/pointprc/fewshot_train_eval.sh 0 data/fewshot/shapenetcorev2 shapenetcorev2 custom_ulip 9 2 2 1 50 obj_only full task_perform False False False manual64 ulip2 l1_dist

Cross-dataset Generalization

  1. This part corresponds to the experiments in Section 4.3 (Table 2 and 3) and Appendix (Table 10 and 11).

  2. This benchmark has four types of evaluation settings: OOD Generalization, Data Corruption, Sim-to-Real, and PointDA, referring to Table 5 in Appendix of the paper.

  3. To evaluate the performances on OOD Generalization of this benchmark, you will use scripts/pointprc/xset_test_dg.sh. This scripts accepts 17 command-line arguments, as explained in the file.

    • Note that in this setting, ShapeNetCorev2 serves as the source domain and other four datasets (ModelNet40, S-PB_T50_RS, S-OBJ_BG, S-OBJ_ONLY, Omni3D) act as the target domains.
    • Since the model has been trained on ShapeNetCoreV2 in the Few-shot Generalization benchmark, when evaluating the performance on a target domain (e.g., Omni3D), we directly load the weights of 16-shot ShapeNetCorev2.
    # train using (1/2/4/8/16)-shot and evaluate on the whole test set
    ./scripts/pointprc/xset_test_dg.sh 0 data/xset/omniobject3d omniobject3d shapenetcorev2 12 4 4 50 obj_only full task_perform False False False manual64 ulip2 l1_dist
  1. To evaluate the performances on other three settings, you can check the following scripts for details.
    • scripts/pointprc/xset_corrupt.sh for Data Corruption
    • scripts/pointprc/xset_train_sim2real.sh and scripts/pointprc/xset_test_sim2real.sh for Sim-to-Real
    • scripts/pointprc/xset_train_pointda.sh and scripts/pointprc/xset_test_pointda.sh for PointDA

Citation

    @inproceedings{sun24point-prc,
        title={Point-PRC: A Prompt Learning Based Regulation Framework for Generalizable Point Cloud Analysis},
        author={Sun, Hongyu and Ke, Qiuhong and Wang, Yongcai and Chen, Wang and Yang, Kang and Li, Deying and Cai, Jianfei},
        booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems (NeurIPS)},
        year={2024},
        url={https://openreview.net/forum?id=g7lYP11Erv}
    }

Acknowledgement

Our implementation is partially inspired by the following projects, thanks to their great work

  1. Dassl
  2. ULIP
  3. PointCLIP
  4. PointCLIP V2
  5. PromptSRC

Contact

If you have any question about our work, please search related issues or create a new one in this repository.