Skip to content

Code for "End-to-End Adaptive Sampling and Representation for Event-based Detection with Recurrent Spiking Neural Networks", ECCV 2024

License

Notifications You must be signed in to change notification settings

Windere/EAS-SNN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

EAS-SNN: End-to-End Adaptive Sampling and Representation for Event-based Detection with Recurrent Spiking Neural Networks

This is the official Pytorch implementation of the ECCV 2024 paper: EAS-SNN: End-to-End Adaptive Sampling and Representation for Event-based Detection with Recurrent Spiking Neural Networks

Summary: In this study, we discover that the neural dynamics of spiking neurons align closely with the behavior of an ideal temporal event sampler. Motivated by this, we propose a novel adaptive sampling module that leverages recurrent convolutional SNNs enhanced with temporal memory, facilitating a fully end-to-end learnable framework for event-based detection. Additionally, we introduce Residual Potential Dropout (RPD) and Spike-Aware Training (SAT) to regulate potential distribution and address performance degradation encountered in spike-based sampling modules.

Installation

The main dependencies are listed below:

Dependency Version
spikingjelly 0.0.0.0.14
h5py 3.8.0
torchvision 0.16.1
thop 0.1.1
pytorch 2.1.1
pycocotools 2.0.6
opencv 4.7.0
numpy 1.26.0
einops 0.8.0
python 3.10.9

You can try to install the required packages by running:

conda env create -f conda-env.yml

or

pip install -r pip-requirements.txt

Required Data

  1. The raw GEN-1 dataset can be downloaded from here

  2. The raw 1Mpx dataset can be downloaded from here

  3. The preprocessed 1Mpx dataset by RVT can be downloaded from here

  4. The raw N-Caltech 101 dataset can be downloaded from here

After unzipping the dataset, you should have the following directory structure:

    # The Splitted N-Caltech101 Dataset
    ├── N-Caltech
    │   ├── Caltech101
    │   ├── Caltech101_annotations
    │   ├── test.txt
    │   ├── train.txt
    │   └── val.txt
    
    #  The raw 1Mpx/Gen1 dataset
    ├── Root Directory
    │   ├── Raw Splitted Dataset
    │   │   ├── train
    │   │   │   ├── EVENT_STREAM_NAME_td.dat
    │   │   │   ├── EVENT_STREAM_NAME_bbox.npy
    │   │   │   └── ...
    │   │   ├── val
    │   │   │   ├── EVENT_STREAM_NAME_td.dat
    │   │   │   ├── EVENT_STREAM_NAME_bbox.npy
    │   │   │   └── ...
    │   │   ├── test
    │   │   │   ├── EVENT_STREAM_NAME_td.dat
    │   │   │   ├── EVENT_STREAM_NAME_bbox.npy
    │   │   │   └── ...
    
    # The processed 1Mpx Dataset
    ├── Root Directory
    │   ├── train
    │   │   ├── EVENT_STREAM_NAME
    │   │   │   ├── event_representations_v2
    │   │   │   │   ├── stacked_histogram_dt=50_nbins=10   
    │   │   │   │   │   ├── event_representations_ds2_nearest.h5  
    │   │   │   │   │   ├── objframe_idx_2_repr_idx.npy 
    │   │   │   │   │   ├── timestamps_us.npy 
    │   │   │   ├── labels_v2
    │   │   │   │   ├── labels.npz   
    │   │   │   │   ├── timestamps_us.npy 
    │   │   │   │   ├──  ...

