Hallucinations with 8bit Whisper PEFT model - solved with full / half precision #477
Replies: 2 comments 3 replies
-
Hey @sanchit-gandhi I have noticed that Inference Speed using QLoRA (4bits) is relatively slow too. My question is: if we have trained a model using the Combination of PEFT + (LoRA or QLoRA) shouldn't we have to load the model with the same Bits on Inference, or given that the adapters learn in 32fp/16fp we can use the inference in half-precision with no problem? and that means that we could train a full precision model + PEFT (using Accelerate for multiple GPUs) and used it with different types at inference time. |
Beta Was this translation helpful? Give feedback.
-
Hey @sanchit-gandhi I wondering if this is resolved if when finish training with 8Bit or 4Bit at the end the adapter is merged back to the model. (at leat QLoRA) peft_config = PeftConfig.from_pretrained(output_dir)
model = AutoModelForCausalLM.from_pretrained(
peft_config.base_model_name_or_path,
return_dict=True,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(model, output_dir)
model.eval()
# Merge LoRA and base model and save
merged_model = model.merge_and_unload()
merged_model.save_pretrained("/opt/ml/model/")
# save tokenizer for easy inference
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
tokenizer.save_pretrained("/opt/ml/model/") |
Beta Was this translation helpful? Give feedback.
-
Observed several instances where fine-tuning Whisper with PEFT and then running inference in 8bit precision gives ~5x slower inference speeds vs full precision, and increases Whisper’s propensity to hallucinate considerably
Table for inference speed with batch-size=1:
I'll include code snippets below, and update these in time to use a fine-tuned PEFT checkpoint with audio sample (currently these are both private):
Code to load PEFT model in 8bit then pass to pipeline:
Loading the model weights and PEFT weights in fp32/fp16 for inference drastically helps with inference time (faster than fp32), and retains the WER boost we get by fine-tuning with PEFT. There are almost no hallucinations when we run inference in full or half precision.
Code to load PEFT model in fp16 then pass to pipeline:
Takeaway: PEFT is great for stable, low-resource training in 8-bit. We can then leverage the fine-tuned checkpoints for fast inference in full or half precision and negate possible hallucinations
Beta Was this translation helpful? Give feedback.
All reactions