Skip to content

Latest commit

 

History

History
234 lines (166 loc) · 6.37 KB

customize_models.md

File metadata and controls

234 lines (166 loc) · 6.37 KB

Tutorial 4: Customize Models

Customize optimizer

Assume you want to add a optimizer named as MyOptimizer, which has arguments a, b, and c. You need to first implement the new optimizer in a file, e.g., in mmseg/core/optimizer/my_optimizer.py:

from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer


@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):

    def __init__(self, a, b, c)

Then add this module in mmseg/core/optimizer/__init__.py thus the registry will find the new module and add it:

from .my_optimizer import MyOptimizer

Then you can use MyOptimizer in optimizer field of config files. In the configs, the optimizers are defined by the field optimizer like the following:

optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)

To use your own optimizer, the field can be changed as

optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)

We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the optimizer field of config files. For example, if you want to use ADAM, though the performance will drop a lot, the modification could be as the following.

optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)

The users can directly set arguments following the API doc of PyTorch.

Customize optimizer constructor

Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNoarm layers. The users can do those fine-grained parameter tuning through customizing optimizer constructor.

from mmcv.utils import build_from_cfg

from mmcv.runner import OPTIMIZER_BUILDERS
from .cocktail_optimizer import CocktailOptimizer


@OPTIMIZER_BUILDERS.register_module
class CocktailOptimizerConstructor(object):

    def __init__(self, optimizer_cfg, paramwise_cfg=None):

    def __call__(self, model):

        return my_optimizer

Develop new components

There are mainly 2 types of components in MMSegmentation.

  • backbone: usually stacks of convolutional network to extract feature maps, e.g., ResNet, HRNet.
  • head: the component for semantic segmentation map decoding.

Add new backbones

Here we show how to develop new components with an example of MobileNet.

  1. Create a new file mmseg/models/backbones/mobilenet.py.
import torch.nn as nn

from ..builder import BACKBONES


@BACKBONES.register_module
class MobileNet(nn.Module):

    def __init__(self, arg1, arg2):
        pass

    def forward(self, x):  # should return a tuple
        pass

    def init_weights(self, pretrained=None):
        pass
  1. Import the module in mmseg/models/backbones/__init__.py.
from .mobilenet import MobileNet
  1. Use it in your config file.
model = dict(
    ...
    backbone=dict(
        type='MobileNet',
        arg1=xxx,
        arg2=xxx),
    ...

Add new heads

In MMSegmentation, we provide a base BaseDecodeHead for all segmentation head. All newly implemented decode heads should be derived from it. Here we show how to develop a new head with the example of PSPNet as the following.

First, add a new decode head in mmseg/models/decode_heads/psp_head.py. PSPNet implements a decode head for segmentation decode. To implement a decode head, basically we need to implement three functions of the new module as the following.

@HEADS.register_module()
class PSPHead(BaseDecodeHead):

    def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
        super(PSPHead, self).__init__(**kwargs)

    def init_weights(self):

    def forward(self, inputs):

Next, the users need to add the module in the mmseg/models/decode_heads/__init__.py thus the corresponding registry could find and load them.

To config file of PSPNet is as the following

norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
    type='EncoderDecoder',
    pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
    backbone=dict(
        type='ResNetV1c',
        depth=50,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        dilations=(1, 1, 2, 4),
        strides=(1, 2, 1, 1),
        norm_cfg=norm_cfg,
        norm_eval=False,
        style='pytorch',
        contract_dilation=True),
    decode_head=dict(
        type='PSPHead',
        in_channels=2048,
        in_index=3,
        channels=512,
        pool_scales=(1, 2, 3, 6),
        dropout_ratio=0.1,
        num_classes=19,
        norm_cfg=norm_cfg,
        align_corners=False,
        loss_decode=dict(
            type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))

Add new loss

Assume you want to add a new loss as MyLoss for segmentation decode. To add a new loss function, the users need implement it in mmseg/models/losses/my_loss.py. The decorator weighted_loss enable the loss to be weighted for each element.

import torch
import torch.nn as nn

from ..builder import LOSSES
from .utils import weighted_loss

@weighted_loss
def my_loss(pred, target):
    assert pred.size() == target.size() and target.numel() > 0
    loss = torch.abs(pred - target)
    return loss

@LOSSES.register_module
class MyLoss(nn.Module):

    def __init__(self, reduction='mean', loss_weight=1.0):
        super(MyLoss, self).__init__()
        self.reduction = reduction
        self.loss_weight = loss_weight

    def forward(self,
                pred,
                target,
                weight=None,
                avg_factor=None,
                reduction_override=None):
        assert reduction_override in (None, 'none', 'mean', 'sum')
        reduction = (
            reduction_override if reduction_override else self.reduction)
        loss = self.loss_weight * my_loss(
            pred, target, weight, reduction=reduction, avg_factor=avg_factor)
        return loss

Then the users need to add it in the mmseg/models/losses/__init__.py.

from .my_loss import MyLoss, my_loss

To use it, modify the loss_xxx field. Then you need to modify the loss_decode field in the head. loss_weight could be used to balance multiple losses.

loss_decode=dict(type='MyLoss', loss_weight=1.0))