From cc349ad8f6121bd7344640433d91db500e0a169c Mon Sep 17 00:00:00 2001 From: 59rentainhe <596106517@qq.com> Date: Tue, 12 Apr 2022 06:51:30 +0000 Subject: [PATCH 1/4] add basic resnet50 models --- .../configs/cnn_default_settings.yaml | 2 +- projects/classification/models/build.py | 6 +- .../classification/models/fast_resnet50.py | 281 ++++++++++++++++++ 3 files changed, 287 insertions(+), 2 deletions(-) create mode 100644 projects/classification/models/fast_resnet50.py diff --git a/projects/classification/configs/cnn_default_settings.yaml b/projects/classification/configs/cnn_default_settings.yaml index 0dea275e..afce0d0b 100644 --- a/projects/classification/configs/cnn_default_settings.yaml +++ b/projects/classification/configs/cnn_default_settings.yaml @@ -1,7 +1,7 @@ DATA: BATCH_SIZE: 16 DATASET: imagenet - DATA_PATH: /DATA/disk1/ImageNet/extract + DATA_PATH: /dataset/extract IMG_SIZE: 224 INTERPOLATION: bicubic ZIP_MODE: False diff --git a/projects/classification/models/build.py b/projects/classification/models/build.py index fc50582e..ff6d2636 100644 --- a/projects/classification/models/build.py +++ b/projects/classification/models/build.py @@ -1,7 +1,11 @@ from flowvision.models import ModelCreator +from .fast_resnet50 import fast_resnet50 def build_model(config): model_arch = config.MODEL.ARCH - model = ModelCreator.create_model(model_arch, pretrained=config.MODEL.PRETRAINED) + if model_arch == "fast_resnet50": + model = fast_resnet50() + else: + model = ModelCreator.create_model(model_arch, pretrained=config.MODEL.PRETRAINED) return model diff --git a/projects/classification/models/fast_resnet50.py b/projects/classification/models/fast_resnet50.py new file mode 100644 index 00000000..a2a7080d --- /dev/null +++ b/projects/classification/models/fast_resnet50.py @@ -0,0 +1,281 @@ +from typing import Type, Any, Callable, Union, List, Optional + +import oneflow as flow +import oneflow.nn as nn +from oneflow import Tensor + + +def conv3x3( + in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AvgPool2d((7, 7)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = flow.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def fast_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) \ No newline at end of file From 3077fa7c6a8f079fce1014e791f4b5a04dc6fef2 Mon Sep 17 00:00:00 2001 From: 59rentainhe <596106517@qq.com> Date: Tue, 12 Apr 2022 07:01:58 +0000 Subject: [PATCH 2/4] reverse --- .../configs/cnn_default_settings.yaml | 2 +- projects/classification/models/build.py | 6 +- .../classification/models/fast_resnet50.py | 281 ------------------ 3 files changed, 2 insertions(+), 287 deletions(-) delete mode 100644 projects/classification/models/fast_resnet50.py diff --git a/projects/classification/configs/cnn_default_settings.yaml b/projects/classification/configs/cnn_default_settings.yaml index afce0d0b..0dea275e 100644 --- a/projects/classification/configs/cnn_default_settings.yaml +++ b/projects/classification/configs/cnn_default_settings.yaml @@ -1,7 +1,7 @@ DATA: BATCH_SIZE: 16 DATASET: imagenet - DATA_PATH: /dataset/extract + DATA_PATH: /DATA/disk1/ImageNet/extract IMG_SIZE: 224 INTERPOLATION: bicubic ZIP_MODE: False diff --git a/projects/classification/models/build.py b/projects/classification/models/build.py index ff6d2636..fc50582e 100644 --- a/projects/classification/models/build.py +++ b/projects/classification/models/build.py @@ -1,11 +1,7 @@ from flowvision.models import ModelCreator -from .fast_resnet50 import fast_resnet50 def build_model(config): model_arch = config.MODEL.ARCH - if model_arch == "fast_resnet50": - model = fast_resnet50() - else: - model = ModelCreator.create_model(model_arch, pretrained=config.MODEL.PRETRAINED) + model = ModelCreator.create_model(model_arch, pretrained=config.MODEL.PRETRAINED) return model diff --git a/projects/classification/models/fast_resnet50.py b/projects/classification/models/fast_resnet50.py deleted file mode 100644 index a2a7080d..00000000 --- a/projects/classification/models/fast_resnet50.py +++ /dev/null @@ -1,281 +0,0 @@ -from typing import Type, Any, Callable, Union, List, Optional - -import oneflow as flow -import oneflow.nn as nn -from oneflow import Tensor - - -def conv3x3( - in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 -) -> nn.Conv2d: - """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation, - ) - - -def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -class BasicBlock(nn.Module): - expansion: int = 1 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - if groups != 1 or base_width != 64: - raise ValueError("BasicBlock only supports groups=1 and base_width=64") - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes) - self.relu = nn.ReLU() - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes) - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - expansion: int = 4 - - def __init__( - self, - inplanes: int, - planes: int, - stride: int = 1, - downsample: Optional[nn.Module] = None, - groups: int = 1, - base_width: int = 64, - dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU() - self.downsample = downsample - self.stride = stride - - def forward(self, x: Tensor) -> Tensor: - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class ResNet(nn.Module): - def __init__( - self, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - num_classes: int = 1000, - zero_init_residual: bool = False, - groups: int = 1, - width_per_group: int = 64, - replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - ) -> None: - super(ResNet, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - self._norm_layer = norm_layer - - self.inplanes = 64 - self.dilation = 1 - if replace_stride_with_dilation is None: - # each element in the tuple indicates if we should replace - # the 2x2 stride with a dilated convolution instead - replace_stride_with_dilation = [False, False, False] - if len(replace_stride_with_dilation) != 3: - raise ValueError( - "replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation) - ) - self.groups = groups - self.base_width = width_per_group - self.conv1 = nn.Conv2d( - 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False - ) - self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU() - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] - ) - self.layer3 = self._make_layer( - block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] - ) - self.layer4 = self._make_layer( - block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] - ) - self.avgpool = nn.AvgPool2d((7, 7)) - self.fc = nn.Linear(512 * block.expansion, num_classes) - - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - # Zero-initialize the last BN in each residual branch, - # so that the residual branch starts with zeros, and each residual block behaves like an identity. - # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 - if zero_init_residual: - for m in self.modules(): - if isinstance(m, Bottleneck): - nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] - elif isinstance(m, BasicBlock): - nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - - def _make_layer( - self, - block: Type[Union[BasicBlock, Bottleneck]], - planes: int, - blocks: int, - stride: int = 1, - dilate: bool = False, - ) -> nn.Sequential: - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer(planes * block.expansion), - ) - - layers = [] - layers.append( - block( - self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer, - ) - ) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - ) - ) - - return nn.Sequential(*layers) - - def _forward_impl(self, x: Tensor) -> Tensor: - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.maxpool(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = self.avgpool(x) - x = flow.flatten(x, 1) - x = self.fc(x) - - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def _resnet( - arch: str, - block: Type[Union[BasicBlock, Bottleneck]], - layers: List[int], - pretrained: bool, - progress: bool, - **kwargs: Any -) -> ResNet: - model = ResNet(block, layers, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) - return model - - -def fast_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: - return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) \ No newline at end of file From 32b0a7aa21176a167db6bdae8722e0d788717b11 Mon Sep 17 00:00:00 2001 From: 59rentainhe <596106517@qq.com> Date: Tue, 12 Apr 2022 08:26:21 +0000 Subject: [PATCH 3/4] init fast resnet50 project --- projects/fast_resnet50/dataset.py | 54 +++ projects/fast_resnet50/graph.py | 76 ++++ projects/fast_resnet50/lr_scheduler.py | 17 + .../fast_resnet50/models/fast_resnet50.py | 347 ++++++++++++++++++ projects/fast_resnet50/optimizer.py | 21 ++ 5 files changed, 515 insertions(+) create mode 100644 projects/fast_resnet50/dataset.py create mode 100644 projects/fast_resnet50/graph.py create mode 100644 projects/fast_resnet50/lr_scheduler.py create mode 100644 projects/fast_resnet50/models/fast_resnet50.py create mode 100644 projects/fast_resnet50/optimizer.py diff --git a/projects/fast_resnet50/dataset.py b/projects/fast_resnet50/dataset.py new file mode 100644 index 00000000..4ec51402 --- /dev/null +++ b/projects/fast_resnet50/dataset.py @@ -0,0 +1,54 @@ +import os + +from flowvision import datasets, transforms +from flowvision.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from flowvision.transforms.functional import str_to_interp_mode +from flowvision.data import Mixup + + +def build_dataset(is_train, config): + transform = build_transform(is_train, config) + if config.DATA.DATASET == "imagenet": + prefix = "train" if is_train else "val" + root = os.path.join(config.DATA.DATA_PATH, prefix) + dataset = datasets.ImageFolder(root, transform=transform) + + + +def build_transform(is_train, config): + resize_im = config.DATA.IMG_SIZE > 32 + if is_train: + t = [] + # this should always dispatch to transforms_imagenet_train + t.append(transforms.RandomResizedCrop( + size=(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), + interpolation=str_to_interp_mode(config.DATA.INTERPOLATION) + )) + t.append(transforms.RandomHorizontalFlip(p=0.5)) + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + + return transforms.Compose(t) + + t = [] + if resize_im: + if config.TEST.CROP: + size = int((256 / 224) * config.DATA.IMG_SIZE) + t.append( + transforms.Resize( + size, interpolation=str_to_interp_mode(config.DATA.INTERPOLATION) + ), + # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) + else: + t.append( + transforms.Resize( + (config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), + interpolation=str_to_interp_mode(config.DATA.INTERPOLATION), + ) + ) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + return transforms.Compose(t) \ No newline at end of file diff --git a/projects/fast_resnet50/graph.py b/projects/fast_resnet50/graph.py new file mode 100644 index 00000000..4447a152 --- /dev/null +++ b/projects/fast_resnet50/graph.py @@ -0,0 +1,76 @@ +import oneflow as flow + +def make_grad_scaler(): + return flow.amp.GradScaler( + init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + ) + +class TrainGraph(flow.nn.Graph): + def __init__(self, model, loss, optimizer, lr_scheduler, data_loader, config): + super().__init__() + if config.use_fp16: + # 使用 nn.Graph 的自动混合精度训练 + self.config.enable_amp(True) + self.set_grad_scalar( + flow.amp.GradScaler( + init_scale=2 ** 30, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, + ) + ) + elif config.scale_grad: + self.set_grad_scaler( + flow.amp.StaticGradScaler(flow.env.get_world_size()) + ) + + if config.fuse_add_to_output: + # 使用 nn.Graph 的add算子融合 + self.config.allow_fuse_add_to_output(True) + + if config.fuse_model_update_ops: + self.config.allow_fuse_model_update_ops(True) + + if config.conv_try_run: + # 使用 nn.Graph 的卷积试跑优化 + self.config.enable_cudnn_conv_heuristic_search_algo(False) + + if config.fuse_pad_to_conv: + # 使用 nn.Graph 的pad算子融合 + self.config.allow_fuse_pad_to_conv(True) + + + self.model = model + self.loss = loss + self.add_optimizer(optimizer, lr_sch=lr_scheduler) + self.data_loader = data_loader + + def build(self): + image, label = self.data_loader() + image = image.to("cuda") + label = label.to("cuda") + logits = self.model(image) + loss = self.cross_entropy(logits, label) + loss.backward() + return loss + + +class EvalGraph(flow.nn.Graph): + def __init__(self, model, data_loader, config): + super().__init__() + + if config.use_fp16: + # 使用 nn.Graph 的自动混合精度训练 + self.config.enable_amp(True) + + if config.fuse_add_to_output: + # 使用 nn.Graph 的add算子融合 + self.config.allow_fuse_add_to_output(True) + + self.data_loader = data_loader + self.model = model + + def build(self): + image, label = self.data_loader() + image = image.to("cuda") + label = label.to("cuda") + logits = self.model(image) + pred = logits.softmax() + return pred, label \ No newline at end of file diff --git a/projects/fast_resnet50/lr_scheduler.py b/projects/fast_resnet50/lr_scheduler.py new file mode 100644 index 00000000..8b00afb0 --- /dev/null +++ b/projects/fast_resnet50/lr_scheduler.py @@ -0,0 +1,17 @@ +import oneflow as flow + + +def build_lr_scheduler(config, optimizer, n_iter_per_epoch): + num_steps = int(config.train.epochs * n_iter_per_epoch) + warmup_steps = int(config.train.warmup_epochs * n_iter_per_epoch) + lr_scheduler = flow.optim.lr_scheduler.CosineDecayLR( + optimizer, decay_steps=num_steps + ) + if config.warmup_epochs > 0: + lr_scheduler = flow.optim.lr_scheduler.WarmUpLR( + lr_scheduler, + warmup_factor=0.01, + warmup_iters=warmup_steps, + warmup_method="linear" + ) + return lr_scheduler diff --git a/projects/fast_resnet50/models/fast_resnet50.py b/projects/fast_resnet50/models/fast_resnet50.py new file mode 100644 index 00000000..eea47a27 --- /dev/null +++ b/projects/fast_resnet50/models/fast_resnet50.py @@ -0,0 +1,347 @@ +import os + +import oneflow as flow +import oneflow.nn as nn +from oneflow import Tensor +from typing import Type, Any, Callable, Union, List, Optional + + +def conv3x3( + in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1 +) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + fuse_bn_relu=False, + fuse_bn_add_relu=False, + ) -> None: + super(Bottleneck, self).__init__() + self.fuse_bn_relu = fuse_bn_relu + self.fuse_bn_add_relu = fuse_bn_add_relu + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + + if self.fuse_bn_relu: + self.bn1 = nn.FusedBatchNorm2d(width) + self.bn2 = nn.FusedBatchNorm2d(width) + else: + self.bn1 = norm_layer(width) + self.bn2 = norm_layer(width) + self.relu = nn.ReLU() + + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.conv3 = conv1x1(width, planes * self.expansion) + + if self.fuse_bn_add_relu: + self.bn3 = nn.FusedBatchNorm2d(planes * self.expansion) + else: + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU() + + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + if self.downsample is not None: + # Note self.downsample execute before self.conv1 has better performance + # when open allow_fuse_add_to_output optimizatioin in nn.Graph. + # Reference: https://github.com/Oneflow-Inc/OneTeam/issues/840#issuecomment-994903466 + # Reference: https://github.com/NVIDIA/cudnn-frontend/issues/21 + identity = self.downsample(x) + + out = self.conv1(x) + + if self.fuse_bn_relu: + out = self.bn1(out, None) + else: + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + + if self.fuse_bn_relu: + out = self.bn2(out, None) + else: + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + + if self.fuse_bn_add_relu: + out = self.bn3(out, identity) + else: + out = self.bn3(out) + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + fuse_bn_relu=False, + fuse_bn_add_relu=False, + channel_last=False, + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + self.fuse_bn_relu = fuse_bn_relu + self.fuse_bn_add_relu = fuse_bn_add_relu + self.channel_last = channel_last + if self.channel_last: + self.pad_input = True + else: + self.pad_input = False + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + + if self.pad_input: + channel_size = 4 + else: + channel_size = 3 + if self.channel_last: + os.environ["ONEFLOW_ENABLE_NHWC"] = "1" + self.conv1 = nn.Conv2d( + channel_size, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False + ) + + if self.fuse_bn_relu: + self.bn1 = nn.FusedBatchNorm2d(self.inplanes) + else: + self.bn1 = self._norm_layer(self.inplanes) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + self.avgpool = nn.AvgPool2d((7, 7), stride=(1, 1)) + + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + fuse_bn_relu=self.fuse_bn_relu, + fuse_bn_add_relu=self.fuse_bn_add_relu, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + fuse_bn_relu=self.fuse_bn_relu, + fuse_bn_add_relu=self.fuse_bn_add_relu, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + if self.pad_input: + if self.channel_last: + # NHWC + paddings = (0, 1) + else: + # NCHW + paddings = (0, 0, 0, 0, 0, 1) + x = flow._C.pad(x, pad=paddings, mode="constant", value=0) + x = self.conv1(x) + if self.fuse_bn_relu: + x = self.bn1(x, None) + else: + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = flow.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + return model + + +def resnet50(**kwargs: Any) -> ResNet: + r"""ResNet-5 + `"Deep Residual Learning for Image Recognition" `_. + """ + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs) \ No newline at end of file diff --git a/projects/fast_resnet50/optimizer.py b/projects/fast_resnet50/optimizer.py new file mode 100644 index 00000000..54cc8bcb --- /dev/null +++ b/projects/fast_resnet50/optimizer.py @@ -0,0 +1,21 @@ +import oneflow as flow + + +def build_optimizer(config, model): + param_group = {"params": [p for p in model.parameters() if p is not None]} + + if config.train.clip_grad > 0.0: + assert config.clip_grad == 1.0, "ONLY support grad_clipping == 1.0" + param_group["clip_grad_max_norm"] = (1.0,) + param_group["clip_grad_norm_type"] = (2.0,) + + opt_lower = config.train.optim.name.lower() + optimizer = None + if opt_lower == "sgd": + optimizer = flow.optim.SGD( + [param_group], + lr = config.train.base_lr, + momentum = config.train.optimizer.momentum, + weight_decay = config.train.weight_decay + ) + return optimizer From 6dc14feea9c06927440d5b0592bf6c4641690d56 Mon Sep 17 00:00:00 2001 From: 59rentainhe <596106517@qq.com> Date: Tue, 12 Apr 2022 09:12:03 +0000 Subject: [PATCH 4/4] add dataloader --- projects/fast_resnet50/dataset.py | 38 ++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/projects/fast_resnet50/dataset.py b/projects/fast_resnet50/dataset.py index 4ec51402..7423c440 100644 --- a/projects/fast_resnet50/dataset.py +++ b/projects/fast_resnet50/dataset.py @@ -1,9 +1,33 @@ import os +from oneflow.utils.data import DataLoader + from flowvision import datasets, transforms from flowvision.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from flowvision.transforms.functional import str_to_interp_mode -from flowvision.data import Mixup + + +def build_loader(config): + config.defrost() + dataset_train, config.MODEL.NUM_CLASSES = build_dataset( + is_train=True, config=config + ) + config.freeze() + dataset_val, _ = build_dataset(is_train=False, config=config) + data_loader_train = DataLoader( + dataset_train, + batch_size=config.DATA.BATCH_SIZE, + num_workers=config.DATA.NUM_WORKERS, + drop_last=True, + ) + data_loader_val = DataLoader( + dataset_val, + batch_size=config.DATA.BATCH_SIZE, + shuffle=False, + num_workers=config.DATA.NUM_WORKERS, + drop_last=False, + ) + return dataset_train, dataset_val, data_loader_train, data_loader_val def build_dataset(is_train, config): @@ -12,7 +36,19 @@ def build_dataset(is_train, config): prefix = "train" if is_train else "val" root = os.path.join(config.DATA.DATA_PATH, prefix) dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 1000 + elif config.DATA.DATASET == "cifar100": + dataset = datasets.CIFAR100( + root = config.DATA.DATA_PATH, + train=is_train, + transform=transform, + download=True, + ) + nb_classes = 100 + else: + raise NotImplementedError("We only support ImageNet and CIFAR100 Now.") + return dataset, nb_classes def build_transform(is_train, config):