diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index b74225a6a..0d29c1b12 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -183,7 +183,7 @@ def main(args): kwargs['torchscript'] = True print("Model loading...") - model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs) + model = AutoModelForCausalLM.from_pretrained(args.model, attn_implementation="sdpa", **kwargs) print("Model loaded.") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model)