Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add simclr config #88

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions configs/simclr/simclr_clas_r50.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
epochs: 90
output_dir: output_dir

model:
name: Classification
backbone:
name: ResNet
depth: 50
with_pool: true
frozen_stages: 4
head:
name: ClasHead
with_avg_pool: false
in_channels: 4096

dataloader:
train:
num_workers: 8
sampler:
batch_size: 128
shuffle: true
drop_last: False
dataset:
name: ImageNet
dataroot: data/ILSVRC2012/train
return_label: True
transforms:
- name: RandomResizedCrop
size: 224
- name: RandomHorizontalFlip
- name: Transpose
- name: NormalizeImage
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
val:
num_workers: 8
sampler:
batch_size: 128
shuffle: false
drop_last: false
dataset:
name: ImageNet
dataroot: data/ILSVRC2012/val
return_label: True
transforms:
- name: Resize
size: 256
- name: CenterCrop
size: 224
- name: Transpose
- name: NormalizeImage
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]


lr_scheduler:
name: MultiStepDecay
learning_rate: 1.6
milestones: [60, 80]

optimizer:
name: Momentum
weight_decay: 0.0

log_config:
name: LogHook
interval: 50

custom_config:
- name: EvaluateHook
6 changes: 3 additions & 3 deletions docs/Train_SimCLR_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c c
##### Cifar10

```
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simclr/simclr_r50_IM.yaml
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simclr/simclr_r18_cifar10.yaml
```


Expand All @@ -68,12 +68,12 @@ python tools/passl2ppclas/convert.py --type res50 --checkpoint ${CHECKPOINT} --o

#### Train:
```
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/moco/moco_clas_r50.yaml --pretrained ${WEIGHT_FILE}
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simclr/simclr_clas_r50.yaml --pretrained ${WEIGHT_FILE}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

表格里的模型link 末尾少了一个s,导致无法下载
https://passl.bj.bcebos.com/models/simclr_r50_ep100_ckpt.pdparams

```

#### Evaluate:
```
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/moco/moco_clas_r50.yaml --load ${CLS_WEGHT_FILE} --evaluate-only
python -m paddle.distributed.launch --gpus="0,1,2,3,4,5,6,7" tools/train.py -c configs/simclr/simclr_clas_r50.yaml --load ${CLS_WEGHT_FILE} --evaluate-only
```

The trained linear weights in conjuction with the backbone weights can be found at [SimCLR linear](https://passl.bj.bcebos.com/models/simclr_r50_ep100_ckpt.pdparam).
Expand Down
6 changes: 4 additions & 2 deletions passl/modeling/backbones/resnetsimclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import paddle
import paddle.nn as nn
from .resnetcifar import ResNet
from .resnetcifar import BasicBlock, BottleneckBlock
# from .resnetcifar import ResNet
# from .resnetcifar import BasicBlock, BottleneckBlock
from .resnetimagenet import ResNet
from .resnetimagenet import BasicBlock, BottleneckBlock

from .builder import BACKBONES
from ...modules import init, freeze
Expand Down