-
Notifications
You must be signed in to change notification settings - Fork 199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] Feat cailey sgd #1127
base: dev
Are you sure you want to change the base?
[DRAFT] Feat cailey sgd #1127
Conversation
@@ -166,3 +183,86 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False): | |||
|
|||
def is_pow2(n): | |||
return (n & (n - 1) == 0) and (n > 0) | |||
|
|||
|
|||
hadamard_string_16 = """++++++++++++++++ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is temporal code. These matrices will be included in hadamard.pt, when the PR is in merge state.
src/brevitas/graph/equalize.py
Outdated
regions: List[Region] = [] | ||
self.find_module(model, regions) | ||
if len(regions) > 0: | ||
_apply_rotate(model, regions) | ||
return model | ||
|
||
|
||
def _apply_rotate_fused_rotations( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need an entire new function
@@ -0,0 +1,197 @@ | |||
# coding=utf-8 | |||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a reference to the code origin
src/brevitas_examples/llm/main.py
Outdated
@@ -49,7 +53,7 @@ def set_seed(seed): | |||
torch.random.manual_seed(seed) | |||
|
|||
|
|||
def fused_rotation_no_fx(model, calibration_loader, args): | |||
def fused_rotation_no_fx(model, calibration_loader, args, fuse_rotations: bool = False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fuse_rotation
-> optimize_rotation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The matrices can be kept unfused, even if they are no optimized (even if this is not leveraged at any point), so this name seems more general.
Reason for this PR
Changes Made in this PR
Testing Summary
Risk Highlight
Checklist
dev
branch.