Skip to content

Commit

Permalink
feat(quantization): update quantization codebase to 1.0 (#79)
Browse files Browse the repository at this point in the history
* feat(quantization): update quantization codebase to 1.0

* refactor(quantization): refactor ci by using shell script and change quantization code for format check
  • Loading branch information
lixiangyin666 authored and Jianfeng Wang committed Oct 21, 2020
1 parent 25014fa commit 1f6ad9c
Show file tree
Hide file tree
Showing 17 changed files with 458 additions and 450 deletions.
21 changes: 1 addition & 20 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,4 @@ jobs:
# Runs a set of commands using the runners shell
- name: Format check
run: |
export PYTHONPATH=$PWD:$PYTHONPATH
CHECK_VISION=official/vision/
CHECK_NLP=official/nlp/
pip install pylint==2.5.2
pylint $CHECK_VISION $CHECK_NLP --rcfile=.pylintrc || pylint_ret=$?
echo test, and deploy your project.
if [ "$pylint_ret" ]; then
exit $pylint_ret
fi
echo "All lint steps passed!"
pip3 install flake8==3.7.9
flake8 official
echo "All flake check passed!"
pip3 install isort==4.3.21
isort --check-only -rc official
echo "All isort check passed!"
run: ./run_format_check.sh
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
*log*/
*.jpg
*.png
*.txt

# compilation and distribution
__pycache__
Expand Down
27 changes: 14 additions & 13 deletions official/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

| Model | top1 acc (float32) | FPS* (float32) | top1 acc (int8) | FPS* (int8) |
| --- | --- | --- | --- | --- |
| ResNet18 | 69.824 | 10.5 | 69.754 | 16.3 |
| ShufflenetV1 (1.5x) | 71.954 | 17.3 | 70.656 | 25.3 |
| MobilenetV2 | 72.820 | 13.1 | 71.378 | 17.4 |
| ResNet18 | 69.796 | 10.5 | 69.814 | 16.3 |
| ShufflenetV1 (1.5x) | 71.948 | 17.3 | 70.806 | 25.3 |
| MobilenetV2 | 72.808 | 13.1 | 71.228 | 17.4 |

**: FPS is measured on Intel(R) Xeon(R) Gold 6130 CPU @ 2.10GHz, single 224x224 image*

Expand All @@ -18,12 +18,12 @@

#### (Optional) Download Pretrained Models
```
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_normal_72820.pkl
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_qat_71378.pkl
wget https://data.megengine.org.cn/models/weights/resnet18_normal_69824.pkl
wget https://data.megengine.org.cn/models/weights/resnet18_qat_69754.pkl
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_normal_71954.pkl
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_qat_70656.pkl
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_normal_72808.pkl
wget https://data.megengine.org.cn/models/weights/mobilenet_v2_qat_71228.pkl
wget https://data.megengine.org.cn/models/weights/resnet18_normal_69796.pkl
wget https://data.megengine.org.cn/models/weights/resnet18_qat_69814.pkl
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_normal_71948.pkl
wget https://data.megengine.org.cn/models/weights/shufflenet_v1_x1_5_g3_qat_70806.pkl
```

## Quantization Aware Training (QAT)
Expand All @@ -44,7 +44,7 @@ for _ in range(...):

```python
import megengine.quantization as Q
import megengine.jit as jit
from megengine.jit import trace

model = ...

Expand All @@ -53,10 +53,11 @@ Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
# real quant
Q.quantize(model)

@jit.trace(symbolic=True):
@trace(symbolic=True, capture_as_const=True)
def inference_func(x):
return model(x)

inference_func(x)
inference_func.dump(...)
```

Expand All @@ -76,7 +77,7 @@ python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resne

## Step 2. Calibration
```
python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl --mode calibration
python3 calibration.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.normal/checkpoint.pkl
```

## Step 3. Test QAT model on ImageNet Testset
Expand All @@ -85,7 +86,7 @@ python3 finetune.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resne
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode qat
```

or testing in quantized mode, which uses only cpu for inference and takes longer time
or testing in quantized mode

```
python3 test.py -a resnet18 -d /path/to/imagenet --checkpoint /path/to/resnet18.qat/checkpoint.pkl --mode quantized -n 1
Expand Down
8 changes: 8 additions & 0 deletions official/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
140 changes: 56 additions & 84 deletions official/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,112 +6,91 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""Finetune a pretrained fp32 with int8 quantization aware training(QAT)"""
"""Finetune a pretrained fp32 with int8 post train quantization(calibration)"""
import argparse
import collections
import multiprocessing as mp
import numbers
import os
import bisect
import time

# pylint: disable=import-error
import models

import megengine as mge
import megengine.data as data
import megengine.data.transform as T
import megengine.distributed as dist
import megengine.functional as F
import megengine.jit as jit
import megengine.optimizer as optim
import megengine.quantization as Q

import config
import models
from megengine.quantization.quantize import enable_observer, quantize, quantize_qat

logger = mge.get_logger(__name__)
# from imagenet_nori_dataset import ImageNetNoriDataset
from megengine.quantization.quantize import enable_observer, quantize, quantize_qat


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-a", "--arch", default="resnet18", type=str)
parser.add_argument("-d", "--data", default=None, type=str)
parser.add_argument("-s", "--save", default="/data/models", type=str)
parser.add_argument("-c", "--checkpoint", default=None, type=str,
help="pretrained model to finetune")

parser.add_argument("-m", "--mode", default="qat", type=str,
choices=["normal", "qat", "quantized", "calibration"],
help="Quantization Mode\n"
"normal: no quantization, using float32\n"
"qat: quantization aware training, simulate int8\n"
"calibration: calibration\n"
"quantized: convert mode to int8 quantized, inference only")
parser.add_argument(
"-c",
"--checkpoint",
default=None,
type=str,
help="pretrained model to finetune",
)

parser.add_argument("-n", "--ngpus", default=None, type=int)
parser.add_argument("-w", "--workers", default=4, type=int)
parser.add_argument("--report-freq", default=50, type=int)
args = parser.parse_args()

world_size = mge.get_device_count("gpu") if args.ngpus is None else args.ngpus

if world_size > 1:
# start distributed training, dispatch sub-processes
mp.set_start_method("spawn")
processes = []
for rank in range(world_size):
p = mp.Process(target=worker, args=(rank, world_size, args))
p.start()
processes.append(p)

for p in processes:
p.join()
else:
worker(0, 1, args)
world_size = (
dist.helper.get_device_count_by_fork("gpu")
if args.ngpus is None
else args.ngpus
)
world_size = 1 if world_size == 0 else world_size
if world_size != 1:
logger.warning(
"Calibration only supports single GPU now, %d provided", world_size
)
proc_func = dist.launcher(worker) if world_size > 1 else worker
proc_func(world_size, args)


def get_parameters(model, cfg):
if isinstance(cfg.WEIGHT_DECAY, numbers.Number):
return {"params": model.parameters(requires_grad=True),
"weight_decay": cfg.WEIGHT_DECAY}
return {
"params": model.parameters(requires_grad=True),
"weight_decay": cfg.WEIGHT_DECAY,
}

groups = collections.defaultdict(list) # weight_decay -> List[param]
for pname, p in model.named_parameters(requires_grad=True):
wd = cfg.WEIGHT_DECAY(pname, p)
groups[wd].append(p)
groups = [
{"params": params, "weight_decay": wd}
for wd, params in groups.items()
{"params": params, "weight_decay": wd} for wd, params in groups.items()
] # List[{param, weight_decay}]
return groups


def worker(rank, world_size, args):
def worker(world_size, args):
# pylint: disable=too-many-statements

rank = dist.get_rank()
if world_size > 1:
# Initialize distributed process group
logger.info("init distributed process group {} / {}".format(rank, world_size))
dist.init_process_group(
master_ip="localhost",
master_port=23456,
world_size=world_size,
rank=rank,
dev=rank,
)

save_dir = os.path.join(args.save, args.arch + "." + args.mode)
save_dir = os.path.join(args.save, args.arch + "." + "calibration")
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
mge.set_log_file(os.path.join(save_dir, "log.txt"))

model = models.__dict__[args.arch]()
cfg = config.get_finetune_config(args.arch)

cfg.LEARNING_RATE *= world_size # scale learning rate in distributed training
total_batch_size = cfg.BATCH_SIZE * world_size
steps_per_epoch = 1280000 // total_batch_size
total_steps = steps_per_epoch * cfg.EPOCHS

# load calibration model
assert args.checkpoint
logger.info("Load pretrained weights from %s", args.checkpoint)
Expand All @@ -121,70 +100,64 @@ def worker(rank, world_size, args):

# Build valid datasets
valid_dataset = data.dataset.ImageNet(args.data, train=False)
# valid_dataset = ImageNetNoriDataset(args.data)
valid_sampler = data.SequentialSampler(
valid_dataset, batch_size=100, drop_last=False
)
valid_queue = data.DataLoader(
valid_dataset,
sampler=valid_sampler,
transform=T.Compose(
[
T.Resize(256),
T.CenterCrop(224),
T.Normalize(mean=128),
T.ToMode("CHW"),
]
[T.Resize(256), T.CenterCrop(224), T.Normalize(mean=128), T.ToMode("CHW")]
),
num_workers=args.workers,
)

# calibration
model.fc.disable_quantize()
model = quantize_qat(model, qconfig=Q.calibration_qconfig)

# calculate scale
@jit.trace(symbolic=True)
def calculate_scale(image, label):
model.eval()
enable_observer(model)
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
return loss, acc1, acc5

# model.fc.disable_quantize()

infer(calculate_scale, valid_queue, args)

# quantized
model = quantize(model)

# eval quantized model
@jit.trace(symbolic=True)
def eval_func(image, label):
model.eval()
logits = model(image)
loss = F.cross_entropy_with_softmax(logits, label, label_smooth=0.1)
acc1, acc5 = F.accuracy(logits, label, (1, 5))
loss = F.loss.cross_entropy(logits, label, label_smooth=0.1)
acc1, acc5 = F.topk_accuracy(logits, label, (1, 5))
if dist.is_distributed(): # all_reduce_mean
loss = dist.all_reduce_sum(loss, "valid_loss") / dist.get_world_size()
acc1 = dist.all_reduce_sum(acc1, "valid_acc1") / dist.get_world_size()
acc5 = dist.all_reduce_sum(acc5, "valid_acc5") / dist.get_world_size()
loss = dist.functional.all_reduce_sum(loss) / dist.get_world_size()
acc1 = dist.functional.all_reduce_sum(acc1) / dist.get_world_size()
acc5 = dist.functional.all_reduce_sum(acc5) / dist.get_world_size()
return loss, acc1, acc5

_, valid_acc, valid_acc5 = infer(eval_func, valid_queue, args)
logger.info("TEST %f, %f", valid_acc, valid_acc5)

# save quantized model
mge.save(
{"step": -1, "state_dict": model.state_dict()},
os.path.join(save_dir, "checkpoint-calibration.pkl")
os.path.join(save_dir, "checkpoint-calibration.pkl"),
)
logger.info(
"save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl"))
)
logger.info("save in {}".format(os.path.join(save_dir, "checkpoint-calibration.pkl")))


def infer(model, data_queue, args):
objs = AverageMeter("Loss")
Expand All @@ -195,8 +168,8 @@ def infer(model, data_queue, args):
t = time.time()
for step, (image, label) in enumerate(data_queue):
n = image.shape[0]
image = image.astype("float32") # convert np.uint8 to float32
label = label.astype("int32")
image = mge.tensor(image, dtype="float32")
label = mge.tensor(label, dtype="int32")

loss, acc1, acc5 = model(image, label)

Expand All @@ -207,9 +180,8 @@ def infer(model, data_queue, args):
t = time.time()

if step % args.report_freq == 0 and dist.get_rank() == 0:
logger.info("Step %d, %s %s %s %s",
step, objs, top1, top5, total_time)

logger.info("Step %d, %s %s %s %s", step, objs, top1, top5, total_time)

# break
if step == args.report_freq:
break
Expand Down
Loading

0 comments on commit 1f6ad9c

Please sign in to comment.