-
Notifications
You must be signed in to change notification settings - Fork 5
自定义主干网络
Mr.Li edited this page May 30, 2022
·
6 revisions
注:优先顺序 自定义网络 > timm库
持续更新SOTA的预训练模型(>600)。官方文档
适用场景:仅更改预测类别数。
-
挑选SOTA模型
import timm from pprint import pprint model_names = timm.list_models(pretrained=True) # 查询模型名称 # model_names = timm.list_models('*resne*t*') # 支持通配符 pprint(model_names) >>> ['resnet18','resnet50',...]
-
配置文件启用Backbone = "resnet18"
适用场景:自定义网络。
- 主干网络入口
Models/Backbone/backbone.py
实现具体方法,命名为MyNet
。 - 配置文件启用Backbone = "MyNet"
基本结构
-
输入:3x224x224的图像
-
卷积层:降采样2、4、8、16、32倍,输出channel x 7 x 7特征。
-
全局池化层:全局池化,输出N维特征(N=通道数)。
-
分类层:输出最终类别维度。
# 创建 无全局池化和分类层的网络, 将输出卷积层最后特征 [batch, channel, 7, 7]
features = timm.create_model("resnet18", pretrained=True, num_classes=0, global_pool='')
# 创建 无分类层的网络, 将输出全局池化后的特征 [batch, channel]
features = timm.create_model("resnet18", pretrained=True, num_classes=0)
# 将输出 卷积层5次降采样的特征
features = timm.create_model("resnet18", pretrained=True, features_only=True)
img = torch.ones((1, 3, 224, 224))
features(img) # list
# [batch,x,112,112]
# [batch,x,56,56]
# [batch,x,28,28]
# [batch,x,14,14]
# [batch,x,7,7]