Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #136. Updated eole/bin/model/average_models.py to work with safetensors model format. #137

Merged
merged 14 commits into from
Oct 30, 2024
Merged
47 changes: 31 additions & 16 deletions eole/bin/model/average_models.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
#!/usr/bin/env python
import torch
from eole.bin import BaseBin, register_bin
from eole.models import model_saver
from eole.config import recursive_model_fields_set
from safetensors.torch import load_file, save_file
import os
import json


def average_models(model_files, fp32=False):
def average_models(model_paths, fp32=False):
vocab = None
config = None
avg_model = None
avg_generator = None

for i, model_file in enumerate(model_files):
m = torch.load(model_file, map_location="cpu")
model_weights = m["model"]
generator_weights = m["generator"]
for i, model_path in enumerate(model_paths):
m = model_saver.load_checkpoint(model_path)
model_weights = load_file(os.path.join(model_path, "model.00.safetensors"))

if fp32:
for k, v in model_weights.items():
model_weights[k] = v.float()
for k, v in generator_weights.items():
generator_weights[k] = v.float()

if i == 0:
vocab, config = m["vocab"], m["config"]
vocab, config, optim = m["vocab"], m["config"], m["optim"]
avg_model = model_weights
avg_generator = generator_weights
else:
for k, v in avg_model.items():
avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1)

for k, v in avg_generator.items():
avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1)

final = {
"vocab": vocab,
"config": config,
"optim": None,
"generator": avg_generator,
"optim": optim,
"model": avg_model,
}
return final
Expand All @@ -56,4 +52,23 @@ def add_args(cls, parser):
@classmethod
def run(cls, args):
final = average_models(args.models, args.fp32)
torch.save(final, args.output)

if not os.path.isdir(args.output):
os.makedirs(args.output, exist_ok=True)

# this maybe better implemented using model_saver classes
# config
with open(os.path.join(args.output, "config.json"), "w") as f:
json.dump(
recursive_model_fields_set(final["config"]),
f,
indent=2,
ensure_ascii=False,
)
# vocab
with open(os.path.join(args.output, "vocab.json"), "w") as f:
json.dump(final["vocab"], f, indent=2, ensure_ascii=False)
# optimizer
torch.save(final["optim"], os.path.join(args.output, "optimizer.pt"))
# model weights
save_file(final["model"], os.path.join(args.output, "model.00.safetensors"))
Loading