Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convit into models #238

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Add convit into models #238

wants to merge 3 commits into from

Conversation

triple-Mu
Copy link
Contributor

Add convit into models

@wzy9813125
Copy link

对convit分类网络进行测试,不进行预训练,convit_tiny、convit_small、convit_base皆可正常运行。
当进行预训练时,由于网络是在imagenet数据集上进行的预训练,所以最后的分类输出种类是1000。采用的测试集案例为10分类,所以要修改最后一个Linear层的输出通道数为10。一般的分类网络vgg、alexnet等最后一层为classifier层。所以常规的使用预训练的代码为:

import oneflow as flow
from flowvision.models import ModelCreator     

net = ModelCreator.create_model("alexnet", pretrained=True)
num_fc = net.classifier[6].in_features
net.classifier[6] = flow.nn.Linear(in_features=num_fc, out_features=10)

在测试过程中,convit分类网络没有使用classifier层,取而代之的是head层,所以在使用时需要注意在加载预训练模型时需对以上常规代码进行修改,修改后可正常运行。
修改为:

net = ModelCreator.create_model("convit_tiny",pretrained = True)
num_fc = net.head.in_features
net.head = torch.nn.Linear(in_features=num_fc, out_features=10)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants