Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Examples (a2q): updating and extending ESPCN demo #706

Merged
merged 17 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions src/brevitas_examples/super_resolution/README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
# Integer-Quantized Super Resolution Experiments with Brevitas

This directory contains training scripts to demonstrate how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas).
This directory contains scripts demonstrating how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas).
Code is also provided to demonstrate accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)".

## Experiments

All models are trained on the BSD300 dataset to upsample images by 2x.
Target images are center cropped to 512x512.
Target images are cropped to 512x512.
During training random cropping is applied, along with random vertical and horizontal flips.
During inference center cropping is applied.
Inputs are then downscaled by 2x and then used to train the model directly in the RGB space.
Note that this is a difference from many academic works that train only on the Y-channel in YCbCr format.

| Model Name | Upscale Factor | Weight quantization | Activation quantization | Peak Signal-to-Noise Ratio |
|-----------------------------|----------------|---------------------|-------------------------|----------------------------|
| [float_espcn_x2](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/float_espcn_x2-2f3821e3.pth) | x2 | float32 | float32 | 30.37 |
| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_base-7d54e29c.pth) | x2 | int8 | (u)int8 | 30.16 |
| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_32b-0b1f361d.pth) | x2 | int8 | (u)int8 | 30.80 |
| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_16b-3c4acd35.pth) | x2 | int8 | (u)int8 | 29.38 |
| bicubic_interp | x2 | N/A | N/A | 28.71 |
| [float_espcn_x2]() | x2 | float32 | float32 | 31.03 |
||
| [quant_espcn_x2_w8a8_base]() | x2 | int8 | (u)int8 | 30.96 |
| [quant_espcn_x2_w8a8_a2q_32b]() | x2 | int8 | (u)int8 | 30.79 |
| [quant_espcn_x2_w8a8_a2q_16b]() | x2 | int8 | (u)int8 | 30.56 |
||
| [quant_espcn_x2_w4a4_base]() | x2 | int4 | (u)int4 | 30.30 |
| [quant_espcn_x2_w4a4_a2q_32b]() | x2 | int4 | (u)int4 | 30.27 |
| [quant_espcn_x2_w4a4_a2q_13b]() | x2 | int4 | (u)int4 | 30.24 |


## Train
Expand Down
20 changes: 15 additions & 5 deletions src/brevitas_examples/super_resolution/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@

parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation')
parser.add_argument('--data_root', help='Path to folder containing BSD300 val folder')
parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint')
parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint. Default = None')
parser.add_argument(
'--save_path', type=str, default='outputs/', help='Save path for exported model')
'--save_path',
type=str,
default='outputs/',
help='Save path for exported model. Default = outputs/')
parser.add_argument(
'--model', type=str, default='quant_espcn_x2_w8a8_base', help='Name of the model configuration')
parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers')
parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size')
'--model',
type=str,
default='quant_espcn_x2_w8a8_base',
help='Name of the model configuration. Default = quant_espcn_x2_w8a8_base')
parser.add_argument(
'--workers', type=int, default=0, help='Number of data loading workers. Default = 0')
parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size. Default = 16')
parser.add_argument(
'--crop_size', type=int, default=512, help='The size to crop the image. Default = 512')
parser.add_argument('--use_pretrained', action='store_true', default=False)
parser.add_argument('--eval_acc_bw', action='store_true', default=False)
parser.add_argument('--save_model_io', action='store_true', default=False)
Expand All @@ -60,6 +69,7 @@ def main():
num_workers=args.workers,
batch_size=args.batch_size,
upscale_factor=model.upscale_factor,
crop_size=args.crop_size,
download=True)

test_psnr = evaluate_avg_psnr(testloader, model)
Expand Down
23 changes: 21 additions & 2 deletions src/brevitas_examples/super_resolution/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

from torch import hub
import torch.nn as nn

from .espcn import *

Expand All @@ -26,7 +27,23 @@
upscale_factor=2,
weight_bit_width=8,
act_bit_width=8,
acc_bit_width=16)}
acc_bit_width=16),
'quant_espcn_x2_w4a4_base':
partial(quant_espcn_base, upscale_factor=2, weight_bit_width=4, act_bit_width=4),
'quant_espcn_x2_w4a4_a2q_32b':
partial(
quant_espcn_a2q,
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=32),
'quant_espcn_x2_w4a4_a2q_13b':
partial(
quant_espcn_a2q,
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=13)}

root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res-r0'

Expand All @@ -40,8 +57,10 @@
def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]:
if name not in model_impl.keys():
raise NotImplementedError(f"{name} does not exist.")
model = model_impl[name]()
model: nn.Module = model_impl[name]()
if pretrained:
if name not in model_impl:
raise NotImplementedError(f"Error: {name} does not have a pre-trained checkpoint.")
checkpoint = model_url[name]
state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu')
model.load_state_dict(state_dict, strict=True)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):


class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
pre_scaling_min_val = 1e-8
scaling_min_val = 1e-8
pre_scaling_min_val = 1e-10
scaling_min_val = 1e-10


class CommonIntActQuant(Int8ActPerTensorFloat):
Expand Down
57 changes: 25 additions & 32 deletions src/brevitas_examples/super_resolution/models/espcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .common import CommonIntWeightPerChannelQuant
from .common import CommonUintActQuant
from .common import ConstUint8ActQuant
from .common import QuantNearestNeighborConvolution

__all__ = [
"float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"]
Expand All @@ -29,9 +28,7 @@ def weight_init(layer):


