Model Parameters Embedding for AI models attack
[TOC]
🔗🔗🔗
introduction to torchvision.models
: https://pytorch.org/vision/stable/models.html
use the models and pro-trained weights from torchvision Lib.
example: Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/mac/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
- Lazarus 20.9MB
- test1.jpeg 864KB
- test1.jpeg 3.1MB
📄📄📄
- The lower 8 bits of the last four parameters are used to store the chunk number
- The lower 8 bits of the penultimate fifth parameters are used to store the remainder
- The lower 8 bits of the penultimate sixth argument are used to store the chunk size
Test the effect of replacing harmful information with the last N bits of the weight parameter of the last fully connected layer on the performance of the model.
Experiment Setup: Models Source: torchvision Lib System: MacOs
N | 8 | 16 | 20 | 21 | 22 | 23 | 24 |
---|
'''数据预处理'''
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
'''加载ImageNet数据集'''
val_dataset = datasets.ImageNet(root='../../dataset', split='val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
'''定义损失函数和评估指标'''
criterion = nn.CrossEntropyLoss()
All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].
本章节中的性能测试均在最后一个全连接层进行性能测试。将低位的比特位全部翻转fcWeightTensor["fc.weight"].data
criterion = nn.CrossEntropyLoss()
准确率不会有变化,因为进行的是一个确定参数的推理过程
python ./models/pretrainedModels/resnet18.py
python ./models/embeddedModels/resnet18.py
原始weights结构:只有一个全连接层,在最后一层,weights的大小是[1000, 512],每个数据类型是torch.float32,数据总量是2.048MB
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | resnet-18 | resenet18-f37072fd.pth | ILSVRC2012_devkit_t12 | 67.27% | 1.3528 |
20240805 | resnet-18 | resenet18_embedding_8_32.pth | ILSVRC2012_devkit_t12 | 67.27% | 1.3528 |
20240805 | resnet-18 | resenet18_embedding_16_32.pth | ILSVRC2012_devkit_t12 | 67.27% | 1.3534 |
20240805 | resnet-18 | resenet18_embedding_20_32.pth | ILSVRC2012_devkit_t12 | 67.12% | 1.3605 |
20240805 | resnet-18 | resenet18_embedding_21_32.pth | ILSVRC2012_devkit_t12 | 66.64% | 1.3775 |
20240805 | resnet-18 | resenet18_embedding_22_32.pth | ILSVRC2012_devkit_t12 | 65.21% | 1.4511 |
20240805 | resnet-18 | resenet18_embedding_23_32.pth | ILSVRC2012_devkit_t12 | 61.39% | 1.7158 |
20240805 | resnet-18 | resenet18_embedding_24_32.pth | ILSVRC2012_devkit_t12 | 49.36% | 2.9826 |
准确率不会有变化,因为进行的是一个确定参数的推理过
python ./models/pretrainedModels/resnet50.py
python ./models/embeddedModels/resnet50.py
原始weights结构:只有一个全连接层,在最后一层,weights的大小是[1000, 2048],每个数据类型是torch.float32,数据总量是8.192MB
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | resnet-50 | resnet50-11ad3fa6.pth | ILSVRC2012_devkit_t12 | 80.12% | 1.4179 |
20240805 | resnet-50 | resenet50_embedding_8_32.pth | ILSVRC2012_devkit_t12 | 80.12% | 1.4178 |
20240805 | resnet-50 | resenet50_embedding_16_32.pth | ILSVRC2012_devkit_t12 | 80.12% | 1.4184 |
20240805 | resnet-50 | resenet50_embedding_20_32.pth | ILSVRC2012_devkit_t12 | 79.94% | 1.4146 |
20240805 | resnet-50 | resenet50_embedding_21_32.pth | ILSVRC2012_devkit_t12 | 79.59% | 1.4004 |
20240805 | resnet-50 | resenet50_embedding_22_32.pth | ILSVRC2012_devkit_t12 | 77.94% | 1.3310 |
20240805 | resnet-50 | resenet50_embedding_23_32.pth | ILSVRC2012_devkit_t12 | 72.30% | 1.4355 |
20240805 | resnet-50 | resenet50_embedding_24_32.pth | ILSVRC2012_devkit_t12 | 64.40% | 1.8136 |
准确率不会有变化,因为进行的是一个确定参数的推理过
python ./models/pretrainedModels/resnet101.py
python ./models/embeddedModels/resnet101.py
原始weights结构:只有一个全连接层,在最后一层,weights的大小是[1000, 2048],每个数据类型是torch.float32,数据总量是8.192MB
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | resnet-101 | resnet101-cd907fc2.pth | ILSVRC2012_devkit_t12 | 80.94% | 0.9221 |
20240805 | resnet-101 | resnet101_embedding_8_32.pth | ILSVRC2012_devkit_t12 | 80.94% | 0.9223 |
20240805 | resnet-101 | resnet101_embedding_16_32.pth | ILSVRC2012_devkit_t12 | 80.96% | 0.9224 |
20240805 | resnet-101 | resnet101_embedding_20_32.pth | ILSVRC2012_devkit_t12 | 80.76% | 0.9240 |
20240805 | resnet-101 | resnet101_embedding_21_32.pth | ILSVRC2012_devkit_t12 | 80.48% | 0.9253 |
20240805 | resnet-101 | resnet101_embedding_22_32.pth | ILSVRC2012_devkit_t12 | 79.07% | 0.9568 |
20240805 | resnet-101 | resnet101_embedding_23_32.pth | ILSVRC2012_devkit_t12 | 74.37% | 1.2267 |
20240805 | resnet-101 | resnet101_embedding_24_32.pth | ILSVRC2012_devkit_t12 | 68.66% | 1.6655 |
本章节中的性能测试均在最后一个全连接层进行性能测试。将低位的比特位全部翻转fcWeightTensor["classier.6.weight"]
criterion = nn.CrossEntropyLoss()
python ./models/pretrainedModels/vgg11.py
python ./models/embeddedModels/vgg11.py
在最后一个全连接层进行嵌入 [4096, 1000],16.384MB。
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | VGG-11 | vgg11-8a719046.pth | ILSVRC2012_devkit_t12 | 66.88% | 1.3540 |
20240805 | VGG-11 | vgg11_embdedding_8_32.pth | ILSVRC2012_devkit_t12 | 66.88% | 1.3542 |
20240805 | VGG-11 | vgg11_embdedding_16_32.pth | ILSVRC2012_devkit_t12 | 66.89% | 1.3549 |
20240805 | VGG-11 | vgg11_embdedding_20_32.pth | ILSVRC2012_devkit_t12 | 66.86% | 1.3564 |
20240805 | VGG-11 | vgg11_embdedding_21_32.pth | ILSVRC2012_devkit_t12 | 66.84% | 1.3589 |
20240805 | VGG-11 | vgg11_embdedding_22_32.pth | ILSVRC2012_devkit_t12 | 66.63% | 1.3714 |
20240805 | VGG-11 | vgg11_embdedding_23_32.pth | ILSVRC2012_devkit_t12 | 66.13% | 1.4388 |
20240805 | VGG-11 | vgg11_embdedding_24_32.pth | ILSVRC2012_devkit_t12 | 64.27% | 1.9315 |
python ./models/pretrainedModels/vgg13.py
python ./models/embeddedModels/vgg13.py
在最后一个全连接层进行嵌入 [4096, 1000],16.384MB
实验结果如下:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | VGG-13 | vgg13-19584684.pth | ILSVRC2012_devkit_t12 | 68.14% | 1.3052 |
20240805 | VGG-13 | vgg13_embdedding_8_32.pth | ILSVRC2012_devkit_t12 | 68.14% | 1.3035 |
20240805 | VGG-13 | vgg13_embdedding_16_32.pth | ILSVRC2012_devkit_t12 | 68.14% | 1.3041 |
20240805 | VGG-13 | vgg13_embdedding_20_32.pth | ILSVRC2012_devkit_t12 | 68.15% | 1.3048 |
20240805 | VGG-13 | vgg13_embdedding_21_32.pth | ILSVRC2012_devkit_t12 | 67.96% | 1.3087 |
20240805 | VGG-13 | vgg13_embdedding_22_32.pth | ILSVRC2012_devkit_t12 | 67.84% | 1.3209 |
20240805 | VGG-13 | vgg13_embdedding_23_32.pth | ILSVRC2012_devkit_t12 | 67.25% | 1.3859 |
20240805 | VGG-13 | vgg13_embdedding_24_32.pth | ILSVRC2012_devkit_t12 | 65.42% | 1.8825 |
python ./models/pretrainedModels/vgg16.py
python ./models/embeddedModels/vgg16.py
在最后一个全连接层进行嵌入 [4096, 1000],16.384MB
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | VGG-16 | vgg16-397923af.pth | ILSVRC2012_devkit_t12 | 70.02% | 1.2210 |
20240805 | VGG-16 | vgg16_embdedding_8_32.pth | ILSVRC2012_devkit_t12 | 70.02% | 1.2230 |
20240805 | VGG-16 | vgg16_embdedding_16_32.pth | ILSVRC2012_devkit_t12 | 70.03% | 1.2218 |
20240805 | VGG-16 | vgg16_embdedding_20_32.pth | ILSVRC2012_devkit_t12 | 69.98% | 1.2225 |
20240805 | VGG-16 | vgg16_embdedding_21_32.pth | ILSVRC2012_devkit_t12 | 69.93% | 1.2267 |
20240805 | VGG-16 | vgg16_embdedding_22_32.pth | ILSVRC2012_devkit_t12 | 69.72% | 1.2424 |
20240805 | VGG-16 | vgg16_embdedding_23_32.pth | ILSVRC2012_devkit_t12 | 69.35% | 1.3138 |
20240805 | VGG-16 | vgg16_embdedding_24_32.pth | ILSVRC2012_devkit_t12 | 67.52% | 1.8072 |
python ./models/pretrainedModels/vgg19.py
python ./models/embeddedModels/vgg19.py
在最后一个全连接层进行嵌入 [4096, 1000],16.384MB
实验结果:
RunDate | Model | InitPara | Dataset | Accuracy | Loss |
---|---|---|---|---|---|
20240805 | VGG-19 | vgg19-dcbb9e9d.pth | ILSVRC2012_devkit_t12 | 70.68% | 1.1922 |
20240805 | VGG-19 | vgg19_embedding_8_32.pth | ILSVRC2012_devkit_t12 | 70.68% | 1.1923 |
20240805 | VGG-19 | vgg19_embedding_16_32.pth | ILSVRC2012_devkit_t120 | 70.68% | 1.1922 |
20240805 | VGG-19 | vgg19_embedding_20_32.pth | ILSVRC2012_devkit_t12 | 70.68% | 1.1946 |
20240805 | VGG-19 | vgg19_embedding_21_32.pth | ILSVRC2012_devkit_t12 | 70.69% | 1.1967 |
20240805 | VGG-19 | vgg19_embedding_22_32.pth | ILSVRC2012_devkit_t12 | 70.44% | 1.2112 |
20240805 | VGG-19 | vgg19_embedding_23_32.pth | ILSVRC2012_devkit_t12 | 70.09% | 1.2840 |
20240805 | VGG-19 | vgg19_embedding_24_32.pth | ILSVRC2012_devkit_t12 | 68.19% | 1.7680 |
Testing the performance of embedding harmful information on the parameters of the first convolutional layer
using the ImageNet Validation dataset
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-18 | InitialPara | 67.27% | 1.3528 |
20240806 | ResNet-18 | 16_32 | 67.30% | 1.3527 |
20240806 | ResNet-18 | 19_32 | 67.22% | 1.3563 |
20240806 | ResNet-18 | 20_32 | 67.12% | 1.3591 |
20240806 | ResNet-18 | 21_32 | 66.22% | 1.4028 |
20240806 | ResNet-18 | 22_32 | 57.47% | 1.8530 |
20240806 | ResNet-18 | 23_32 | 35.39% | 3.3473 |
20240806 | ResNet-18 | 24_32 | 8.98% | 7.2026 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-50 | InitialPara | 80.12% | 1.4179 |
20240806 | ResNet-50 | 16_32 | 80.13% | 1.4182 |
20240806 | ResNet-50 | 19_32 | 79.87% | 1.4331 |
20240806 | ResNet-50 | 20_32 | 78.82% | 1.4833 |
20240806 | ResNet-50 | 21_32 | 76.27% | 1.5902 |
20240806 | ResNet-50 | 22_32 | 65.63% | 2.0357 |
20240806 | ResNet-50 | 23_32 | 57.54% | 2.4152 |
20240806 | ResNet-50 | 24_32 | 9.07% | 10.0144 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-101 | InitialPara | 80.94% | 0.9221 |
20240806 | ResNet-101 | 16_32 | 80.95% | 0.9221 |
20240806 | ResNet-101 | 19_32 | 80.87% | 0.9239 |
20240806 | ResNet-101 | 20_32 | 80.75% | 0.9309 |
20240806 | ResNet-101 | 21_32 | 79.41% | 0.9939 |
20240806 | ResNet-101 | 22_32 | 73.24% | 1.2881 |
20240806 | ResNet-101 | 23_32 | 62.81% | 1.8137 |
20240806 | ResNet-101 | 24_32 | 37.95% | 3.3804 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-11 | InitialPara | 66.88% | 1.3540 |
20240806 | VGG-11 | 16_32 | 66.91% | 1.3545 |
20240806 | VGG-11 | 19_32 | 66.63% | 1.3663 |
20240806 | VGG-11 | 20_32 | 64.29% | 1.4808 |
20240806 | VGG-11 | 21_32 | 60.45% | 1.6614 |
20240806 | VGG-11 | 22_32 | 44.65% | 2.5821 |
20240806 | VGG-11 | 23_32 | 33.32% | 3.3020 |
20240806 | VGG-11 | 24_32 | 3.93% | 8.6345 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-13 | InitialPara | 68.14% | 1.3052 |
20240806 | VGG-13 | 16_32 | 68.15% | 1.3049 |
20240806 | VGG-13 | 19_32 | 68.11% | 1.3051 |
20240806 | VGG-13 | 20_32 | 67.65% | 1.3266 |
20240806 | VGG-13 | 21_32 | 67.07% | 1.3522 |
20240806 | VGG-13 | 22_32 | 58.91% | 1.7601 |
20240806 | VGG-13 | 23_32 | 54.99% | 1.9757 |
20240806 | VGG-13 | 24_32 | 37.88% | 3.2278 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-16 | InitialPara | 70.02% | 1.2210 |
20240806 | VGG-16 | 16_32 | 70.01% | 1.2216 |
20240806 | VGG-16 | 19_32 | 69.91% | 1.2270 |
20240806 | VGG-16 | 20_32 | 69.77% | 1.2306 |
20240806 | VGG-16 | 21_32 | 69.13% | 1.2609 |
20240806 | VGG-16 | 22_32 | 61.51% | 1.6538 |
20240806 | VGG-16 | 23_32 | 56.89% | 1.8700 |
20240806 | VGG-16 | 24_32 | 24.46% | 4.8595 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-19 | InitialPara | 70.68% | 1.1922 |
20240806 | VGG-19 | 16_32 | 70.68% | 1.1920 |
20240806 | VGG-19 | 19_32 | 70.63% | 1.1955 |
20240806 | VGG-19 | 20_32 | 70.44% | 1.2059 |
20240806 | VGG-19 | 21_32 | 69.03% | 1.2730 |
20240806 | VGG-19 | 22_32 | 65.51% | 1.4517 |
20240806 | VGG-19 | 23_32 | 55.29% | 2.0323 |
20240806 | VGG-19 | 24_32 | 40.71% | 3.0027 |
对参数中所有高维(维度大于2)的参数的最低比特进行嵌入测试:
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-18 | InitialPara | 67.27% | 1.3528 |
20240806 | ResNet-18 | 16_32 | 67.30% | 1.3528 |
20240806 | ResNet-18 | 17_32 | 67.28% | 1.3539 |
20240806 | ResNet-18 | 18_32 | 67.20% | 1.3565 |
20240806 | ResNet-18 | 19_32 | 67.13% | 1.3631 |
20240806 | ResNet-18 | 20_32 | 66.26% | 1.3953 |
20240806 | ResNet-18 | 21_32 | 61.99% | 1.5967 |
20240806 | ResNet-18 | 22_32 | 37.74% | 3.1640 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-50 | InitialPara | 80.12% | 1.4179 |
20240806 | ResNet-50 | 16_32 | 80.12% | 1.4190 |
20240806 | ResNet-50 | 17_32 | 80.16% | 1.4128 |
20240806 | ResNet-50 | 18_32 | 80.14% | 1.4116 |
20240806 | ResNet-50 | 19_32 | 79.16% | 1.4404 |
20240806 | ResNet-50 | 20_32 | 75.14% | 1.5994 |
20240806 | ResNet-50 | 21_32 | 55.08% | 2.5866 |
20240806 | ResNet-50 | 22_32 | 4.92% | 6.5583 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | ResNet-101 | InitialPara | 80.94% | 0.9221 |
20240806 | ResNet-101 | 16_32 | 80.92% | 0.9227 |
20240806 | ResNet-101 | 17_32 | 80.92% | 0.9248 |
20240806 | ResNet-101 | 18_32 | 80.78% | 0.9305 |
20240806 | ResNet-101 | 19_32 | 80.65% | 0.9432 |
20240806 | ResNet-101 | 20_32 | 79.86% | 0.9712 |
20240806 | ResNet-101 | 21_32 | 72.17% | 1.3286 |
20240806 | ResNet-101 | 22_32 | 19.51% | 5.3451 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-11 | InitialPara | 66.88% | 1.3540 |
20240806 | VGG-11 | 16_32 | 66.93% | 1.3546 |
20240806 | VGG-11 | 17_32 | 66.88% | 1.3546 |
20240806 | VGG-11 | 18_32 | 66.79% | 1.3593 |
20240806 | VGG-11 | 19_32 | 66.65% | 1.3688 |
20240806 | VGG-11 | 20_32 | 64.04% | 1.4937 |
20240806 | VGG-11 | 21_32 | 59.72% | 1.7133 |
20240806 | VGG-11 | 22_32 | 42.89% | 3.0529 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-13 | InitialPara | 68.14% | 1.3052 |
20240806 | VGG-13 | 16_32 | 68.13% | 1.3044 |
20240806 | VGG-13 | 17_32 | 68.14% | 1.3047 |
20240806 | VGG-13 | 18_32 | 68.10% | 1.3060 |
20240806 | VGG-13 | 19_32 | 67.97% | 1.3081 |
20240806 | VGG-13 | 20_32 | 67.52% | 1.3351 |
20240806 | VGG-13 | 21_32 | 66.20% | 1.4033 |
20240806 | VGG-13 | 22_32 | 54.69% | 2.3744 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-16 | InitialPara | 70.02% | 1.2210 |
20240806 | VGG-16 | 16_32 | 69.98% | 1.2226 |
20240806 | VGG-16 | 17_32 | 69.96% | 1.2221 |
20240806 | VGG-16 | 18_32 | 69.98% | 1.2240 |
20240806 | VGG-16 | 19_32 | 69.82% | 1.2306 |
20240806 | VGG-16 | 20_32 | 69.57% | 1.2415 |
20240806 | VGG-16 | 21_32 | 68.36% | 1.3269 |
20240806 | VGG-16 | 22_32 | 56.11% | 2.4337 |
RunDate | Model | Para | Accuracy | Loss |
---|---|---|---|---|
20240806 | VGG-19 | InitialPara | 70.68% | 1.1922 |
20240806 | VGG-19 | 16_32 | 70.67% | 1.1925 |
20240806 | VGG-19 | 17_32 | 70.66% | 1.1936 |
20240806 | VGG-19 | 18_32 | 70.70% | 1.1930 |
20240806 | VGG-19 | 19_32 | 70.65% | 1.1992 |
20240806 | VGG-19 | 20_32 | 70.40% | 1.2225 |
20240806 | VGG-19 | 21_32 | 68.26% | 1.3580 |
20240806 | VGG-19 | 22_32 | 62.23% | 1.9975 |
对深度深度神经网络进行重训练(微调)时,训练模型哪些层距取决于几个因素:数据集的相似性、数据规模、计算资源以及具体任务需求:
-
重训练内容:
- **全连接层:**重训练全连接层是最常见的做法,因为直接与分类任务相关。通常用新的全连接层替换预训练模型的最后一层,以适应新任务的输出类别数量。
- **高层卷积层:**如果新任务的数据与预测数据存在一定差异,通常重训练最后几层卷积层。这些层学习到的特征与特定任务相关,通过重训练可以使模型更好适应新的数据特征。
- **所有层:**如果新任务数据与原数据差异极大,且数据集足够大,可能需要重训练所有层。耗费计算资源,但是提供更大灵活性。
-
重训练低层卷积层需要考虑因素
低层卷积层通常学习到的是通用的图像特征,例如边缘、纹理、颜色变化等。这些特征在许多视觉任务中是共享的。因此是否重训练这些层需要根据具体情况进行判断。
- **数据集相似性:**相似数据集:如果新任务的数据集与预训练的数据集非常相似,一般不需要重训练底层卷积层。这些层能够有效捕获许多通用特征;不同数据集:如果新数据集与预训练数据集差异较大,例如从自然图像转到医学图像,或从真实图像转到手绘插图,重训练底层卷积层会更有利于模型适应新的数据特征。
- **数据集规模:**小数据集:对于小数据集,通常选择冻结低层卷积层,以减少过拟合的风险并利用预训练模型已有的特征;大数据集:对于较大规模的数据集,尤其是新任务与预训练任务差异较大时,可以考虑重训练低层卷积层。
- **计算资源:**有限的计算资源:冻结低层卷积层可以减少计算开销和训练时间,因为这些层通常是计算密集型的;充足的计算资源:可以选择重训练所有层,包括低层卷积层,以充分利用计算资源提高模型性能。
-
汉明距离
-
**任务迁移:**迁移学习应用广泛,设计从大型、广泛的数据集迁移到特定、较小的数据集,或在相似任务之间进行知识迁移。 从ImageNet迁移到特定领域的图像数据集,通常用于预训练深度学习模型。这些模型随后可以迁移到具体的小型数据集上,如医疗图像识别、卫星图像分析等。
目标数据集示例:医疗影像数据集(如肺部X光图像用于检测肺炎)、车辆识别、动物识别。
-
从ImageNet到医疗影像
从ImageNet到医疗影像数据集进行迁移学习时,参数变化的程度取决于多个因素,包括任务的不同、数据特性的差异,以及迁移学习策略的选择。以下是一些主要的考虑点:
- 数据特性的差异。图像内容和上下文:ImageNet包含多样的自然图像,覆盖从动物到日常物品的广泛类别,而医疗影像(如CT、MRI或X光图)主要是人体内部结构的图像,具有更高的结构和内容专业性;图像格式和处理:医疗图像可能是灰度图,分辨率和对比度也与常规彩色照片不同,这需要在预处理阶段进行调整。
- 模型架构的调整。特征提取层:在ImageNet上预训练的模型,如VGG、ResNet等,其卷积层通常不需要大幅修改,因为它们负责从图像中提取通用特征;分类或决策层:最后的全连接层或分类层往往需要根据医疗图像的具体任务进行重大调整或完全替换,以适应不同的输出需求(如疾病分类、异常检测等)。
- 参数微调和再训练。微调程度:尽管模型的初级和中级特征提取器可能保持相对稳定,但针对特定医疗任务,通常需要对更高层或特定层的参数进行微调;学习策略:可以选择冻结初级卷积层的参数,仅对高层进行微调;或者对整个网络进行细致的调整,特别是在医疗图像数据上。
- 训练数据量和标注。数据量:医疗图像数据集通常比ImageNet小得多,这可能限制了对深层网络大规模重新训练的可能性;数据标注:医疗数据的标注需要专业知识,且标注成本高,这可能影响训练数据的质量和数量。
- 总结:迁移学习中,从ImageNet到医疗影像的参数变化可能从较小(如仅微调最后几层)到较大(如完全重新设计分类层和深层结构)。因此,具体变化程度取决于目标任务的复杂性和所需的精确度。通常,迁移学习策略旨在通过利用在ImageNet上学到的通用视觉特征,减少在医疗领域所需的训练数据量,同时优化模型在特定医疗任务上的性能。
测试从ImageNet1000迁移到CIFAR-100数据集,各个模型的参数会如何变化。
实验选取三种情况:欠拟合,拟合,过拟合来探究参数的变化;同时探究参数的每一层是如何变化的
网络位置:conv1
# 7bit
[9.31122449e-01 2.95493197e-02 5.95238095e-03 2.23214286e-02
2.12585034e-04 1.38180272e-03 4.67687075e-03 4.78316327e-03
0.00000000e+00]
[9.31122449e-01 2.95493197e-02 5.95238095e-03 2.23214286e-02
2.12585034e-04 1.38180272e-03 4.67687075e-03 4.78316327e-03
0.00000000e+00]
[9.31122449e-01 2.95493197e-02 5.95238095e-03 2.23214286e-02
2.12585034e-04 1.38180272e-03 4.67687075e-03 4.78316327e-03
0.00000000e+00]
# 8bit
[8.83503401e-01 5.34651361e-02 2.49787415e-02 5.84608844e-03
2.12585034e-02 6.37755102e-04 2.97619048e-03 4.67687075e-03
2.65731293e-03]
[8.83503401e-01 5.34651361e-02 2.49787415e-02 5.84608844e-03
2.12585034e-02 6.37755102e-04 2.97619048e-03 4.67687075e-03
2.65731293e-03]
[8.83503401e-01 5.34651361e-02 2.49787415e-02 5.84608844e-03
2.12585034e-02 6.37755102e-04 2.97619048e-03 4.67687075e-03
2.65731293e-03]
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9889455782312926 | 0.9677933673469387 |
epoch=10 | 0.9802295918367347 | 0.9540816326530612 |
epoch=20 | 0.9788477891156463 | 0.951530612244898 |
网络位置:layer1[0].conv1
# 7bit
[9.09505208e-01 3.89268663e-02 7.05295139e-03 2.95952691e-02
5.69661458e-04 3.63498264e-03 6.13064236e-03 4.58441840e-03
0.00000000e+00]
[9.09505208e-01 3.89268663e-02 7.05295139e-03 2.95952691e-02
5.69661458e-04 3.63498264e-03 6.13064236e-03 4.58441840e-03
0.00000000e+00]
[9.09505208e-01 3.89268663e-02 7.05295139e-03 2.95952691e-02
5.69661458e-04 3.63498264e-03 6.13064236e-03 4.58441840e-03
0.00000000e+00]
# 8bit
[0.85356988 0.06382921 0.0343967 0.00558811 0.02794054 0.00203451
0.00520833 0.00488281 0.00254991]
[0.85356988 0.06382921 0.0343967 0.00558811 0.02794054 0.00203451
0.00520833 0.00488281 0.00254991]
[0.85356988 0.06382921 0.0343967 0.00558811 0.02794054 0.00203451
0.00520833 0.00488281 0.00254991]
7bit | 8bit | |
---|---|---|
epoch = 1 | 0.9850802951388888 | 0.9573838975694443 |
epoch = 10 | 0.9707573784722222 | 0.930908203125 |
Epoch = 20 | 0.9690212673611113 | 0.9275444878472223 |
网络位置:layer1[0].conv2
# 7bit
[9.02153863e-01 4.24804688e-02 8.19227431e-03 3.05718316e-02
5.69661458e-04 4.58441840e-03 6.18489583e-03 5.26258681e-03
0.00000000e+00]
[9.02153863e-01 4.24804688e-02 8.19227431e-03 3.05718316e-02
5.69661458e-04 4.58441840e-03 6.18489583e-03 5.26258681e-03
0.00000000e+00]
[9.02153863e-01 4.24804688e-02 8.19227431e-03 3.05718316e-02
5.69661458e-04 4.58441840e-03 6.18489583e-03 5.26258681e-03
0.00000000e+00]
# 8bit
[0.85077582 0.06792535 0.03586155 0.00459798 0.02746582 0.00280762
0.00391981 0.00416395 0.0024821 ]
[0.85077582 0.06792535 0.03586155 0.00459798 0.02746582 0.00280762
0.00391981 0.00416395 0.0024821 ]
[0.85077582 0.06792535 0.03586155 0.00459798 0.02746582 0.00280762
0.00391981 0.00416395 0.0024821 ]
7bit | 8bit | |
---|---|---|
epoch = 1 | 0.9833984375 | 0.9591606987847222 |
epoch = 10 | 0.9629991319444445 | 0.9101426866319444 |
Epoch = 20 | 0.958740234375 | 0.9027777777777778 |
网络位置:layer2[0].conv1
# 7bit
[9.10685221e-01 4.03103299e-02 7.64973958e-03 2.76014540e-02
7.18858507e-04 4.51660156e-03 3.78417969e-03 4.73361545e-03
0.00000000e+00]
[9.10685221e-01 4.03103299e-02 7.64973958e-03 2.76014540e-02
7.18858507e-04 4.51660156e-03 3.78417969e-03 4.73361545e-03
0.00000000e+00]
[9.10685221e-01 4.03103299e-02 7.64973958e-03 2.76014540e-02
7.18858507e-04 4.51660156e-03 3.78417969e-03 4.73361545e-03
0.00000000e+00]
# 8bit
[0.83897569 0.07131619 0.03822157 0.00575087 0.02943251 0.00263129
0.00550673 0.00569661 0.00246853]
[0.83897569 0.07131619 0.03822157 0.00575087 0.02943251 0.00263129
0.00550673 0.00569661 0.00246853]
[0.83897569 0.07131619 0.03822157 0.00575087 0.02943251 0.00263129
0.00550673 0.00569661 0.00246853]
7bit | 8bit | |
---|---|---|
epoch = 1 | 0.9862467447916667 | 0.9542643229166667 |
epoch = 10 | 0.9640570746527778 | 0.912082248263889 |
Epoch = 20 | 0.9599609375 | 0.906277126736111 |
网络位置:layer2[0].conv2
# 7bit
[9.22980414e-01 3.42203776e-02 9.06711155e-03 2.28339301e-02
7.93457031e-04 3.94015842e-03 2.62451172e-03 3.54003906e-03
0.00000000e+00]
[9.22980414e-01 3.42203776e-02 9.06711155e-03 2.28339301e-02
7.93457031e-04 3.94015842e-03 2.62451172e-03 3.54003906e-03
0.00000000e+00]
[9.22980414e-01 3.42203776e-02 9.06711155e-03 2.28339301e-02
7.93457031e-04 3.94015842e-03 2.62451172e-03 3.54003906e-03
0.00000000e+00]
# 8bit
[0.87089708 0.05729167 0.03350152 0.00495741 0.02283393 0.00245497
0.00316026 0.00310601 0.00179715]
[0.87089708 0.05729167 0.03350152 0.00495741 0.02283393 0.00245497
0.00316026 0.00310601 0.00179715]
[0.87089708 0.05729167 0.03350152 0.00495741 0.02283393 0.00245497
0.00316026 0.00310601 0.00179715]
7bit | 8bit | |
---|---|---|
epoch = 1 | 0.989101833767361 | 0.9666476779513888 |
epoch = 10 | 0.9629991319444443 | 0.9047309027777777 |
Epoch = 20 | 0.9584282769097221 | 0.895107693142361 |
网络位置:最后一个卷积层layer4[1].conv2
# 7bit
[0.85805427 0.05552928 0.01861615 0.04222361 0.00159497 0.00921419
0.00648287 0.00828467 0. ] 0.9744233025444878
[0.85805427 0.05552928 0.01861615 0.04222361 0.00159497 0.00921419
0.00648287 0.00828467 0. ] 0.9466260274251302
[0.85805427 0.05552928 0.01861615 0.04222361 0.00159497 0.00921419
0.00648287 0.00828467 0. ] 0.938822004530165
# 8bit
[0.77042262 0.0992224 0.05333328 0.01012039 0.04211129 0.00567542
0.00753487 0.00728692 0.00429281] 0.9330986870659722
[0.77042262 0.0992224 0.05333328 0.01012039 0.04211129 0.00567542
0.00753487 0.00728692 0.00429281] 0.8663889567057291
[0.77042262 0.0992224 0.05333328 0.01012039 0.04211129 0.00567542
0.00753487 0.00728692 0.00429281] 0.8497200012207031
7bit | 8bit | |
---|---|---|
epoch = 1 | 0.9744233025444878 | 0.9330986870659722 |
epoch = 10 | 0.9466260274251302 | 0.8663889567057291 |
Epoch = 20 | 0.938822004530165 | 0.8497200012207031 |
网络位置:第一个卷积层:conv1
# 7bit
[0.89423895 0.03900935 0.0113733 0.0355017 0.00255102 0.00340136
0.00669643 0.00722789 0. ] 0.9801232993197279
[0.89423895 0.03900935 0.0113733 0.0355017 0.00255102 0.00340136
0.00669643 0.00722789 0. ] 0.9706632653061223
[0.89423895 0.03900935 0.0113733 0.0355017 0.00255102 0.00340136
0.00669643 0.00722789 0. ] 0.9680059523809523
# 8bit
[0.83269558 0.06877126 0.03380102 0.01796344 0.02816752 0.00255102
0.00520833 0.0071216 0.00372024] 0.9532312925170068
[0.83269558 0.06877126 0.03380102 0.01796344 0.02816752 0.00255102
0.00520833 0.0071216 0.00372024] 0.9448341836734694
[0.83269558 0.06877126 0.03380102 0.01796344 0.02816752 0.00255102
0.00520833 0.0071216 0.00372024] 0.9430272108843538
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9801232993197279 | 0.9532312925170068 |
epoch=10 | 0.9706632653061223 | 0.9448341836734694 |
epoch=20 | 0.9680059523809523 | 0.9430272108843538 |
网络位置中部:layer2.0.conv2
# 7bit
[9.57302517e-01 1.90972222e-02 6.80881076e-03 1.12169054e-02
6.44259983e-04 1.51909722e-03 1.68185764e-03 1.72932943e-03
0.00000000e+00] 0.9944254557291666
[9.57302517e-01 1.90972222e-02 6.80881076e-03 1.12169054e-02
6.44259983e-04 1.51909722e-03 1.68185764e-03 1.72932943e-03
0.00000000e+00] 0.985595703125
[9.57302517e-01 1.90972222e-02 6.80881076e-03 1.12169054e-02
6.44259983e-04 1.51909722e-03 1.68185764e-03 1.72932943e-03
0.00000000e+00] 0.9828762478298612
# 8bit
[9.27435981e-01 3.16636827e-02 1.99924045e-02 4.42843967e-03
1.12372504e-02 1.09185113e-03 1.66151259e-03 1.64116753e-03
8.47710503e-04] 0.9835205078124999
[9.27435981e-01 3.16636827e-02 1.99924045e-02 4.42843967e-03
1.12372504e-02 1.09185113e-03 1.66151259e-03 1.64116753e-03
8.47710503e-04] 0.958530002170139
[9.27435981e-01 3.16636827e-02 1.99924045e-02 4.42843967e-03
1.12372504e-02 1.09185113e-03 1.66151259e-03 1.64116753e-03
8.47710503e-04] 0.9520806206597222
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9944254557291666 | 0.9835205078124999 |
epoch=10 | 0.985595703125 | 0.958530002170139 |
epoch=20 | 0.9828762478298612 | 0.9520806206597222 |
网络位置最后一个卷积层:layer4.2.conv2
# 7bit
[0.94582918 0.02169079 0.00946638 0.01477814 0.00126775 0.00181156
0.00257153 0.00258467 0. ] 0.9917644924587674
[0.94582918 0.02169079 0.00946638 0.01477814 0.00126775 0.00181156
0.00257153 0.00258467 0. ] 0.9833496941460502
[0.94582918 0.02169079 0.00946638 0.01477814 0.00126775 0.00181156
0.00257153 0.00258467 0. ] 0.981243133544922
# 8bit
[0.90923987 0.0398869 0.02129491 0.00730726 0.01469252 0.00155301
0.00215573 0.0024821 0.0013877 ] 0.9777289496527779
[0.90923987 0.0398869 0.02129491 0.00730726 0.01469252 0.00155301
0.00215573 0.0024821 0.0013877 ] 0.955349392361111
[0.90923987 0.0398869 0.02129491 0.00730726 0.01469252 0.00155301
0.00215573 0.0024821 0.0013877 ] 0.9500651889377171
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9917644924587674 | 0.9777289496527779 |
epoch=10 | 0.9833496941460502 | 0.955349392361111 |
epoch=20 | 0.981243133544922 | 0.9500651889377171 |
网络位置第一个卷积层:conv1.weight
# 7bit
[0.90136054 0.03528912 0.01318027 0.03093112 0.0014881 0.00350765
0.00690901 0.00733418 0. ] 0.9807610544217686
[0.90136054 0.03528912 0.01318027 0.03093112 0.0014881 0.00350765
0.00690901 0.00733418 0. ] 0.9722576530612246
[0.90136054 0.03528912 0.01318027 0.03093112 0.0014881 0.00350765
0.00690901 0.00733418 0. ] 0.9720450680272109
# 8bit
[0.84470663 0.06164966 0.03305697 0.01838861 0.02359694 0.00297619
0.00510204 0.00627126 0.0042517 ] 0.9578018707482994
[0.84470663 0.06164966 0.03305697 0.01838861 0.02359694 0.00297619
0.00510204 0.00627126 0.0042517 ] 0.9486607142857144
[0.84470663 0.06164966 0.03305697 0.01838861 0.02359694 0.00297619
0.00510204 0.00627126 0.0042517 ] 0.9482355442176872
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9807610544217686 | 0.9578018707482994 |
epoch=10 | 0.9722576530612246 | 0.9486607142857144 |
epoch=20 | 0.9720450680272109 | 0.9482355442176872 |
网络中间层layer3.4.conv2.weight
# 7bit
[9.68443129e-01 1.29089355e-02 6.91392687e-03 7.86505805e-03
1.12575955e-03 1.84800890e-04 1.27326118e-03 1.28512912e-03
0.00000000e+00] 0.9961310492621528
[9.68443129e-01 1.29089355e-02 6.91392687e-03 7.86505805e-03
1.12575955e-03 1.84800890e-04 1.27326118e-03 1.28512912e-03
0.00000000e+00] 0.98828125
[9.68443129e-01 1.29089355e-02 6.91392687e-03 7.86505805e-03
1.12575955e-03 1.84800890e-04 1.27326118e-03 1.28512912e-03
0.00000000e+00] 0.9861297607421874
# 8bit
[9.45312500e-01 2.45632595e-02 1.25952827e-02 6.40530056e-03
7.84301758e-03 6.15437826e-04 7.78198242e-04 1.20544434e-03
6.81559245e-04] 0.9888763427734376
[9.45312500e-01 2.45632595e-02 1.25952827e-02 6.40530056e-03
7.84301758e-03 6.15437826e-04 7.78198242e-04 1.20544434e-03
6.81559245e-04] 0.9668002658420138
[9.45312500e-01 2.45632595e-02 1.25952827e-02 6.40530056e-03
7.84301758e-03 6.15437826e-04 7.78198242e-04 1.20544434e-03
6.81559245e-04] 0.9606238471137153
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9961310492621528 | 0.9888763427734376 |
epoch=10 | 0.98828125 | 0.9668002658420138 |
epoch=20 | 0.9861297607421874 | 0.9606238471137153 |
网络最后一个卷积层layer4.2.conv3.weight
# 7bit
[9.55640793e-01 1.92956924e-02 6.96563721e-03 1.20620728e-02
8.03947449e-04 1.57070160e-03 1.72042847e-03 1.94072723e-03
0.00000000e+00] 0.9939641952514648
[9.55640793e-01 1.92956924e-02 6.96563721e-03 1.20620728e-02
8.03947449e-04 1.57070160e-03 1.72042847e-03 1.94072723e-03
0.00000000e+00] 0.9879722595214844
[9.55640793e-01 1.92956924e-02 6.96563721e-03 1.20620728e-02
8.03947449e-04 1.57070160e-03 1.72042847e-03 1.94072723e-03
0.00000000e+00] 0.9861278533935547
# 8bit
[0.92431259 0.03347301 0.01967716 0.00490284 0.01203156 0.00120449
0.00160599 0.00177288 0.00101948] 0.982365608215332
[0.92431259 0.03347301 0.01967716 0.00490284 0.01203156 0.00120449
0.00160599 0.00177288 0.00101948] 0.9650745391845703
[0.92431259 0.03347301 0.01967716 0.00490284 0.01203156 0.00120449
0.00160599 0.00177288 0.00101948] 0.9602394104003906
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9939641952514648 | 0.982365608215332 |
epoch=10 | 0.9879722595214844 | 0.9650745391845703 |
epoch=20 | 0.9861278533935547 | 0.9602394104003906 |
第一个卷积层:features.0.weight
# 7bit
[9.68171296e-01 1.67824074e-02 6.94444444e-03 5.20833333e-03
0.00000000e+00 0.00000000e+00 5.78703704e-04 2.31481481e-03
0.00000000e+00] 0.9971064814814815
[9.68171296e-01 1.67824074e-02 6.94444444e-03 5.20833333e-03
0.00000000e+00 0.00000000e+00 5.78703704e-04 2.31481481e-03
0.00000000e+00] 0.9895833333333333
[9.68171296e-01 1.67824074e-02 6.94444444e-03 5.20833333e-03
0.00000000e+00 0.00000000e+00 5.78703704e-04 2.31481481e-03
0.00000000e+00] 0.9866898148148149
# 8bit
[0.94039352 0.02951389 0.0150463 0.00694444 0.00520833 0.
0. 0.00173611 0.00115741] 0.991898148148148
[0.94039352 0.02951389 0.0150463 0.00694444 0.00520833 0.
0. 0.00173611 0.00115741] 0.982638888888889
[0.94039352 0.02951389 0.0150463 0.00694444 0.00520833 0.
0. 0.00173611 0.00115741] 0.9820601851851852
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9971064814814815 | 0.991898148148148 |
epoch=10 | 0.9895833333333333 | 0.982638888888889 |
epoch=20 | 0.9866898148148149 | 0.9820601851851852 |
中间卷积层:features.18.weight
# 7bit
[0.94266595 0.02321413 0.01026874 0.01581573 0.00111389 0.00188955
0.00245327 0.00257874 0. ] 0.9919645521375867
[0.94266595 0.02321413 0.01026874 0.01581573 0.00111389 0.00188955
0.00245327 0.00257874 0. ] 0.9684257507324218
[0.94266595 0.02321413 0.01026874 0.01581573 0.00111389 0.00188955
0.00245327 0.00257874 0. ] 0.9598155551486544
# 8bit
[0.90263706 0.04272758 0.02410338 0.00727887 0.01579242 0.0015263
0.00216717 0.00236215 0.00140508] 0.9767468770345052
[0.90263706 0.04272758 0.02410338 0.00727887 0.01579242 0.0015263
0.00216717 0.00236215 0.00140508] 0.9138200547960069
[0.90263706 0.04272758 0.02410338 0.00727887 0.01579242 0.0015263
0.00216717 0.00236215 0.00140508] 0.893851810031467
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9919645521375867 | 0.9767468770345052 |
epoch=10 | 0.9684257507324218 | 0.9138200547960069 |
epoch=20 | 0.9598155551486544 | 0.893851810031467 |
倒数第二个全连接层:classifer.3.weight
# 7bit
[0.92424482 0.02901179 0.01598012 0.01896697 0.00189114 0.00247723
0.00368488 0.00374305 0. ] 0.9882037043571472
[0.92424482 0.02901179 0.01598012 0.01896697 0.00189114 0.00247723
0.00368488 0.00374305 0. ] 0.9683631062507629
[0.92424482 0.02901179 0.01598012 0.01896697 0.00189114 0.00247723
0.00368488 0.00374305 0. ] 0.9615017771720886
# 8bit
7bit | 8bit | |
---|---|---|
epoch=1 | 0.9882037043571472 | |
epoch=10 | 0.9683631062507629 | |
epoch=20 | 0.9615017771720886 |
第一个卷积层:features.0.weight
7bit | 8bit | |
---|---|---|
epoch=1 | ||
epoch=10 | ||
epoch=20 |
从ImageNet迁移到OxfordPets(宠物数据集37种)一部分种类是包含在ImageNet数据集中的。小差异
第一个卷积层:conv1.weight
# 7bit
[0.82706207 0.06515731 0.02338435 0.04283588 0.00946003 0.00797194
0.00977891 0.01434949 0. ]
# 8bit
[0.74340986 0.09460034 0.05729167 0.03805272 0.02965561 0.00977891
0.00744048 0.01318027 0.00659014] 0.9333545918367347
7bit | 8bit | |
---|---|---|
epoch=20 | 0.9584396258503403 | 0.9333545918367347 |
从ImageNet迁移到FGVCAircraft
数据集(用于飞机识别)大差异
第一个卷积层:conv1.weight
# 7bit epoch=10
[9.41326531e-01 2.69982993e-02 4.67687075e-03 1.77508503e-02
2.12585034e-04 1.70068027e-03 4.14540816e-03 3.18877551e-03
0.00000000e+00] 0.9907525510204082
# 8bit epoch=10
[8.94026361e-01 5.37840136e-02 2.16836735e-02 4.25170068e-03
1.72193878e-02 5.31462585e-04 3.40136054e-03 3.61394558e-03
1.48809524e-03] 0.9737457482993197
# 7bit epoch=20
[9.41326531e-01 2.69982993e-02 4.67687075e-03 1.77508503e-02
2.12585034e-04 1.70068027e-03 4.14540816e-03 3.18877551e-03
0.00000000e+00] 0.9854379251700681
# 8bit epoch=20
[8.94026361e-01 5.37840136e-02 2.16836735e-02 4.25170068e-03
1.72193878e-02 5.31462585e-04 3.40136054e-03 3.61394558e-03
1.48809524e-03] 0.9615221088435374
# 7bit epoch=30
[9.41326531e-01 2.69982993e-02 4.67687075e-03 1.77508503e-02
2.12585034e-04 1.70068027e-03 4.14540816e-03 3.18877551e-03
0.00000000e+00] 0.9828869047619048
# 8bit epoch=30
[8.94026361e-01 5.37840136e-02 2.16836735e-02 4.25170068e-03
1.72193878e-02 5.31462585e-04 3.40136054e-03 3.61394558e-03
1.48809524e-03] 0.9582270408163265
7bit | 8bit | |
---|---|---|
epoch=10 (56.38%正确率)从头开始训练,增长慢 | 0.9907525510204082 | 0.9737457482993197 |
epoch=20(68.95%正确率) | 0.9854379251700681 | 0.9615221088435374 |
epoch=20(67.54%正确率) | 0.9828869047619048 | 0.9582270408163265 |
从ImageNet迁移到GTSRB数据集(交通标识识别)。大差异
第一个卷积层conv1.weight
# 7bit epoch=5
[0.88892432 0.04230442 0.01062925 0.03390731 0.00233844 0.00403912
0.00903486 0.00882228 0. ] 0.975765306122449
# 8bit epoch=5
[0.81643282 0.07897534 0.03773384 0.01828231 0.02540391 0.00340136
0.00605867 0.00903486 0.00467687] 0.9514243197278911
# 7bit epoch=10
[0.88605442 0.04272959 0.01179847 0.03773384 0.00318878 0.00393282
0.00775935 0.00680272 0. ] 0.9783163265306122
# 8bit epoch=10
[0.81792092 0.07610544 0.03752126 0.01977041 0.02848639 0.00382653
0.00531463 0.00797194 0.00308248] 0.9513180272108843
7bit | 8bit | |
---|---|---|
epoch=5(96.87%正确率) | 0.975765306122449 | 0.9514243197278911 |
epoch=10(97.24%正确率) | 0.9783163265306122 | 0.9513180272108843 |
从ImageNet迁移到PCAM数据集(医疗影像)。图片输入上有巨大的差异
包含从淋巴结切片的组织病理学扫描中提取的327.680幅彩色图像(96 x 96px)。每个图像都用二进制标记注释,该标记指示转移组织的存在。机器学习的基本进步主要是在简单明了的自然图像分类数据集上进行评估的,医学成像正成为ML的主要应用之一,因此值得在ML数据集列表中占据一席之地。既要挑战未来的工作,又要引导发展朝着对该领域有利的方向发展。
第一个卷积层conv1.weight
# 高7bit epoch=5
# 低7bit epoch=5
# 8bit epoch=5
# 高7bit epoch=10
[0.83216412 0.05644133 0.02008929 0.04804422 0.01094813 0.00722789
0.01020408 0.01488095 0. ] 0.9567389455782312
# 低7bit epoch=10
[0.74202806 0.09863946 0.05155187 0.04177296 0.03762755 0.00892857
0.01105442 0.00839711 0. ] 0.9339923469387756
# 8bit epoch=10
[0.74202806 0.09863946 0.05155187 0.0368835 0.03252551 0.0099915
0.00892857 0.01105442 0.00839711] 0.9291028911564627
高7bit | 低7bit | 8bit | |
---|---|---|---|
epoch=10(87.17%,过拟合) | 0.9567389455782312 | 0.9339923469387756 | 0.9291028911564627 |
在Resnet50的参数中的指数部分进行嵌入
做PCAM迁移训练得到的数据集
只替换第一个卷积层:69.98%
替换前面两个卷积层:68.63%
替换前面三个卷积层:68.70%
替换前面四个卷积层:69.44%
替换第234层卷积:79.04%
做FGVCAircraft迁移训练得到的数据集
只替换第一个卷积层:79.68%
替换前面两个卷积层:79.56%
替换前面三个卷积层:
替换前面四个卷积层:
替换第234层卷积:
在预训练的参数的第一个卷积层参数的指数低N位进行翻转,得到新的参数
对于新的参数,冻结修改了参数的卷积层,在ImageNet上进行重训练,得到弹性恢复之后的参数
将弹性恢复之后的参数做迁移学习,观察指数部分的变化情况
- 第一层
conv1.weight
,
嵌入2016/9408
低7bit嵌入之后的性能:43.18%
低6bit嵌入后的性能:37.64%
低5bit嵌入后的性能:45.94%
低4bit嵌入后的性能:61.32% (随机1) 55.50%(随机2)
低3bit嵌入后的性能:3.68%(随机1)5.49%(随机2)
低2bit嵌入后的性能:1.10%
低2bit嵌入后的性能:
使用在第7bit翻转后,在ImageNet1K数据集上进行重训练的结果:
当epoch=1时,性能就恢复到了78.62%
当epoch=5时,性能恢复到了78.22%,此时其实可以发现当epoch=5时,已经发生了过拟合
迁移训练的结果:
使用ImageNet1k进行弹性训练恢复性能epoch=1,在CIFAR100数据集上进行训练epoch=10,第一个卷积层的参数发了巨大改变
使用ImageNet1k进行弹性训练恢复性能epoch=5,在CIFAR100数据集上进行训练epoch=10,结果未知,还未进行实验验证
- 第三层
layer1.0.con2.wegiht
[64, 64, 3, 3] 使用嵌入策略,在每一个in_feature和每一个out_feature的kernel参数上随机选一个,即九个中随机选择一个,嵌入率1/9
低7bit嵌入后的性能:68.12%
低6bit嵌入后的性能:78.92%
低5bit嵌入后的性能:77.61%
低4bit嵌入后的性能:77.42%
低3bit嵌入后的性能:0.43%
低2bit嵌入后的性能:54.19%
低1bit嵌入后的性能:78.86%
使用low 4 bit XOR在在ImageNet1k上进行重训练,epoch=1时就可以实现收敛,能够达到78.80%的正确率
在CIFAR100上进行迁移训练,epoch=10,准确率达到83.27%
Files already downloaded and verified
Files already downloaded and verified
Epoch 1/10
----------
train Loss: 2.0605 Acc: 0.5221
val Loss: 0.8561 Acc: 0.7496
Epoch 2/10
----------
train Loss: 0.7333 Acc: 0.7826
val Loss: 0.6614 Acc: 0.8008
Epoch 3/10
----------
train Loss: 0.5060 Acc: 0.8455
val Loss: 0.6052 Acc: 0.8135
Epoch 4/10
----------
train Loss: 0.3698 Acc: 0.8868
val Loss: 0.5734 Acc: 0.8288
Epoch 5/10
----------
train Loss: 0.2713 Acc: 0.9183
val Loss: 0.5752 Acc: 0.8281
Epoch 6/10
----------
train Loss: 0.2032 Acc: 0.9410
val Loss: 0.5801 Acc: 0.8323
Epoch 7/10
----------
train Loss: 0.1518 Acc: 0.9582
val Loss: 0.5896 Acc: 0.8349
Epoch 8/10
----------
train Loss: 0.1128 Acc: 0.9716
val Loss: 0.6126 Acc: 0.8305
Epoch 9/10
----------
train Loss: 0.0871 Acc: 0.9798
val Loss: 0.6230 Acc: 0.8358
Epoch 10/10
----------
train Loss: 0.0680 Acc: 0.9855
val Loss: 0.6411 Acc: 0.8327
Test Accuracy: 83.27%
Test Loss: 0.6412
Test total: 10000
Test correct: 8327
重训练后得到的参数对于原始嵌入后的参数的汉明距离分布如下:
【0.359375 0.30883789 0.2265625 0.06152344 0.04370117】
[64, 64, 3, 3] 使用嵌入策略,在每一个in_feature和每一个out_feature的kernel参数上随机选一个,即九个中选择[0, 0],嵌入率1/9。method1
低4bit的嵌入的性能:78.96%(相较于随机选择的正确率高了很多)
使用low 4 bit XOR在在ImageNet1k上进行弹性训练,epoch=1时就可以实现收敛,能够达到78.80%的正确率
在CIFAR100数据集上进行训练
epoch=5时,性能83.13%,对于原始嵌入后的参数的汉明距离分布如下:
[0.42602539 0.27978516 0.1940918 0.05224609 0.04785156]
epoch=10时,性能83.54%,对于原始嵌入后的参数之间对应的汉明距离分布如下:
[0.41943359 0.2824707 0.18554688 0.06054688 0.05200195]
嵌入策略:kernel中间的数据进行嵌入,method2
低4bit的嵌入性能:68.27%
在ImageNet1K上进行弹性训练epoch=1时收敛,能够达到78.81%正确率
在CIFAR100数据集上进行训练
epoch=5时,性能83.42%
[2.94677734e-01 3.08593750e-01 2.81005859e-01 8.42285156e-02 3.12500000e-02]
嵌入策略:kernel的左下角进行嵌入,method3
低4bit的嵌入性能:78.91%
在ImageNet1K上进行弹性训练epoch=1时收敛,能够达到:78.91%(几乎没变)
在CIFAR100上迁移训练,method3
epoch=5时,性能:83.08%
[0.44262695 0.28833008 0.18432617 0.03955078 0.04516602]
epoch=10时,性能:83.35%
[0.409179688 0.290039062 0.192626953e 0.0595703125 0.0483398438]
在OxfordIIITPet上迁移训练,method3,性能较差
epoch=2时,性能:68.03%
[0.19116211 0.27392578 0.19628906 0.23291016 0.10571289]
epoch=5时,性能:72.61%
[0.15258789 0.25292969 0.19311523 0.27416992 0.12719727]
epoch=10时,性能:79.15%
[0.14306641 0.20776367 0.21801758 0.27832031 0.15283203]
epoch=20时,性能:76.40%
[0.12524414 0.18603516 0.23657227 0.28491211 0.16723633]
使用和其他训练代码一致的数据处理函数
epoch=10时,性能:73.26%
[0.14257812 0.2121582 0.22705078 0.26757812 0.15063477]
在FGVCAircraft上迁移训练,method3
epoch=20时,性能:64.66%
[0.569824219 0.210449219 0.162597656 0.0341796875 0.0227050781]
在PCAM上迁移训练,method3
epoch=5时,性能:87.14%
[0.50561523 0.25146484 0.17700195 0.03198242 0.03393555]
method3
嵌入策略:左下角翻转低7bit,嵌入后的性能是:79.01%
在ImageNet上弹性训练,epoch=1,性能是:78.22%
CIFAR100重训练之后,epoch=5,性能82.77%
[0, 0, 0, 0.16259766 0.35107422 0.31933594 0.13549805 0.03149414, 0] 性能极差,几乎不可恢复
method3
嵌入策略:左下角翻转3bit,嵌入后的性能是:36.60%
在ImageNet上弹性训练,epoch=1,性能是:78.56%
CIFAR100迁移训练之后,epoch=5,性能是:83.03%
[0.83251953 0.07519531 0.07128906 0.02099609]
OxfordIIITPet上迁移训练,epoch=10,性能:74.57%
[0.81030273 0.09326172 0.05664062 0.03979492]
在GTSRB上迁移训练之后,epoch=10,性能:96.42%
[0.85351562 0.07763672 0.05444336 0.0144043]
在GFVCAircraft上迁移训练之后,epoch=20,性能:64.90%
[0.85107422 0.0793457 0.05664062 0.01293945]
在PCAM上迁移训练之后,epoch=5,性能:86.37%
[0.50585938 0.25610352 0.18481445 0.05322266]
method3
嵌入策略:左下角翻转2bit,嵌入后的性能是:69.23%
在ImageNet上弹性训练,epoch=1,性能是:
CIFAR100迁移训练之后,epoch=5,性能是:
[0.83251953 0.07519531 0.07128906 0.02099609]
在PCAM上迁移训练之后,epoch=5,性能:
设置此实验用于观察原始预训练参数的指数部分的分布
一个可执行文件中的0bit和1bit的数量近似于1:1
-
第一个卷积层
conv1.weight
[0 1 10.99978741 0.90508078 0.43633078 0.49213435 0.50127551] -
第二个卷积层
layer1.0.conv1.weight
[0 1 1 0.99951172 0.89794922 0.29638672 0.54638672 0.49243164] -
第三个卷积层
layer1.0.conv2.weight
[0 1 1 0.99940321 0.83940972 0.21511502 0.52267795 0.49598524]
-
第四个卷积层
layer1.0.conv3.weight
[0 1 0.99255371 0.97979736 0.87640381 0.24499512 0.56115723 0.49353027] -
第六个卷积层
layer1.1.conv2.weight
[0 1 0.984375 0.99943034 0.85845269 0.20817057 0.5464681 0.49620226]
- 第一个卷积层
conv1.weight
[1.06292517e-04 9.99893707e-01 9.99893707e-01 9.99255952e-01 9.18792517e-01 4.44515306e-01 4.85650510e-01 4.98511905e-01] - 第二个卷积层
layer1.0.conv1.weight
[]
在layer1.0.conv2的weight参数的指数部分嵌入低三位(malware57B,一共456bit),性能:2.18%
在ImageNet上做弹性训练,epoch=1,性能恢复: 78.51%
在CIFAR100上做迁移训练,epoch=5,性能达到:83.11%
[],恢复失败,出现6bit错误:7 9 92 245 300 303
在CIFAR100上做迁移训练,epoch=7,性能达到:83.56%
[0.83740234 0.07104492 0.04321289 0.04833984],恢复失败,出现5bit错误 14 67 245 300 303
在GTSRB上迁移训练之后,epoch=10,性能达到:96.83%
[0.88623047 0.06054688 0.03076172 0.02246094],恢复失败,出现1bit错误 304
采用新的编码方式encoding1
编码方式:bit0 -》011 ; bit1 -〉100;encoding1
编码得到的三位字符串, A3,A2,A1
位置 | 高位A3 | 中位A2 | 低位A1 |
---|---|---|---|
编码值 | 0 | 1 | 1 |
真实值 | 0 | 0 | 0 |
即译码过程:(A3,!A2,!A1)得到真实值
在layer1.0.conv2的weight参数的指数部分嵌入低三位(malware57B,一共456bit),性能:79.27%
在ImageNet上做弹性训练,epoch=1,性能恢复: 78.86%
在CIFAR100上做迁移训练,epoch=5,性能达到:83.01%
恢复出现问题:
pos: 8 initBit: 0 extractedBit: 1 pos: 178 initBit: 0 extractedBit: 1 pos: 210 initBit: 0 extractedBit: 1
在GTSRB上迁移训练,epoch=10,性能:96.85%
恢复出现问题:
pos: 403 initBit: 0 extractedBit: 1
encoding1+11个参数做冗余
嵌入后46B的性能是:79.26%
imageNet弹性训练,epoch=1,性能恢复:78.72%
CIFAR100迁移训练,epoch=5,性能恢复:83.13%,提取成功
OxfrodIIITPet迁移训练,epoch=20,性能达到:71.90%,提取失败,发生1bit错误
OxfrodIIITPet迁移训练new,epoch=20,性能达到:72.91%,提取失败,发生5bit错误
pos: 64 initBit: 1 extractedBit: 0 pos: 66 initBit: 1 extractedBit: 0 pos: 67 initBit: 1 extractedBit: 0 pos: 128 initBit: 1 extractedBit: 0 pos: 329 initBit: 1 extractedBit: 0
GTSRB迁移训练,epoch=10,性能恢复:96.38%,提取成功
GFVCAircraft迁移训练,epoch=20,性能:64.72%,提取成功
PCAM迁移训练,epoch=5,性能恢复:84.91%,提取成功