diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 7bbbb201d..93e99eac9 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -59,7 +59,7 @@ def match_node(self, node: Node) -> bool: is_adaptive_2d_mean = ((2, 3) in node.args or [2, 3] in node.args or 'dim' in node.kwargs and (node.kwargs['dim'] == (2, 3) or node.kwargs['dim'] == [2, 3])) - is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs['keepdim'] + is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs.get('keepdim', False) return spr and is_adaptive_2d_mean def move_node_args_to_kwargs(self, node: Node): diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 62c501250..0abde1f48 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -222,6 +222,9 @@ def ndim(self): def dim(self): return self.value.dim() + def mean(self, *args, **kwargs): + return self.value.mean(*args, **kwargs) + @property def shape(self): return self.value.shape @@ -232,6 +235,15 @@ def dim(self): def add(self, other): return self + other + def sum(self, *args, **kwargs): + return self.value.sum(*args, **kwargs) + + def unsqueeze(self, *args, **kwargs): + return self.value.unsqueeze(*args, **kwargs) + + def sigmoid(self): + return self.value.sigmoid() + @staticmethod def cat(tensors, dim, out=None): if out is not None: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 1f7c06a2b..69da2ab8e 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -8,6 +8,7 @@ import warnings import numpy as np +import timm import torch import torch.backends.cudnn as cudnn import torch.nn.parallel @@ -33,6 +34,7 @@ from brevitas_examples.imagenet_classification.ptq.utils import get_model_config from brevitas_examples.imagenet_classification.ptq.utils import get_torchvision_model from brevitas_examples.imagenet_classification.utils import generate_dataloader +from brevitas_examples.imagenet_classification.utils import generate_dataloader_with_transform from brevitas_examples.imagenet_classification.utils import SEED from brevitas_examples.imagenet_classification.utils import validate @@ -47,10 +49,6 @@ def parse_type(v, default_type): return default_type(v) -model_names = sorted( - name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and - callable(torchvision.models.__dict__[name]) and not name.startswith("get_")) - parser = argparse.ArgumentParser(description='PyTorch ImageNet PTQ Validation') parser.add_argument( '--calibration-dir', @@ -75,12 +73,16 @@ def parse_type(v, default_type): parser.add_argument('--gpu', default=None, type=int, help='GPU id to use (default: None)') parser.add_argument( '--calibration-samples', default=1000, type=int, help='Calibration size (default: 1000)') +parser.add_argument( + '--repository', + default='torchvision', + choices=['torchvision', 'timm'], + help='Source of models (default: torchvision)') parser.add_argument( '--model-name', default='resnet18', metavar='ARCH', - choices=model_names, - help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)') + help='model architecture: (default: resnet18)') parser.add_argument( '--dtype', default='float', choices=['float', 'bfloat16'], help='Data type to use') parser.add_argument( @@ -181,6 +183,11 @@ def parse_type(v, default_type): 'weight-narrow-range', default=False, help='Narrow range for weight quantization (default: disabled)') +add_bool_arg( + parser, + 'validate-before-quantize', + default=False, + help='Run validation on the model before it is quantized') parser.add_argument('--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 1.0)') parser.add_argument( '--quant-format', @@ -331,30 +338,58 @@ def main(): # Get model-specific configurations about input shapes and normalization model_config = get_model_config(args.model_name) - # Generate calibration and validation dataloaders - resize_shape = model_config['resize_shape'] - center_crop_shape = model_config['center_crop_shape'] - inception_preprocessing = model_config['inception_preprocessing'] - calib_loader = generate_dataloader( - args.calibration_dir, - args.batch_size_calibration, - args.workers, - resize_shape, - center_crop_shape, - args.calibration_samples, - inception_preprocessing) - val_loader = generate_dataloader( - args.validation_dir, - args.batch_size_validation, - args.workers, - resize_shape, - center_crop_shape, - inception_preprocessing=inception_preprocessing) - - # Get the model from torchvision - model = get_torchvision_model(args.model_name) + # Get the model from torchvision or timm + if args.repository == 'torchvision': + model = get_torchvision_model(args.model_name) + else: + model = timm.create_model(args.model_name, pretrained=True) + data_cfg = timm.data.resolve_data_config(model.pretrained_cfg) + transform = timm.data.create_transform(**data_cfg) + model_config['resize_shape'] = transform.transforms[0].size + model_config['center_crop_shape'] = transform.transforms[1].size[0] model = model.to(dtype) + # If available, use the selected GPU + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model = model.cuda(args.gpu) + cudnn.benchmark = False + + # Generate calibration and validation dataloaders + if args.repository == 'torchvision': + resize_shape = model_config['resize_shape'] + center_crop_shape = model_config['center_crop_shape'] + inception_preprocessing = model_config['inception_preprocessing'] + + calib_loader = generate_dataloader( + args.calibration_dir, + args.batch_size_calibration, + args.workers, + resize_shape, + center_crop_shape, + args.calibration_samples, + inception_preprocessing) + val_loader = generate_dataloader( + args.validation_dir, + args.batch_size_validation, + args.workers, + resize_shape, + center_crop_shape, + inception_preprocessing=inception_preprocessing) + else: + calib_loader = generate_dataloader_with_transform( + args.calibration_dir, + args.batch_size_calibration, + args.workers, + transform, + args.calibration_samples) + val_loader = generate_dataloader_with_transform( + args.validation_dir, args.batch_size_validation, args.workers, transform) + + if args.validate_before_quantize is True: + print("Starting validation of unquantized model") + validate(val_loader, model, stable=dtype != torch.bfloat16) + # Preprocess the model for quantization if args.target_backend == 'flexml': # flexml requires static shapes, pass a representative input in @@ -376,12 +411,6 @@ def main(): else: raise RuntimeError(f"{args.target_backend} backend not supported.") - # If available, use the selected GPU - if args.gpu is not None: - torch.cuda.set_device(args.gpu) - model = model.cuda(args.gpu) - cudnn.benchmark = False - if args.act_equalization is not None: print("Applying activation equalization:") apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise') diff --git a/src/brevitas_examples/imagenet_classification/ptq/utils.py b/src/brevitas_examples/imagenet_classification/ptq/utils.py index bb622d93a..3d523f994 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/utils.py @@ -15,7 +15,7 @@ def get_model_config(model_name): config = dict() # Set-up config parameters - if model_name == 'inception_v3' or model_name == 'googlenet': + if 'inception_v3' in model_name or 'googlenet' in model_name: config['inception_preprocessing'] = True else: config['inception_preprocessing'] = False diff --git a/src/brevitas_examples/imagenet_classification/utils.py b/src/brevitas_examples/imagenet_classification/utils.py index d506b8a61..8823b9a62 100644 --- a/src/brevitas_examples/imagenet_classification/utils.py +++ b/src/brevitas_examples/imagenet_classification/utils.py @@ -109,6 +109,11 @@ def generate_dataset(dir, resize_shape=256, center_crop_shape=224, inception_pre return dataset +def generate_dataset_with_transform(dir, transform): + dataset = datasets.ImageFolder(dir, transform) + return dataset + + def generate_dataloader( dir, batch_size, @@ -128,3 +133,19 @@ def generate_dataloader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) return loader + + +def generate_dataloader_with_transform( + dir, + batch_size, + num_workers, + transform, + subset_size=None, +): + dataset = generate_dataset_with_transform(dir, transform) + if subset_size is not None: + dataset = torch.utils.data.Subset(dataset, list(range(subset_size))) + loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True) + + return loader