class FloatESPCN(nn.Module):
"""Floating-point version of FINN-Friendly Quantized Efficient Sub-Pixel Convolution
Network (ESPCN) as used in Colbert et al. (2023) - `Quantized Neural Networks for
Low-Precision Accumulation with Guaranteed Overflow Avoidance`."""
"""Floating-point version of Efficient Sub-Pixel Convolution Network (ESPCN)"""

def __init__(self, upscale_factor: int = 3, num_channels: int = 3):
super(FloatESPCN, self).__init__()
Expand All @@ -48,17 +45,14 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3):
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
self.conv3 = nn.Conv2d(
in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True)
self.conv4 = nn.Sequential()
self.conv4.add_module("interp", nn.UpsamplingNearest2d(scale_factor=upscale_factor))
self.conv4.add_module(
"conv",
nn.Conv2d(
in_channels=32,
out_channels=num_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True))
self.conv4 = nn.Conv2d(
in_channels=32,
out_channels=num_channels * pow(upscale_factor, 2),
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
Expand All @@ -75,15 +69,13 @@ def forward(self, inp: Tensor):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = self.out(x)
x = self.pixel_shuffle(self.conv4(x))
x = self.out(x) # To mirror quant version
return x


class QuantESPCN(FloatESPCN):
"""FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN) as
used in Colbert et al. (2023) - `Quantized Neural Networks for Low-Precision
Accumulation with Guaranteed Overflow Avoidance`."""
"""FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN)"""

def __init__(
self,
Expand Down Expand Up @@ -130,27 +122,28 @@ def __init__(
kernel_size=3,
stride=1,
padding=1,
bias=True,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_bit_width=weight_bit_width,
weight_accumulator_bit_width=acc_bit_width,
weight_quant=weight_quant)
# Quantizing the weights and input activations to 8-bit integers
# and not applying accumulator constraint to the final convolution
# layer (i.e., accumulator_bit_width=32).
self.conv4 = QuantNearestNeighborConvolution(
# We quantize the weights and input activations of the final layer
# to 8-bit integers. We do not apply the accumulator constraint to
# the final convolution layer. FINN does not currently support
# per-tensor quantization or biases for sub-pixel convolution layers.
self.conv4 = qnn.QuantConv2d(
in_channels=32,
out_channels=num_channels,
out_channels=num_channels * pow(upscale_factor, 2),
kernel_size=3,
stride=1,
padding=1,
upscale_factor=upscale_factor,
bias=True,
signed_act=False,
act_bit_width=IO_DATA_BIT_WIDTH,
acc_bit_width=IO_ACC_BIT_WIDTH,
weight_quant=weight_quant,
weight_bit_width=IO_DATA_BIT_WIDTH)
bias=False,
input_bit_width=act_bit_width,
input_quant=CommonUintActQuant,
weight_bit_width=IO_DATA_BIT_WIDTH,
weight_quant=CommonIntWeightPerChannelQuant,
weight_scaling_per_output_channel=False)

self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
Expand Down
18 changes: 13 additions & 5 deletions src/brevitas_examples/super_resolution/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import argparse
import copy
from hashlib import sha256
import json
import os
Expand Down Expand Up @@ -46,10 +47,10 @@
parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers')
parser.add_argument('--batch_size', type=int, default=8, help='Minibatch size')
parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate')
parser.add_argument('--total_epochs', type=int, default=100, help='Total number of training epochs')
parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay')
parser.add_argument('--total_epochs', type=int, default=500, help='Total number of training epochs')
parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay')
parser.add_argument('--step_size', type=int, default=1)
parser.add_argument('--gamma', type=float, default=0.98)
parser.add_argument('--gamma', type=float, default=0.999)
parser.add_argument('--eval_acc_bw', action='store_true', default=False)
parser.add_argument('--save_pth_ckpt', action='store_true', default=False)
parser.add_argument('--save_model_io', action='store_true', default=False)
Expand Down Expand Up @@ -81,7 +82,6 @@ def main():
args.data_root,
num_workers=args.workers,
batch_size=args.batch_size,
batch_size_test=1,
upscale_factor=model.upscale_factor,
download=True)
criterion = nn.MSELoss()
Expand All @@ -92,17 +92,25 @@ def main():
scheduler = lrs.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)

# train model
best_psnr, best_weights = 0., copy.deepcopy(model.state_dict())
for ep in range(args.total_epochs):
train_loss = train_for_epoch(trainloader, model, criterion, optimizer)
test_psnr = evaluate_avg_psnr(testloader, model)
scheduler.step()
print(f"[Epoch {ep:03d}] train_loss={train_loss:.4f}, test_psnr={test_psnr:.2f}")
if test_psnr >= best_psnr:
best_weights = copy.deepcopy(model.state_dict())
best_psnr = test_psnr
model.load_state_dict(best_weights)
model = model.to(device)
test_psnr = evaluate_avg_psnr(testloader, model)
print(f"Final test_psnr={test_psnr:.2f}")

# save checkpoint
os.makedirs(args.save_path, exist_ok=True)
if args.save_pth_ckpt:
ckpt_path = f"{args.save_path}/{args.model}.pth"
torch.save(model.state_dict(), ckpt_path)
torch.save(best_weights, ckpt_path)
with open(ckpt_path, "rb") as _file:
bytes = _file.read()
model_tag = sha256(bytes).hexdigest()[:8]
Expand Down
Loading
Loading