Skip to content

Commit

Permalink
Fix (learned_round): use of named tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 31, 2023
1 parent dedb9b8 commit c91ed5a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]):
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
self._cached_inp = cached_inp
# Remove any naming metadata to avoid dowmstream errors
# Avoid inplace operations on the input in case of forward hooks
if not torch._C._get_tracing_state():
inp.value.rename_(None)
inp = inp.set(value=inp.value.rename(None))
return inp

def pack_output(self, quant_output: QuantTensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from brevitas.inject.enum import FloatToIntImplType
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.quant_tensor import QuantTensor

config.IGNORE_MISSING_KEYS = True

Expand All @@ -53,6 +54,19 @@ def __init__(self, store_output: False):
self.output_store = None

def __call__(self, module, input_batch, output_batch):
input_batch = input_batch[0]
if isinstance(input_batch, QuantTensor):
input_batch = input_batch.value

if hasattr(input_batch, 'names') and 'N' in input_batch.names:
batch_dim = input_batch.names.index('N')

input_batch.rename_(None)
input_batch = input_batch.transpose(0, batch_dim)
if self.store_output:
output_batch.rename_(None)
output_batch = output_batch.transpose(0, batch_dim)

if self.store_output:
self.output_store = output_batch
self.input_store = input_batch
Expand Down Expand Up @@ -183,9 +197,9 @@ def save_inp_out_data(
pass
if store_inp:
if keep_gpu:
cached[0].append(data_saver.input_store[0].detach())
cached[0].append(data_saver.input_store.detach())
else:
cached[0].append(data_saver.input_store[0].detach().cpu())
cached[0].append(data_saver.input_store.detach().cpu())
if store_out:
if keep_gpu:
cached[1].append(data_saver.output_store.detach())
Expand Down

0 comments on commit c91ed5a

Please sign in to comment.