Usage

  1. First, install all required packages and cd to the 'tools' directory.

  2. Run the following command to train an EAS-SNN model on the GEN-1 dataset:

    • Commands for training spiking YOLOX-S (non-spiking FPN + non-spiking HEAD) on the GEN-1 dataset

      CUDA_VISIBLE_DEVICES=0,1,2,3  python train_event.py -n e-yolox-s -d 4 -b 64 \
      -expn exp_name  max_epoch 30 data_num_workers 4 T 3 eval_interval 10 embedding arsnn \ 
      basic_lr_per_img 0.000015625 seed 80 data_name gen1 data_dir /data2/wzm/dataset/GEN1/raw/ \
      num_classes 2 scheduler fixed spike_attach True  thresh 1 readout sum embedding_depth 2 \
      embedding_ksize 5 write_zero True  use_spike True spike_fn atan
    • Commands for training spiking YOLOX-S (spiking FPN + HEAD) on the GEN-1 dataset

      CUDA_VISIBLE_DEVICES=0,1,2,3  python train_event.py -n e-yolox-s -d 4 -b 58 \
      -expn exp_name  max_epoch 30 data_num_workers 4 T 3 eval_interval 10 embedding arsnn \
      basic_lr_per_img 0.00001724 seed 80 data_name gen1 data_dir /data2/wzm/dataset/GEN1/raw/ \
      num_classes 2 scheduler fixed spike_attach True  thresh 1 readout sum embedding_depth 2 \
      embedding_ksize 5 write_zero True  use_spike full_spike spike_fn atan
    • Commands for training spiking YOLOX-S (spiking FPN + spiking HEAD) on the GEN-1 dataset

      CUDA_VISIBLE_DEVICES=0,1,2,3  python train_event.py -n e-yolox-s -d 4 -b 54  -expn exp_name \
      max_epoch 30 data_num_workers 4 T 3 eval_interval 10 embedding arsnn basic_lr_per_img 0.00001851 \ 
      seed 80 data_name gen1 data_dir /data2/wzm/dataset/GEN1/raw/ num_classes 2 scheduler fixed spike_attach True \
      thresh 1 readout sum embedding_depth 2 embedding_ksize 5 write_zero True  use_spike full_spike_v2 spike_fn atan
  3. Run the following command to train an EAS-SNN model on the N-Caltech dataset:

    CUDA_VISIBLE_DEVICES=0,1,2,3  python train_event.py -n e-yolox-m -d 4 -b 32 -expn exp_name \ 
    max_epoch 60 data_num_workers 2 eval_interval 10 embedding arsnn basic_lr_per_img  0.000009375 \
    seed 80 data_dir /data2/wzm/dataset/N-Caltech/ no_aug_epochs 0 Tm 4 T 3 scheduler fixed \
    spike_attach True  write_zero True readout sum use_spike full_spike_v2  window 0  spike_fn atan alpha 1.5
  4. Run the following command to evaluate an EAS-SNN model on the GEN-1 dataset:

    python eval_event.py -n e-yolox-m -d 4 -b 36 -c ./YOLOX_outputs/$exp_name/best_ckpt.pth --conf 0.001 \
     --eval_proh  data_num_workers 4 embedding arsnn seed 80 data_name gen1 data_dir /data2/wzm/dataset/GEN1/raw/ \
     num_classes 2 Tm 4 T 3 spike_attach True thresh 1 readout sum embedding_depth 2 embedding_ksize 5 \
    write_zero True use_spike full_spike spike_fn atan
  5. The hyperparameter Ts can be modified to explore the ability for temporal modelling capacity of SNNs as shown in Fig.4 in the paper.

Citation Info

If you find this work helpful, please consider citing our paper:

@article{wang2024eas,
  title={EAS-SNN: End-to-End Adaptive Sampling and Representation for Event-based Detection with Recurrent Spiking Neural Networks},
  author={Wang, Ziming and Wang, Ziling and Li, Huaning and Qin, Lang and Jiang, Runhao and Ma, De and Tang, Huajin},
  journal={arXiv preprint arXiv:2403.12574},
  year={2024}
}

Acknowledgement

This project has adpated code from the following libraries:

  • YOLOX for the detection PAFPN/head
  • Tonic for the event-based representation, like voxel grid
  • RVT for the 1Mpx preprocessed dataset and the Prophesee evaluation tool
  • ASGL and SpikingJelly for the SNN implementation and event visualization

About

Code for "End-to-End Adaptive Sampling and Representation for Event-based Detection with Recurrent Spiking Neural Networks", ECCV 2024

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published