Skip to content

Commit

Permalink
Merge pull request #715 from PINTO0309/multi-batch-quant
Browse files Browse the repository at this point in the history
Multi batch quant
  • Loading branch information
PINTO0309 authored Oct 19, 2024
2 parents 9a4c2e9 + b10e8c6 commit 67d5b4c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.26.1
ghcr.io/pinto0309/onnx2tf:1.26.2

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.26.1
docker.io/pinto0309/onnx2tf:1.26.2

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.26.1'
__version__ = '1.26.2'
10 changes: 7 additions & 3 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,7 @@ def sanitizing(node):
mean,
std,
]

elif custom_input_op_name_np_data_path is not None:
for param in custom_input_op_name_np_data_path:
if len(param) != 4:
Expand All @@ -1652,11 +1653,14 @@ def sanitizing(node):

# representative_dataset_gen
def representative_dataset_gen():
for idx in range(data_count):
batch_size = model.inputs[0].shape[0]
if not isinstance(batch_size, int):
batch_size = 1
for idx in range(0, data_count, batch_size):
yield_data_dict = {}
for model_input_name in model_input_name_list:
calib_data, mean, std = calib_data_dict[model_input_name]
normalized_calib_data: np.ndarray = (calib_data[idx] - mean) / std
normalized_calib_data: np.ndarray = (calib_data[idx:idx+batch_size] - mean) / std
yield_data_dict[model_input_name] = tf.cast(tf.convert_to_tensor(normalized_calib_data), tf.float32)
yield yield_data_dict

Expand Down Expand Up @@ -1708,7 +1712,7 @@ def representative_dataset_gen():
inf_type_input = tf.float32
else:
inf_type_input = tf.int8

if output_quant_dtype == 'int8':
inf_type_output = tf.int8
elif output_quant_dtype == 'uint8':
Expand Down

0 comments on commit 67d5b4c

Please sign in to comment.