Skip to content

Commit

Permalink
[BC-Break] The Great Refactor with config files (#45)
Browse files Browse the repository at this point in the history
* simple registry

* register and run function

* base runners & lane trainer

* lane tester

* seg trainer & tester

* register datasets

* register losses

* lr_schedulers

* lr_schedulers

* optimizers

* torch scheduler

* register transforms

* register models

* remove defs

* ddp from args

* first config

* refactor datasets

* lane starting

* parse args

* debug

* erfnet

* refactor configs

* doc

* optimize args

* first seg config

* losses wrappers

* debug

* city matched

* work dir

* save-dir for final results

* debug

* tensorboard log moved

* lane configs

* doc change

* basic configs [unchecked]

* shell

* debug lstr

* debug gtav/synthia

* val_batch_size

* lstr debug

* default val_bs back to align with seg

* seg complete

* vgg16 finished

* vggs

* conversions

* add resnet50

* add resnet34

* add resnet18

* add resnet101 tusimple

* add rep-vgg a0

* enet erfnet --cfg-options

* lstr resa101

* add resnet101 baseline scnn

* RepVGG-A tusimple

* tools

* fix cfg-options

* add repvgg culane

* multi  GPU tests

* update the result of repvgg-a on culane

* add repvgg-a1 scnn

* add mobilenetv2

* InvertedResidualV3 block

* arch error

* align arch with DeepLab-LargeFOV (resnets)

* bias=False

* add option for reducer

* debug

* update names

* fix a1-culane config files

* mobilenetv2 scnn config

* mobilenet v3 test

* align torchvision

* v2

* refactor

* lane dir

* delete

* bug fixes

* lane video

* v3 fix dilation

* v3 fix padding

* update repvgg config

* seg vis

* refactor common_models to a directory

* move vgg encoder

* remove test_images

* refactor doc & vis debug

* cherry-pick from rc-swin

* repvgg config clean up

* replace

* replace all

* credits

* new video

* adavanced doc

* final mods

Co-authored-by: cedricgsh <guoshaohua@sjtu.edu.cn>
  • Loading branch information
voldemortX and cedricgsh authored Jan 10, 2022
1 parent b21cb07 commit b93cc1f
Show file tree
Hide file tree
Showing 371 changed files with 11,706 additions and 3,910 deletions.
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Codebase for deep self-driving perception

*PytorchAutoDrive* is a **pure Python** codebase includes semantic segmentation models, lane detection models based on **PyTorch**. Here we provide full stack supports from research (model training, testing, fair benchmarking) to application (visualization, model deployment). Poster at PyTorch Developer Day: [PytorchAutoDrive: Toolkit & Fair Benchmark for Autonomous Driving Research](https://drive.google.com/file/d/14EgcwPnKvAZJ1aWqBv6W9Msm666Wqi5a/view?usp=sharing).
*PytorchAutoDrive* is a **pure Python** codebase includes semantic segmentation models, lane detection models based on **PyTorch**. Here we provide full stack supports from research (model training, testing, fair benchmarking by simply writing configs) to application (visualization, model deployment). Poster at PyTorch Developer Day: [PytorchAutoDrive: Toolkit & Fair Benchmark for Autonomous Driving Research](https://drive.google.com/file/d/14EgcwPnKvAZJ1aWqBv6W9Msm666Wqi5a/view?usp=sharing).

*This repository is under active development, results with models uploaded are stable. For legacy code users, please check [deprecations](https://github.com/voldemortX/pytorch-auto-drive/issues/14) for changes.*

**A demo video from ERFNet:**

https://user-images.githubusercontent.com/32259501/124389349-3e0ea480-dd19-11eb-8947-cf5e9c95721a.mp4
https://user-images.githubusercontent.com/32259501/148680744-a18793cd-f437-461f-8c3a-b909c9931709.mp4

## Highlights

Various methods tested on a wide range of backbones, **modulated** and **easily understood** codes, image/keypoint loading, transformations and **visualizations**, **mixed precision training**, tensorboard logging and **deployment support** with ONNX and TensorRT.
Various methods on a wide range of backbones, **config** based implementations, **modulated** and **easily understood** codes, image/keypoint loading, transformations and **visualizations**, **mixed precision training**, tensorboard logging and **deployment support** with ONNX and TensorRT.

Models from this repo are faster to train (**single card trainable**) and often have better performance than other implementations, see [wiki](https://github.com/voldemortX/pytorch-auto-drive/wiki/Notes) for reasons and technical specification of models.

Expand Down Expand Up @@ -70,11 +70,17 @@ Get started with [SEGMENTATION.md](docs/SEGMENTATION.md) for semantic segmentati
Refer to [VISUALIZATION.md](docs/VISUALIZATION.md) for a visualization & inference tutorial, for image and video inputs.

## Benchmark Tools

Refer to [BENCHMARK.md](docs/BENCHMARK.md) for a benchmarking tutorial, including FPS test, FLOPs & memory count for each supported model.

## Deployment

Refer to [DEPLOY.md](docs/DEPLOY.md) for ONNX and TensorRT deployment supports.

## Advanced Tutorial

Checkout [ADVANCED_TUTORIAL.md](docs/ADVANCED_TUTORIAL.md) for advanced use cases and how to code in PytorchAutoDrive.

## Contributing

We welcome **Pull Requests** to fix bugs, update docs or implement new features etc. We also welcome **Issues** to report problems and needs, or ask questions (since your question might be more common and helpful to the community than you presume). Interested folks should checkout our [roadmap](https://github.com/voldemortX/pytorch-auto-drive/issues/4).
Expand All @@ -95,8 +101,6 @@ This repository implements (or plan to implement) the following interesting pape

[RESA: Recurrent Feature-Shift Aggregator for Lane Detection](https://arxiv.org/abs/2008.13719) AAAI 2021

[Learning Lightweight Lane Detection CNNs by Self Attention Distillation](https://arxiv.org/abs/1908.00821) ICCV 2019

[Polynomial Regression Network for Variable-Number Lane Detection](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123630698.pdf) ECCV 2020

[End-to-end Lane Shape Prediction with Transformers](https://arxiv.org/abs/2011.04233) WACV 2021
Expand All @@ -105,8 +109,16 @@ You are also welcomed to make additions on this paper list, or open-source your

## Notes:

1. Cityscapes dataset is down-sampled by 2 when training at 256 x 512, to specify different sizes, modify them in [configs.yaml](configs.yaml); similar changes can be done with other experiments.
1. Cityscapes dataset is down-sampled by 2 when training at 256 x 512, to specify different sizes, modify them in config files if needed.

2. Training times are measured on **a single RTX 2080Ti**, including online validation time for segmentation, test time for lane detection.

3. All segmentation results reported are from single model without CRF and without multi-scale testing.

## Credits:

PytorchAutoDrive is maintained by Zhengyang Feng ([VoldemortX](https://github.com/voldemortX)) and Shaohua Guo ([cedricgsh](https://github.com/cedricgsh)).

Community contributors (GitHub ID): [kalkun](https://github.com/kalkun)

People who sponsored us with hardware: Junshu Tang ([junshutang](https://github.com/junshutang))
7 changes: 4 additions & 3 deletions autotest_culane.sh
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/bin/bash
echo experiment name: $1
echo status: $2
echo save dir: $3

cd tools/culane_evaluation
if [ "$2" = "test" ]; then
# Perform test with official scripts
./eval.sh $1
./eval.sh $1 $3
# Calculate overall F1 score
python cal_total.py --exp-name=$1
python cal_total.py --exp-name=$1 --save-dir=$3
else
# Perform validation with official scripts
./eval_validation.sh $1
./eval_validation.sh $1 $3
fi
cd ../../
3 changes: 2 additions & 1 deletion autotest_llamas.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/bin/bash
echo experiment name: $1
echo status: $2
echo save dir: $3
data_dir=../../../../dataset/llamas/labels/valid
cd tools/llamas_evaluation/

if [ "$2" = "val" ]; then
# we can provide the valid set to evaluate models
python evaluate.py --pred_dir=../../output/valid --anno_dir=${data_dir} --exp_name=$1
python evaluate.py --pred_dir=../../output/valid --anno_dir=${data_dir} --exp_name=$1 --save-dir=$3
else
echo "The test set of llamas is not public available."
fi
Expand Down
5 changes: 3 additions & 2 deletions autotest_tusimple.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
data_dir=../../../../dataset/tusimple/
echo experiment name: $1
echo status: $2
echo save dir: $3

# Perform test/validation with official scripts
cd tools/tusimple_evaluation
if [ "$2" = "test" ]; then
python lane.py ../../output/${1}.json ${data_dir}test_label.json $1
python lane.py ../../output/${1}.json ${data_dir}test_label.json $1 $3
else
python lane.py ../../output/${1}.json ${data_dir}label_data_0531.json $1
python lane.py ../../output/${1}.json ${data_dir}label_data_0531.json $1 $3
fi
cd ../../
159 changes: 0 additions & 159 deletions configs.yaml

This file was deleted.

43 changes: 43 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## Configs

Config files in *PytorchAutoDrive* (`./configs/`) are used to define models,
how they are trained, tested, visualized, *etc*.

### Registry Mechanism

Different to existing class-based registers, we can also register functions.
For functions, you only write static args in your config,
while passing the dynamic ones on-the-fly by:

```
REGISTRY.from_dict(
<config dict for a function/class>,
kwarg1=1, kwarg2=2, ...
)
```

Note that each argument must be keyword (k=v), and some kwargs can overwrite dict configs.

### Use An Existing Config

Modify customized options like the root of your datasets (in `configs/*/common/_*.py`).

### Write A New Config

Copy the config file most similar to your use case and modify it.
Note that you can simply import config parts from `common` or other config files, it is like writing Python.

### Register A New Class/Func

Choose the appropriate registry and register your Class/Func by:

```
@REGISTRY.register()
```

Remember you still need to import this Class/Func for the registering to take effects.

### How To Read The Code

Since you can't just click 'go to definition' in your IDE,
it is suggested to search the directory for each Class/Function by `name` in configs.
70 changes: 70 additions & 0 deletions configs/lane_detection/baseline/enet_culane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Data pipeline
from configs.lane_detection.common.datasets.culane_seg import dataset
from configs.lane_detection.common.datasets.train_level0_288 import train_augmentation
from configs.lane_detection.common.datasets.test_288 import test_augmentation

# Optimization pipeline
from configs.lane_detection.common.optims.segloss_5class import loss
from configs.lane_detection.common.optims.sgd05 import optimizer
from configs.lane_detection.common.optims.ep12_poly_warmup200 import lr_scheduler


train = dict(
exp_name='enet_baseline_culane',
workers=10,
batch_size=20,
checkpoint=None,
# Device args
world_size=0,
dist_url='env://',
device='cuda',

val_num_steps=0, # Seg IoU validation (mostly useless)
save_dir='./checkpoints',

input_size=(288, 800),
original_size=(590, 1640),
num_classes=5,
num_epochs=12,
collate_fn=None, # 'dict_collate_fn' for LSTR
seg=True, # Seg-based method or not
)

test = dict(
exp_name='enet_baseline_culane',
workers=10,
batch_size=80,
checkpoint='./checkpoints/enet_baseline_culane/model.pt',
# Device args
device='cuda',

save_dir='./checkpoints',

seg=True,
gap=20,
ppl=18,
thresh=0.3,
collate_fn=None, # 'dict_collate_fn' for LSTR
input_size=(288, 800),
original_size=(590, 1640),
max_lane=4,
dataset_name='culane'
)

model = dict(
name='ENet',
num_classes=5,
encoder_relu=False,
decoder_relu=True,
dropout_1=0.01,
dropout_2=0.1,
encoder_only=False,
pretrained_weights=None,
lane_classifier_cfg=dict(
name='EDLaneExist',
num_output=5 - 1,
flattened_size=4500,
dropout=0.1,
pool='avg'
)
)
Loading

0 comments on commit b93cc1f

Please sign in to comment.