From 4e80f889dac2cf77f92d633c3b54c14739b57e4e Mon Sep 17 00:00:00 2001 From: icolbert Date: Fri, 25 Aug 2023 16:46:06 -0700 Subject: [PATCH 01/17] Updating ESPCN to use sub-pixel conv --- .../super_resolution/models/espcn.py | 78 ++++++++++--------- 1 file changed, 43 insertions(+), 35 deletions(-) diff --git a/src/brevitas_examples/super_resolution/models/espcn.py b/src/brevitas_examples/super_resolution/models/espcn.py index 3af2846f8..3d10e6581 100644 --- a/src/brevitas_examples/super_resolution/models/espcn.py +++ b/src/brevitas_examples/super_resolution/models/espcn.py @@ -12,10 +12,14 @@ from .common import CommonIntWeightPerChannelQuant from .common import CommonUintActQuant from .common import ConstUint8ActQuant -from .common import QuantNearestNeighborConvolution __all__ = [ - "float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"] + "float_espcn", + "quant_espcn", + "quant_espcn_a2q", + "quant_espcn_base", + "FloatESPCN", + "QuantESPCN"] IO_DATA_BIT_WIDTH = 8 IO_ACC_BIT_WIDTH = 32 @@ -29,9 +33,7 @@ def weight_init(layer): class FloatESPCN(nn.Module): - """Floating-point version of FINN-Friendly Quantized Efficient Sub-Pixel Convolution - Network (ESPCN) as used in Colbert et al. (2023) - `Quantized Neural Networks for - Low-Precision Accumulation with Guaranteed Overflow Avoidance`.""" + """Floating-point version of Efficient Sub-Pixel Convolution Network (ESPCN)""" def __init__(self, upscale_factor: int = 3, num_channels: int = 3): super(FloatESPCN, self).__init__() @@ -45,20 +47,27 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3): padding=2, bias=True) self.conv2 = nn.Conv2d( - in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True) + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1, + bias=True) self.conv3 = nn.Conv2d( - in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True) - self.conv4 = nn.Sequential() - self.conv4.add_module("interp", nn.UpsamplingNearest2d(scale_factor=upscale_factor)) - self.conv4.add_module( - "conv", - nn.Conv2d( - in_channels=32, - out_channels=num_channels, - kernel_size=3, - stride=1, - padding=1, - bias=True)) + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1, + bias=True) + self.conv4 = nn.Conv2d( + in_channels=32, + out_channels=num_channels * pow(upscale_factor, 2), + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) @@ -75,15 +84,13 @@ def forward(self, inp: Tensor): x = self.relu(self.bn1(self.conv1(x))) x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) - x = self.conv4(x) - x = self.out(x) + x = self.pixel_shuffle(self.conv4(x)) + x = self.out(x) # To mirror quant version return x class QuantESPCN(FloatESPCN): - """FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN) as - used in Colbert et al. (2023) - `Quantized Neural Networks for Low-Precision - Accumulation with Guaranteed Overflow Avoidance`.""" + """FINN-Friendly Quantized Efficient Sub-Pixel Convolution Network (ESPCN)""" def __init__( self, @@ -130,27 +137,28 @@ def __init__( kernel_size=3, stride=1, padding=1, + bias=True, input_bit_width=act_bit_width, input_quant=CommonUintActQuant, weight_bit_width=weight_bit_width, weight_accumulator_bit_width=acc_bit_width, weight_quant=weight_quant) - # Quantizing the weights and input activations to 8-bit integers - # and not applying accumulator constraint to the final convolution - # layer (i.e., accumulator_bit_width=32). - self.conv4 = QuantNearestNeighborConvolution( + # We quantize the weights and input activations of the final layer + # to 8-bit integers. We do not apply the accumulator constraint to + # the final convolution layer. FINN does not currently support + # per-tensor quantization or biases for sub-pixel convolution layers. + self.conv4 = qnn.QuantConv2d( in_channels=32, - out_channels=num_channels, + out_channels=num_channels * pow(upscale_factor, 2), kernel_size=3, stride=1, padding=1, - upscale_factor=upscale_factor, - bias=True, - signed_act=False, - act_bit_width=IO_DATA_BIT_WIDTH, - acc_bit_width=IO_ACC_BIT_WIDTH, - weight_quant=weight_quant, - weight_bit_width=IO_DATA_BIT_WIDTH) + bias=False, + input_bit_width=act_bit_width, + input_quant=CommonUintActQuant, + weight_bit_width=IO_DATA_BIT_WIDTH, + weight_quant=CommonIntWeightPerChannelQuant, + weight_scaling_per_output_channel=False) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) From 108f9ae7394efb94f00b797fe704d1c06930206b Mon Sep 17 00:00:00 2001 From: icolbert Date: Fri, 25 Aug 2023 16:48:17 -0700 Subject: [PATCH 02/17] Adding shared transform to BSD300 dataloaders --- .../super_resolution/utils/dataset.py | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index edadbb95a..1c7749b08 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -52,6 +52,9 @@ from torchvision.transforms import Compose from torchvision.transforms import Resize from torchvision.transforms import ToTensor +from torchvision.transforms import RandomCrop +from torchvision.transforms import RandomVerticalFlip +from torchvision.transforms import RandomHorizontalFlip __all__ = ["get_bsd300_dataloaders"] @@ -79,21 +82,21 @@ def load_img_rbg(filepath): class DatasetFromFolder(data.Dataset): - def __init__(self, image_dir, input_transform=None, target_transform=None): + def __init__(self, image_dir, shared_transform, input_transform, target_transform): super(DatasetFromFolder, self).__init__() self.image_filenames = [ os.path.join(image_dir, x) for x in os.listdir(image_dir) if is_valid_image_file(x)] + self.shared_transform = shared_transform self.input_transform = input_transform self.target_transform = target_transform def __getitem__(self, index): input = load_img_rbg(self.image_filenames[index]) + input = self.shared_transform(input) target = input.copy() - if self.input_transform: - input = self.input_transform(input) - if self.target_transform: - target = self.target_transform(target) + input = self.input_transform(input) + target = self.target_transform(target) return input, target def __len__(self): @@ -122,35 +125,45 @@ def calculate_valid_crop_size(crop_size, upscale_factor): return crop_size - (crop_size % upscale_factor) +def train_transforms(crop_size): + return Compose([ + RandomCrop(crop_size, pad_if_needed=True), + RandomHorizontalFlip(), + RandomVerticalFlip()]) + + +def test_transforms(crop_size): + return Compose([CenterCrop(crop_size)]) + + def input_transform(crop_size, upscale_factor): return Compose([ - CenterCrop(crop_size), Resize(crop_size // upscale_factor), ToTensor(),]) -def target_transform(crop_size): - return Compose([ - CenterCrop(crop_size), - ToTensor(),]) +def target_transform(): + return Compose([ToTensor()]) -def get_training_set(upscale_factor: int, root_dir: str, crop_size: int = 256): +def get_training_set(upscale_factor: int, root_dir: str, crop_size: int): train_dir = os.path.join(root_dir, "train") crop_size = calculate_valid_crop_size(crop_size, upscale_factor) return DatasetFromFolder( train_dir, + shared_transform=train_transforms(crop_size), input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size)) + target_transform=target_transform()) -def get_test_set(upscale_factor: int, root_dir: str, crop_size: int = 256): +def get_test_set(upscale_factor: int, root_dir: str, crop_size: int): test_dir = os.path.join(root_dir, "test") crop_size = calculate_valid_crop_size(crop_size, upscale_factor) return DatasetFromFolder( test_dir, + shared_transform=test_transforms(crop_size), input_transform=input_transform(crop_size, upscale_factor), - target_transform=target_transform(crop_size)) + target_transform=target_transform()) def get_bsd300_dataloaders( From 7552e92696b6f2a04338f71bb33c2031650e9b41 Mon Sep 17 00:00:00 2001 From: icolbert Date: Fri, 25 Aug 2023 16:48:54 -0700 Subject: [PATCH 03/17] Updating defaults for BSD300 dataloaders --- src/brevitas_examples/super_resolution/utils/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index 1c7749b08..e66a79a5c 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -170,10 +170,10 @@ def get_bsd300_dataloaders( data_root: str, num_workers: int = 0, batch_size: int = 32, - batch_size_test: int = 32, + batch_size_test: int = 100, pin_memory: bool = True, upscale_factor: int = 3, - crop_size: int = 512, + crop_size: int = 256, download: bool = False) -> Tuple[Type[DataLoader]]: """Function that loads BSD300 dataset from data_root folder and returns the training and testing dataloaders. If /BSD300/images does not exist, then the data is @@ -187,7 +187,7 @@ def get_bsd300_dataloaders( None, then batch_size_test = batch_size. Default: 32 pin_memory (bool): Whether or not to pin the memory for both dataloaders. Default: True upscale_factor (int): The upscale factor for the super resolution task. Default: 3 - crop_size (int): The size to crop images for upscaling. Default 512 + crop_size (int): The size to crop images for upscaling. Default: 256 download (bool): Whether or not to download the dataset. Default: False """ data_root = download_bsd300(data_root, download) From 80e7e35d0d06984ce2cc3ee951c4dfe9b6e3857b Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 10:42:06 -0700 Subject: [PATCH 04/17] Saving best weights rather than last weights --- .../super_resolution/train_model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/super_resolution/train_model.py b/src/brevitas_examples/super_resolution/train_model.py index 52dfe8c72..6f1086579 100644 --- a/src/brevitas_examples/super_resolution/train_model.py +++ b/src/brevitas_examples/super_resolution/train_model.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +import copy from hashlib import sha256 import json import os @@ -81,7 +82,6 @@ def main(): args.data_root, num_workers=args.workers, batch_size=args.batch_size, - batch_size_test=1, upscale_factor=model.upscale_factor, download=True) criterion = nn.MSELoss() @@ -92,17 +92,25 @@ def main(): scheduler = lrs.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) # train model + best_psnr, best_weights = 0., copy.deepcopy(model.state_dict()) for ep in range(args.total_epochs): train_loss = train_for_epoch(trainloader, model, criterion, optimizer) test_psnr = evaluate_avg_psnr(testloader, model) scheduler.step() print(f"[Epoch {ep:03d}] train_loss={train_loss:.4f}, test_psnr={test_psnr:.2f}") + if test_psnr >= best_psnr: + best_weights = copy.deepcopy(model.state_dict()) + best_psnr = test_psnr + model.load_state_dict(best_weights) + model = model.to(device) + test_psnr = evaluate_avg_psnr(testloader, model) + print(f"Final test_psnr={test_psnr:.2f}") # save checkpoint os.makedirs(args.save_path, exist_ok=True) if args.save_pth_ckpt: ckpt_path = f"{args.save_path}/{args.model}.pth" - torch.save(model.state_dict(), ckpt_path) + torch.save(best_weights, ckpt_path) with open(ckpt_path, "rb") as _file: bytes = _file.read() model_tag = sha256(bytes).hexdigest()[:8] From f27fe1c62e8f3cc73954d9a1a2d540966d232173 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 10:42:19 -0700 Subject: [PATCH 05/17] Updating defaults --- src/brevitas_examples/super_resolution/utils/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index e66a79a5c..a3ca12493 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -173,7 +173,7 @@ def get_bsd300_dataloaders( batch_size_test: int = 100, pin_memory: bool = True, upscale_factor: int = 3, - crop_size: int = 256, + crop_size: int = 512, download: bool = False) -> Tuple[Type[DataLoader]]: """Function that loads BSD300 dataset from data_root folder and returns the training and testing dataloaders. If /BSD300/images does not exist, then the data is @@ -184,10 +184,10 @@ def get_bsd300_dataloaders( num_workers (int): Number of workers to use for both dataloaders. Default: 0 batch_size (int): Size of batches to use for the training dataloader. Default: 32 batch_size_test (int): Size of batches to use for the testing dataloader. When - None, then batch_size_test = batch_size. Default: 32 + None, then batch_size_test = batch_size. Default: 100 pin_memory (bool): Whether or not to pin the memory for both dataloaders. Default: True upscale_factor (int): The upscale factor for the super resolution task. Default: 3 - crop_size (int): The size to crop images for upscaling. Default: 256 + crop_size (int): The size to crop images for upscaling. Default: 512 download (bool): Whether or not to download the dataset. Default: False """ data_root = download_bsd300(data_root, download) From f693deeb188cdd4f2d9f9b8a345d32f30d516354 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 10:42:28 -0700 Subject: [PATCH 06/17] Update train.py --- src/brevitas_examples/super_resolution/utils/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index 5d12e8b19..af17a86d1 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -15,6 +15,7 @@ def calc_average_psnr(ref_images: Tensor, gen_images: Tensor, eps: float = 1e-10 def train_for_epoch(trainloader, model, criterion, optimizer): + model.train() tot_loss = 0. for i, (images, targets) in enumerate(trainloader): optimizer.zero_grad() From 4b9b6f9f69da49f173dbbfee07675fced9442076 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 10:42:49 -0700 Subject: [PATCH 07/17] Removing cache_inference_quant_inp --- src/brevitas_examples/super_resolution/utils/evaluate.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/evaluate.py b/src/brevitas_examples/super_resolution/utils/evaluate.py index 0319700e6..39c05a434 100644 --- a/src/brevitas_examples/super_resolution/utils/evaluate.py +++ b/src/brevitas_examples/super_resolution/utils/evaluate.py @@ -29,10 +29,6 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor: def evaluate_accumulator_bit_widths(model: nn.Module, inp: Tensor): - if isinstance(model, QuantESPCN): - # Need to cache the quantized input to the final convolution to be able to evaluate the - # accumulator bounds since we need the input bit-width, which is specified at runtime. - model.conv4.conv.cache_inference_quant_inp = True model(inp) # collect quant inputs now that caching is enabled stats = dict() for name, module in model.named_modules(): From 3617422f5c075e4339f54ca6ccba8d1c9525fcc2 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 10:43:15 -0700 Subject: [PATCH 08/17] Updating defaults --- src/brevitas_examples/super_resolution/models/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index 9d03aba51..f781f1487 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -24,9 +24,9 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): scaling_per_output_channel = True -class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): - pre_scaling_min_val = 1e-8 - scaling_min_val = 1e-8 +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): + pre_scaling_min_val = 1e-10 + scaling_min_val = 1e-10 class CommonIntActQuant(Int8ActPerTensorFloat): From 3afebd99eb06e7b9524dfd3aad89d9be1e7319e6 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 14:17:55 -0700 Subject: [PATCH 09/17] Adding 4b ESPCN models --- .../super_resolution/models/__init__.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index fa5bff7bd..3ea12fea0 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -3,8 +3,8 @@ from functools import partial from typing import Union - from torch import hub +import torch.nn as nn from .espcn import * @@ -26,7 +26,23 @@ upscale_factor=2, weight_bit_width=8, act_bit_width=8, - acc_bit_width=16)} + acc_bit_width=16), + 'quant_espcn_x2_w4a4_base': + partial(quant_espcn_base, upscale_factor=2, weight_bit_width=4, act_bit_width=4), + 'quant_espcn_x2_w4a4_a2q_32b': + partial( + quant_espcn_a2q, + upscale_factor=2, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=32), + 'quant_espcn_x2_w4a4_a2q_14b': + partial( + quant_espcn_a2q, + upscale_factor=2, + weight_bit_width=4, + act_bit_width=4, + acc_bit_width=14)} root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res-r0' @@ -40,8 +56,10 @@ def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]: if name not in model_impl.keys(): raise NotImplementedError(f"{name} does not exist.") - model = model_impl[name]() + model: nn.Module = model_impl[name]() if pretrained: + if name not in model_impl: + raise NotImplementedError(f"Error: {name} does not have a pre-trained checkpoint.") checkpoint = model_url[name] state_dict = hub.load_state_dict_from_url(checkpoint, progress=True, map_location='cpu') model.load_state_dict(state_dict, strict=True) From 1d0ef69044617bb2f4ffe9e52ca38eb1265b5862 Mon Sep 17 00:00:00 2001 From: icolbert Date: Tue, 12 Sep 2023 20:53:40 -0700 Subject: [PATCH 10/17] Adding regularization penalty --- .../super_resolution/utils/train.py | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index af17a86d1..e10ffe335 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -3,6 +3,8 @@ import torch from torch import Tensor +from brevitas.function import abs_binary_sign_grad +from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -16,16 +18,43 @@ def calc_average_psnr(ref_images: Tensor, gen_images: Tensor, eps: float = 1e-10 def train_for_epoch(trainloader, model, criterion, optimizer): model.train() - tot_loss = 0. - for i, (images, targets) in enumerate(trainloader): + + tot_loss, reg_penalty = 0., 0. + + def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output): + """Accumulate the regularization penalty across constrained layers""" + nonlocal reg_penalty + (weights, input_bit_width, input_is_signed) = inp + s = module.scaling_impl(weights) # s + g = abs_binary_sign_grad(module.restrict_clamp_scaling(module.value)) # g + T = module.get_upper_bound_on_l1_norm(input_bit_width, input_is_signed) # T / s + cur_penalty = torch.relu(g - (T * s)).sum() + reg_penalty += cur_penalty + return output + + # Register a forward hook to accumulate the regularization penalty + hook_fns = list() + for mod in model.modules(): + if isinstance(mod, AccumulatorAwareParameterPreScaling): + hook = mod.register_forward_hook(acc_reg_penalty) + hook_fns.append(hook) + + for _, (images, targets) in enumerate(trainloader): optimizer.zero_grad() images = images.to(device) targets = targets.to(device) outputs = model(images) - loss: Tensor = criterion(outputs, targets) + task_loss: Tensor = criterion(outputs, targets) + loss = task_loss + reg_penalty loss.backward() optimizer.step() - tot_loss += loss.item() * images.size(0) + reg_penalty = 0. # reset the accumulated regularization penalty + tot_loss += task_loss.item() * images.size(0) + + # Remove the registered forward hooks before exiting + for hook in hook_fns: + hook.remove() + avg_loss = tot_loss / len(trainloader.dataset) return avg_loss From cddc7ff0232bccb0daa39420b3c8e209a1d8bfeb Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:18:38 -0700 Subject: [PATCH 11/17] Updating defaults in CLI args --- src/brevitas_examples/super_resolution/train_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/train_model.py b/src/brevitas_examples/super_resolution/train_model.py index 6f1086579..f446355cb 100644 --- a/src/brevitas_examples/super_resolution/train_model.py +++ b/src/brevitas_examples/super_resolution/train_model.py @@ -47,10 +47,10 @@ parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers') parser.add_argument('--batch_size', type=int, default=8, help='Minibatch size') parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate') -parser.add_argument('--total_epochs', type=int, default=100, help='Total number of training epochs') -parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay') +parser.add_argument('--total_epochs', type=int, default=500, help='Total number of training epochs') +parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay') parser.add_argument('--step_size', type=int, default=1) -parser.add_argument('--gamma', type=float, default=0.98) +parser.add_argument('--gamma', type=float, default=0.999) parser.add_argument('--eval_acc_bw', action='store_true', default=False) parser.add_argument('--save_pth_ckpt', action='store_true', default=False) parser.add_argument('--save_model_io', action='store_true', default=False) From 6ff6184aa35a6d7f398930a55f1d2544c5b7e150 Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:18:53 -0700 Subject: [PATCH 12/17] Adding weight for regularization penalty --- src/brevitas_examples/super_resolution/utils/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index e10ffe335..24e859ac0 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -16,7 +16,7 @@ def calc_average_psnr(ref_images: Tensor, gen_images: Tensor, eps: float = 1e-10 return psnr.mean() -def train_for_epoch(trainloader, model, criterion, optimizer): +def train_for_epoch(trainloader, model, criterion, optimizer, reg_weight: float = 1e-3): model.train() tot_loss, reg_penalty = 0., 0. @@ -45,7 +45,7 @@ def acc_reg_penalty(module: AccumulatorAwareParameterPreScaling, inp, output): targets = targets.to(device) outputs = model(images) task_loss: Tensor = criterion(outputs, targets) - loss = task_loss + reg_penalty + loss = task_loss + (reg_weight * reg_penalty) loss.backward() optimizer.step() reg_penalty = 0. # reset the accumulated regularization penalty From 8dde0990c4435458d5dfa966ade95f06c0bb76c2 Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:19:10 -0700 Subject: [PATCH 13/17] Update README.md --- .../super_resolution/README.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/brevitas_examples/super_resolution/README.md b/src/brevitas_examples/super_resolution/README.md index 2e762393e..55f1b5d75 100644 --- a/src/brevitas_examples/super_resolution/README.md +++ b/src/brevitas_examples/super_resolution/README.md @@ -1,22 +1,29 @@ # Integer-Quantized Super Resolution Experiments with Brevitas -This directory contains training scripts to demonstrate how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas). +This directory contains scripts demonstrating how to train integer-quantized super resolution models using [Brevitas](https://github.com/Xilinx/brevitas). Code is also provided to demonstrate accumulator-aware quantization (A2Q) as proposed in our ICCV 2023 paper "[A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance](https://arxiv.org/abs/2308.13504)". ## Experiments All models are trained on the BSD300 dataset to upsample images by 2x. -Target images are center cropped to 512x512. +Target images are cropped to 512x512. +During training random cropping is applied, along with random vertical and horizontal flips. +During inference center cropping is applied. Inputs are then downscaled by 2x and then used to train the model directly in the RGB space. Note that this is a difference from many academic works that train only on the Y-channel in YCbCr format. | Model Name | Upscale Factor | Weight quantization | Activation quantization | Peak Signal-to-Noise Ratio | |-----------------------------|----------------|---------------------|-------------------------|----------------------------| -| [float_espcn_x2](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/float_espcn_x2-2f3821e3.pth) | x2 | float32 | float32 | 30.37 | -| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_base-7d54e29c.pth) | x2 | int8 | (u)int8 | 30.16 | -| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_32b-0b1f361d.pth) | x2 | int8 | (u)int8 | 30.80 | -| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r0/quant_espcn_x2_w8a8_a2q_16b-3c4acd35.pth) | x2 | int8 | (u)int8 | 29.38 | | bicubic_interp | x2 | N/A | N/A | 28.71 | +| [float_espcn_x2]() | x2 | float32 | float32 | 31.03 | +|| +| [quant_espcn_x2_w8a8_base]() | x2 | int8 | (u)int8 | 30.96 | +| [quant_espcn_x2_w8a8_a2q_32b]() | x2 | int8 | (u)int8 | 30.79 | +| [quant_espcn_x2_w8a8_a2q_16b]() | x2 | int8 | (u)int8 | 30.56 | +|| +| [quant_espcn_x2_w4a4_base]() | x2 | int4 | (u)int4 | 30.30 | +| [quant_espcn_x2_w4a4_a2q_32b]() | x2 | int4 | (u)int4 | 30.27 | +| [quant_espcn_x2_w4a4_a2q_14b]() | x2 | int4 | (u)int4 | 30.24 | ## Train From e9de10f49e39f7131911c7a6d590716ed1522ca3 Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:25:28 -0700 Subject: [PATCH 14/17] Pre-commit fixes --- .../super_resolution/models/__init__.py | 1 + .../super_resolution/models/common.py | 2 +- .../super_resolution/models/espcn.py | 25 ++++--------------- .../super_resolution/utils/dataset.py | 10 +++----- .../super_resolution/utils/train.py | 3 ++- 5 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index 3ea12fea0..f3c9a0ad5 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -3,6 +3,7 @@ from functools import partial from typing import Union + from torch import hub import torch.nn as nn diff --git a/src/brevitas_examples/super_resolution/models/common.py b/src/brevitas_examples/super_resolution/models/common.py index f781f1487..d3022d089 100644 --- a/src/brevitas_examples/super_resolution/models/common.py +++ b/src/brevitas_examples/super_resolution/models/common.py @@ -24,7 +24,7 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat): scaling_per_output_channel = True -class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): +class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant): pre_scaling_min_val = 1e-10 scaling_min_val = 1e-10 diff --git a/src/brevitas_examples/super_resolution/models/espcn.py b/src/brevitas_examples/super_resolution/models/espcn.py index 3d10e6581..f3123fcb9 100644 --- a/src/brevitas_examples/super_resolution/models/espcn.py +++ b/src/brevitas_examples/super_resolution/models/espcn.py @@ -14,12 +14,7 @@ from .common import ConstUint8ActQuant __all__ = [ - "float_espcn", - "quant_espcn", - "quant_espcn_a2q", - "quant_espcn_base", - "FloatESPCN", - "QuantESPCN"] + "float_espcn", "quant_espcn", "quant_espcn_a2q", "quant_espcn_base", "FloatESPCN", "QuantESPCN"] IO_DATA_BIT_WIDTH = 8 IO_ACC_BIT_WIDTH = 32 @@ -47,19 +42,9 @@ def __init__(self, upscale_factor: int = 3, num_channels: int = 3): padding=2, bias=True) self.conv2 = nn.Conv2d( - in_channels=64, - out_channels=64, - kernel_size=3, - stride=1, - padding=1, - bias=True) + in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True) self.conv3 = nn.Conv2d( - in_channels=64, - out_channels=32, - kernel_size=3, - stride=1, - padding=1, - bias=True) + in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True) self.conv4 = nn.Conv2d( in_channels=32, out_channels=num_channels * pow(upscale_factor, 2), @@ -85,7 +70,7 @@ def forward(self, inp: Tensor): x = self.relu(self.bn2(self.conv2(x))) x = self.relu(self.bn3(self.conv3(x))) x = self.pixel_shuffle(self.conv4(x)) - x = self.out(x) # To mirror quant version + x = self.out(x) # To mirror quant version return x @@ -145,7 +130,7 @@ def __init__( weight_quant=weight_quant) # We quantize the weights and input activations of the final layer # to 8-bit integers. We do not apply the accumulator constraint to - # the final convolution layer. FINN does not currently support + # the final convolution layer. FINN does not currently support # per-tensor quantization or biases for sub-pixel convolution layers. self.conv4 = qnn.QuantConv2d( in_channels=32, diff --git a/src/brevitas_examples/super_resolution/utils/dataset.py b/src/brevitas_examples/super_resolution/utils/dataset.py index a3ca12493..cbf1004be 100644 --- a/src/brevitas_examples/super_resolution/utils/dataset.py +++ b/src/brevitas_examples/super_resolution/utils/dataset.py @@ -50,11 +50,11 @@ import torch.utils.data as data from torchvision.transforms import CenterCrop from torchvision.transforms import Compose -from torchvision.transforms import Resize -from torchvision.transforms import ToTensor from torchvision.transforms import RandomCrop -from torchvision.transforms import RandomVerticalFlip from torchvision.transforms import RandomHorizontalFlip +from torchvision.transforms import RandomVerticalFlip +from torchvision.transforms import Resize +from torchvision.transforms import ToTensor __all__ = ["get_bsd300_dataloaders"] @@ -127,9 +127,7 @@ def calculate_valid_crop_size(crop_size, upscale_factor): def train_transforms(crop_size): return Compose([ - RandomCrop(crop_size, pad_if_needed=True), - RandomHorizontalFlip(), - RandomVerticalFlip()]) + RandomCrop(crop_size, pad_if_needed=True), RandomHorizontalFlip(), RandomVerticalFlip()]) def test_transforms(crop_size): diff --git a/src/brevitas_examples/super_resolution/utils/train.py b/src/brevitas_examples/super_resolution/utils/train.py index 24e859ac0..94f9c36c3 100644 --- a/src/brevitas_examples/super_resolution/utils/train.py +++ b/src/brevitas_examples/super_resolution/utils/train.py @@ -3,8 +3,9 @@ import torch from torch import Tensor -from brevitas.function import abs_binary_sign_grad + from brevitas.core.scaling.pre_scaling import AccumulatorAwareParameterPreScaling +from brevitas.function import abs_binary_sign_grad device = 'cuda' if torch.cuda.is_available() else 'cpu' From 90637c784459494c17ac38e289e041ddd2f5b24e Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:35:59 -0700 Subject: [PATCH 15/17] Adding test to verify float and quant models match --- .../brevitas_examples/test_examples_import.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/brevitas_examples/test_examples_import.py b/tests/brevitas_examples/test_examples_import.py index 3f728428c..c80dd3161 100644 --- a/tests/brevitas_examples/test_examples_import.py +++ b/tests/brevitas_examples/test_examples_import.py @@ -1,6 +1,11 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +import pytest + +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8WeightPerChannelFloat + def test_import_bnn_pynq(): from brevitas_examples.bnn_pynq import cnv_1w1a @@ -31,3 +36,17 @@ def test_import_stt(): from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_4b from brevitas_examples.speech_to_text import quant_quartznet_perchannelscaling_8b from brevitas_examples.speech_to_text import quant_quartznet_pertensorscaling_8b + + +@pytest.mark.parametrize("upscale_factor", [2, 3, 4]) +@pytest.mark.parametrize("num_channels", [1, 3]) +@pytest.mark.parametrize( + "weight_quant", [Int8WeightPerChannelFloat, Int8AccumulatorAwareWeightQuant]) +def test_super_resolution_float_and_quant_models_match(upscale_factor, num_channels, weight_quant): + import brevitas.config as config + from brevitas_examples.super_resolution.models import float_espcn + from brevitas_examples.super_resolution.models import quant_espcn + config.IGNORE_MISSING_KEYS = True + float_model = float_espcn(upscale_factor, num_channels) + quant_model = quant_espcn(upscale_factor, num_channels, weight_quant=weight_quant) + quant_model.load_state_dict(float_model.state_dict()) From 1789beed248e5a934ac6ad313cadfd8f3f2221bc Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 16:52:45 -0700 Subject: [PATCH 16/17] Update eval_model.py --- .../super_resolution/eval_model.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/super_resolution/eval_model.py b/src/brevitas_examples/super_resolution/eval_model.py index a567469d1..60883dea6 100644 --- a/src/brevitas_examples/super_resolution/eval_model.py +++ b/src/brevitas_examples/super_resolution/eval_model.py @@ -31,13 +31,22 @@ parser = argparse.ArgumentParser(description='PyTorch BSD300 Validation') parser.add_argument('--data_root', help='Path to folder containing BSD300 val folder') -parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint') +parser.add_argument('--model_path', default=None, help='Path to PyTorch checkpoint. Default = None') parser.add_argument( - '--save_path', type=str, default='outputs/', help='Save path for exported model') + '--save_path', + type=str, + default='outputs/', + help='Save path for exported model. Default = outputs/') parser.add_argument( - '--model', type=str, default='quant_espcn_x2_w8a8_base', help='Name of the model configuration') -parser.add_argument('--workers', type=int, default=0, help='Number of data loading workers') -parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size') + '--model', + type=str, + default='quant_espcn_x2_w8a8_base', + help='Name of the model configuration. Default = quant_espcn_x2_w8a8_base') +parser.add_argument( + '--workers', type=int, default=0, help='Number of data loading workers. Default = 0') +parser.add_argument('--batch_size', type=int, default=16, help='Minibatch size. Default = 16') +parser.add_argument( + '--crop_size', type=int, default=512, help='The size to crop the image. Default = 512') parser.add_argument('--use_pretrained', action='store_true', default=False) parser.add_argument('--eval_acc_bw', action='store_true', default=False) parser.add_argument('--save_model_io', action='store_true', default=False) @@ -60,6 +69,7 @@ def main(): num_workers=args.workers, batch_size=args.batch_size, upscale_factor=model.upscale_factor, + crop_size=args.crop_size, download=True) test_psnr = evaluate_avg_psnr(testloader, model) From 3e0dfec4093ab3bd3192b7c3253c83176913130a Mon Sep 17 00:00:00 2001 From: icolbert Date: Wed, 13 Sep 2023 21:00:55 -0700 Subject: [PATCH 17/17] Fixing 14b -> 13b --- src/brevitas_examples/super_resolution/README.md | 2 +- src/brevitas_examples/super_resolution/models/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/super_resolution/README.md b/src/brevitas_examples/super_resolution/README.md index 55f1b5d75..9e0068c27 100644 --- a/src/brevitas_examples/super_resolution/README.md +++ b/src/brevitas_examples/super_resolution/README.md @@ -23,7 +23,7 @@ Note that this is a difference from many academic works that train only on the Y || | [quant_espcn_x2_w4a4_base]() | x2 | int4 | (u)int4 | 30.30 | | [quant_espcn_x2_w4a4_a2q_32b]() | x2 | int4 | (u)int4 | 30.27 | -| [quant_espcn_x2_w4a4_a2q_14b]() | x2 | int4 | (u)int4 | 30.24 | +| [quant_espcn_x2_w4a4_a2q_13b]() | x2 | int4 | (u)int4 | 30.24 | ## Train diff --git a/src/brevitas_examples/super_resolution/models/__init__.py b/src/brevitas_examples/super_resolution/models/__init__.py index f3c9a0ad5..872624a85 100644 --- a/src/brevitas_examples/super_resolution/models/__init__.py +++ b/src/brevitas_examples/super_resolution/models/__init__.py @@ -37,13 +37,13 @@ weight_bit_width=4, act_bit_width=4, acc_bit_width=32), - 'quant_espcn_x2_w4a4_a2q_14b': + 'quant_espcn_x2_w4a4_a2q_13b': partial( quant_espcn_a2q, upscale_factor=2, weight_bit_width=4, act_bit_width=4, - acc_bit_width=14)} + acc_bit_width=13)} root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res-r0'