Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Remove QOP Export #917

Merged
merged 6 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 0 additions & 53 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,61 +170,12 @@ class BaseManager(ABC):
def export(cls, *args, **kwargs):
return

@classmethod
def _gen_patches(cls, fn_dispatcher):
patches = []
for fn in cls._fn_to_cache:
dispatcher = partial(fn_dispatcher, fn)
p = patch(torch.nn.functional, fn.__name__, dispatcher)
patches.append(p)
return patches

@classmethod
def _trace_patches(cls):
patches = cls._gen_patches(cls._trace_fn_dispatcher)
return patches

@classmethod
def _cache_patches(cls):
return cls._gen_patches(cls._cache_fn_dispatcher)

@classmethod
def _restore_fn_patches(cls):
return [patch(torch.nn.functional, fn.__name__, fn) for fn in cls._fn_to_cache]

@classmethod
def _trace_fn_dispatcher(cls, fn, input, *args, **kwargs):
# baseline impl
cls._fn_to_cache.pop(0)
return fn(input, *args, **kwargs)

@classmethod
def _cache_fn_dispatcher(cls, fn, input, *args, **kwargs):
with ExitStack() as stack:
# disable recursing into this patch
for mgr in cls._restore_fn_patches():
stack.enter_context(mgr)
if isinstance(input, QuantTensor):
inp_cache = None
out_cache = None
inp_cache = _CachedIO(input, metadata_only=True)
output = fn(input, *args, **kwargs)
if isinstance(output, QuantTensor):
out_cache = _CachedIO(output, metadata_only=True)
cached_io = (inp_cache, out_cache)
if fn in cls._cached_io_handler_map:
cached_io = cls._cached_io_handler_map[fn](cached_io)
cls._fn_cache.append(cached_io)
else:
# could be a fn invoked within a quant module on a dequant tensor
# or a function invoked on a float tensor. The former won't show
# up during jit tracing as they are replaced by symbolic functions,
# but the latter will, so we have to account for them in the _fn_cache
output = fn(input, *args, **kwargs)
if not config._IS_INSIDE_QUANT_LAYER:
cls._fn_cache.append(None)
return output

@classmethod
def handler_from_module(cls, module: Module, no_inheritance=False):
for handler in cls.handlers:
Expand Down Expand Up @@ -254,8 +205,6 @@ def _cache_inp_out(cls, module, *args, **kwargs):
module.apply(lambda m: _override_inp_caching_mode(m, enabled=True))
module.apply(lambda m: _override_out_caching_mode(m, enabled=True))
with ExitStack() as stack:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't add anything to the stack, we can remove the context manager and leave the forward pass as part of the function

for mgr in cls._cache_patches():
stack.enter_context(mgr)
_ = module.forward(*args, **kwargs)
# Restore previous caching properties
module.apply(lambda m: _restore_quant_metadata_caching_mode(m))
Expand All @@ -280,8 +229,6 @@ def jit_inference_trace(
# force requires_grad to False to let the wrapped model lambda go through tracing
requires_grad_backup_dict = _force_requires_grad_false(module)
with ExitStack() as stack:
for mgr in cls._trace_patches():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before

stack.enter_context(mgr)
# wrapping with a lambda forces inlining during tracing,
# converts everything to const and removes unused params/buffers
traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args)
Expand Down
2 changes: 0 additions & 2 deletions src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ def export_onnx(
module.apply(lambda m: _override_inp_caching_mode(m, enabled=False))
# perform export pass
with ExitStack() as stack:
for mgr in cls._trace_patches():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before, everything can be flattened one level down

stack.enter_context(mgr)
if export_path is not None:
export_target = export_path
else:
Expand Down
Loading