Skip to content

Commit

Permalink
Fixed issues found through demo
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoffatt2 committed Nov 8, 2024
1 parent d3b20c8 commit 362129a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
9 changes: 9 additions & 0 deletions demos/quantization_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

# Training fully quantized model
python3 train.py \
--out_dir "quantized_model" \
--n_layer "2" \
--n_head "2" \
--n_kv_group "2" \
--n_embd "60" \
--max_iters "100" \
--block_size "32" \
--eval_iters "50" \
--log_interval "20" \
--quantize_linear_method "symmetric_quant" \
--activations_quant_method "symmetric_quant" \
--dtype "bfloat16" \
Expand Down
6 changes: 5 additions & 1 deletion quantization/save_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import numpy as np
import pickle
import ml_dtypes
from collections import OrderedDict

def parse_args():
Expand All @@ -18,7 +19,10 @@ def save_quantized_data(state_dict, out_file, file_type):
to_save = OrderedDict()
for k, v in list(state_dict.items()):
if "mlp_act" in k or "attn_act" in k or k.endswith("quantized_bias") or k.endswith("bias_norm") or k.endswith("zero_point") or k.endswith("quantized_weight") or k.endswith("weight_norm"):
to_save[k] = v.cpu().numpy()
if v.dtype == torch.bfloat16:
to_save[k] = v.cpu().float().numpy().astype(ml_dtypes.bfloat16)
else:
to_save[k] = v.cpu().numpy()

if file_type == "pkl":
with open(f"{out_file}.pkl", 'wb') as f:
Expand Down
1 change: 1 addition & 0 deletions quantization/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main():
# Save the image
plt.savefig(image_path)
print(f'Saved image to {image_path}')
plt.close()

if __name__ == "__main__":
main()

0 comments on commit 362129a

Please sign in to comment.