Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet-50 训练 #222

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions projects/classification/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Path to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATH = ""
# Dataset name
_C.DATA.DATASET = "cifar100"
_C.DATA.DATASET = "imagenet"
# Input image size
_C.DATA.IMG_SIZE = 224
# Interpolation to resize image (random, bilinear, bicubic)
Expand All @@ -40,7 +40,7 @@
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model arch
_C.MODEL.ARCH = "swin_tiny_patch4_window7_224"
_C.MODEL.ARCH = "resnet50"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以通过外部的.yaml的config来覆盖这里的config, 我记得在/configs文件夹下有相关的参考

# Pretrained weight from checkpoint
_C.MODEL.PRETRAINED = False
# Path to a specific weights to load, e.g., "./checkpoints/swin_tiny_pretrained_model"
Expand Down Expand Up @@ -90,13 +90,15 @@

# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = "adamw"
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
_C.TRAIN.OPTIMIZER.NAME = "sgd"
# # Optimizer Epsilon
# _C.TRAIN.OPTIMIZER.EPS = 1e-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分不用注释,可以仔细看一下optimizer.py里的build_optimizer函数, 里面会进行一些判断

# # Optimizer Betas
# _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
# SGD momentum
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
# # NESTEROV
_C.TRAIN.OPTIMIZER.NESTEROV = True

# -----------------------------------------------------------------------------
# Augmentation settings
Expand All @@ -110,12 +112,22 @@
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = "pixel"
# Scale
_C.AUG.SCALE = [0.08, 1.0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些部分建议还是用外部的.yaml进行覆盖

# Ratio
_C.RATIO = [0.75, 1.0+1/3]
# Hflip
_C.HFLIP = 0.5
# Vflip
_C.VFLIP = 0.0
# Interpolation
_C.INTERPLOATION = 'random'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
_C.AUG.MIXUP = 0.0
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
_C.AUG.CUTMIX = 0.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = None
# Probability of performing mixup or cutmix when either/both is enabled
Expand Down
58 changes: 58 additions & 0 deletions projects/classification/configs/resnet50_default_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
DATA:
BATCH_SIZE: 256
DATASET: imagenet
DATA_PATH: /home/ubuntu/work/oneflow/datasets
IMG_SIZE: 224
INTERPOLATION: bicubic
ZIP_MODE: False
CACHE_MODE: "part"
PIN_MEMORY: True
NUM_WORKERS: 8

MODEL:
PRETRAINED: False
RESUME: ""
LABEL_SMOOTHING: 0.1

TRAIN:
START_EPOCH: 0
EPOCHS: 300
WARMUP_EPOCHS: 3
WARMUP_LR: 0.0001
MIN_LR: 1.0e-06
WEIGHT_DECAY: 2.0e-05
BASE_LR: 0.01
CLIP_GRAD: None
AUTO_RESUME: True
ACCUMULATION_STEPS: 0

LR_SCHEDULER:
NAME: cosine
MILESTONES: None

OPTIMIZER:
NAME: sgd
MOMENTUM: 0.9
NESTEROV: True


AUG:
COLOR_JITTER: 0.4
AUTO_AUGMENT: rand-m9-mstd0.5-inc1
REPROB: 0.6
REMODE: pixel
RECOUNT: 1
MIXUP: 0.
CUTMIX: 0.
CUTMIX_MINMAX: None

TEST:
CROP: True
SEQUENTIAL: False

TAG: default
SAVE_FREQ: 1
PRINT_FREQ: 10
SEED: 42
EVAL_MODE: False
THROUGHPUT_MODE: False