From d54c8320ca32a390f958e6f42d43389f93489e7d Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Tue, 20 Aug 2024 12:32:49 +0100 Subject: [PATCH] Feat (example/llm): Specify LLMs to use SDPA for their attn implementation --- src/brevitas_examples/llm/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b74225a6a..0d29c1b12 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -183,7 +183,7 @@ def main(args): kwargs['torchscript'] = True print("Model loading...") - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = AutoModelForCausalLM.from_pretrained(args.model, attn_implementation="sdpa", **kwargs) print("Model loaded.") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model)