Skip to content

Commit

Permalink
Adress comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 27, 2024
1 parent f1a188b commit a9e9b52
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import RandomSampler
from tqdm import tqdm

from brevitas import config
Expand Down Expand Up @@ -285,10 +284,6 @@ def initialize_cache(self) -> None:
def clear_cache(self) -> None:
pass

@abstractmethod
def reset_cache(self) -> None:
pass

@abstractmethod
def cache_to_dataset(self) -> Dataset:
pass
Expand Down Expand Up @@ -699,7 +694,7 @@ def apply_learned_round(
block_forward: Callable,
data_loader: DataLoader,
cache: Cache,
block_check_fn: Callable,
get_blocks_fn: Callable,
model_prepare_fn: Optional[Callable] = None,
model_finish_fn: Optional[Callable] = None,
keep_gpu: bool = True) -> None:
Expand All @@ -711,7 +706,7 @@ def apply_learned_round(
self.learned_round.insert_learned_round_quantizers(model)

# Retrieve blocks using the appropiate function to check blocks
blocks = get_blocks(model, block_check_fn)
blocks = get_blocks_fn(model)

print(f"Total Iterations per block {self.iters}")
print(f"Number of blocks {len(blocks)}")
Expand All @@ -726,7 +721,6 @@ def apply_learned_round(
model = offload_model(model)
# Cache needs to be cleared before populating it with the inputs and outputs
# to the block under optimization.
cache.clear_cache()
self._populate_cache(
cache,
model,
Expand Down Expand Up @@ -801,7 +795,7 @@ def apply_learned_round(

# TODO: This call might not be needed, check_clear and reset_cache methods
# Reset cache after optimisation
cache.reset_cache()
cache.clear_cache()

# The original configuration of the model is restored after finishing the optimization
if model_finish_fn is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def parse_lr_scheduler_class(lr_scheduler_str: str) -> Type[LRScheduler]:
torch.optim.lr_scheduler.__dict__[lr_scheduler_key] != LRScheduler and
isinstance(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], type) and
issubclass(torch.optim.lr_scheduler.__dict__[lr_scheduler_key], LRScheduler))]
print(lr_scheduler_keys)
if len(lr_scheduler_keys) == 0:
warnings.warn(
f"There are no matches for LR scheduler {lr_scheduler_str}. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import functools
import re
from typing import Any, Callable, Dict, Optional, Tuple, Union
import warnings
Expand All @@ -39,6 +40,8 @@
from brevitas import config
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.common.learned_round.learned_round_optimizer import Cache
from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks
from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer
from brevitas_examples.common.learned_round.learned_round_parser import parse_learned_round
from brevitas_examples.common.learned_round.learned_round_parser import \
Expand All @@ -62,7 +65,7 @@ def is_layer(module: nn.Module, module_name: str) -> bool:
"blockwise": is_resnet_block,}


class CacheVision(dict):
class CacheVision(Cache, dict):

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -97,12 +100,6 @@ def clear_cache(self) -> None:
self["inputs"] = []
self["output"] = []

def reset_cache(self) -> None:
del self["inputs"]
del self["output"]
self["inputs"] = []
self["output"] = []

def sample_batch(self, indices: torch.Tensor) -> Union[Any, torch.Tensor]:
if isinstance(self["inputs"], list):
self["inputs"] = torch.cat(self["inputs"], dim=self.batch_dim)
Expand Down Expand Up @@ -166,6 +163,7 @@ def apply_learned_round(
warnings.warn(
f"{learned_round_mode} is not a valid learned round mode. Defaulting to layerwise.")
block_check_fn = BLOCK_CHECK_MAP[learned_round_mode]
get_blocks_fn = functools.partial(get_blocks, block_check_fn=block_check_fn)
lr_scheduler_kwargs = {
"start_factor": 1.0,
"end_factor": 0.0,
Expand All @@ -192,6 +190,6 @@ def apply_learned_round(
block_forward=cnn_block_forward,
data_loader=calibration_loader,
cache=cache,
block_check_fn=block_check_fn,
get_blocks_fn=get_blocks_fn,
keep_gpu=True,
)
175 changes: 0 additions & 175 deletions src/brevitas_examples/llm/benchmark/llm_benchmark.py

This file was deleted.

32 changes: 0 additions & 32 deletions src/brevitas_examples/llm/benchmark/post_processing.py

This file was deleted.

Loading

0 comments on commit a9e9b52

Please sign in to comment.