Skip to content

Latest commit

 

History

History
142 lines (114 loc) · 6.7 KB

README.md

File metadata and controls

142 lines (114 loc) · 6.7 KB

Image to Sphere: Learning Equivariant Features for Efficient Pose Prediction

Paper | Project Page


I2S model This repository implements a hybrid equivariant model for SO(3) reasoning from 2D images for object pose estimation. The underlying SO(3) symmetry of the pose estimation task is not accessible in an image, which can only be transformed by in-plane rotations. Our model, I2S, projects features from the image plane onto the sphere, which is SO(3) transformable. Thus, the model is able to leverage SO(3)-equivariant group convolutions which improve sample efficiency. Conveniently, the output of the group convolution are coefficients over the Fourier basis of SO(3), which form a concise yet expressive representation for distributions over SO(3). Our model can capture complex pose distributions that arise from occlusions, ambiguity or object symmetries.

Table of Contents

  1. Colab Demos
    1. Visualize Predictions
    2. Model Walkthrough
    3. Intro to Spherical Convolution
  2. Installation
  3. Dataset Preparation
  4. Train I2S
    1. ModelNet10-SO(3)
    2. SYMSOL
    3. PASCAL3D+
  5. Pretrained Models
  6. Citation
  7. Acknowledgements

Colab Demos

Visualize Predictions

This Colab notebook loads pretrained I2S models on PASCAL3D+ and ModelNet10-SO(3) and visualizes output distributions generated for images from the test set. You can also upload your own images and see what the model predicts.

Model Walkthrough

This Colab notebook goes step-by-step through the construction of I2S, and illustrates how you can modify different components for a custom application.

Intro to Spherical Convolution

This Colab notebook helps you understand spherical harmonics and spherical convolution with some visualizations.

Installation

This code was tested with python 3.8. You can install all necessary requirements with pip:

pip install -r requirements.txt

You may get lots of warnings from e3nn about deprecated functions. If so, run commands as python -W ignore -m src.train ...

Dataset preparation

Follow instruction in datasets/README.md. Make sure to run commands from within the datasets folder.

Train I2S

ModelNet10-SO(3)

python -m src.train --dataset_name=modelnet10 --encoder=resnet50_pretrained --seed=0

Rotation error (in radians) on the test set will be stored in results/pascal3d-warp-synth_resnet101-pretrained_seed0/eval.npy

To train on the limited training set (20 views per instance), run:

python -m src.train --dataset_name=modelnet10-limited --encoder=resnet50_pretrained --seed=0

SYMSOL

Here is an example for training on SYMSOL I with 50k views per instance

python -m src.train --dataset_name=symsolI-50000 --encoder=resnet50_pretrained --seed=0

Average log likelihood on the test set will be stored in results/symsolI-50000_resnet50-pretrained_seed0/eval_log_likelihood.npy

You can adjust the number of views (--dataset_name=symsolI-10000 will use 10k views per instance) or train on SYMSOL II objects (--dataset_name=symsolII-50000 will train on sphX; --dataset_name=symsolIII-50000 will train on cylO; --dataset_name=symsolIIII-50000 will train on tetX). We train a single model on all of SYMSOL I, but separate models for each object from SYMSOL II.

Train on PASCAL3D+

python -m src.train --dataset_name=pascal3d-warp-synth --encoder=resnet101_pretrained --seed=0

Rotation error (in radians) on the test set will be stored in results/pascal3d-warp-synth_resnet101-pretrained_seed0/eval.npy

Pretrained Models

Pascal3D+

Download the checkpoint here. It achieves median rotation error of 9.6 degrees averaged over all twelve classes. It was trained using the procedure outlined in the paper.

Use the following code to load the weights:
from src.predictor import I2S

model = I2S(num_classes=12, encoder='resnet101')
checkpoint = torch.load('pascal3d_checkpoint.pt')['model_state_dict']
model.load_state_dict(checkpoint)
model.eval()

ModelNet10-SO(3)

Download the checkpoint here. It achieves an Acc@15 of 0.721, averaged over all 10 classes. It was trained using the procedure outlined in the paper (100 training views).

Use the following code to load the weights:
from src.predictor import I2S

model = I2S(num_classes=10, encoder='resnet50')
checkpoint = torch.load('modelnet10so3_checkpoint.pt')['model_state_dict']
model.load_state_dict(checkpoint)
model.eval()

Citation

To cite this work, please use the following bibtex:

@inproceedings{
	klee2023image2sphere,
	title={Image to Sphere: Learning Equivariant Features for Efficient Pose Prediction},
	author={David M. Klee and Ondrej Biza and Robert Platt and Robin Walters},
	booktitle={International Conference on Learning Representations},
	year={2023},
	url={https://openreview.net/forum?id=_2bDpAtr7PI}
}

Acknowledgements

The code for loading and warping PASCAL3D+ images is taken from this repo. The code for generating healpy grids and visualizing distributions over SO(3) is taken from this repo.