Skip to content

Commit

Permalink
Last changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent f822e45 commit cad0516
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def cat(tensors, dim, out=None):

def __neg__(self):
neg_value = (-self.int(float_datatype=True) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
neg_value = neg_value.type(self.scale.dtype)
if self.signed:
return QuantTensor(
value=neg_value,
Expand Down Expand Up @@ -447,6 +449,8 @@ def __truediv__(self, other):
def __abs__(self):
if self.signed:
abs_value = (torch.abs(self.int(float_datatype=True)) - self.zero_point) * self.scale
# In case the dtype of self.int is different from the one of the scale
abs_value = abs_value.type(self.scale.dtype)
return QuantTensor(
value=abs_value,
scale=self.scale,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,15 @@ def main():

# Get the model from torchvision
model = get_torchvision_model(args.model_name)
model = model.to(dtype)

# Preprocess the model for quantization
if args.target_backend == 'flexml':
# flexml requires static shapes, pass a representative input in
img_shape = model_config['center_crop_shape']
model = preprocess_for_flexml_quantize(
model,
torch.ones(1, 3, img_shape, img_shape),
torch.ones(1, 3, img_shape, img_shape, dtype=dtype),
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
Expand All @@ -330,7 +331,6 @@ def main():
if args.act_equalization is not None:
print("Applying activation equalization:")
apply_act_equalization(model, calib_loader, layerwise=args.act_equalization == 'layerwise')
model = model.to(dtype)

# Define the quantized model
quant_model = quantize_model(
Expand Down

0 comments on commit cad0516

Please sign in to comment.