Skip to content

Commit

Permalink
Version 0.42.0 New Features EasyDelState,`fully automated loading…
Browse files Browse the repository at this point in the history
… and converting model`
  • Loading branch information
erfanzar committed Jan 11, 2024
1 parent 6d0a78e commit de2ec4e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 48 deletions.
42 changes: 15 additions & 27 deletions lib/python/EasyDel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,39 +114,27 @@
get_mem as get_mem
)

from .transform.llama import (
llama_from_pretrained as llama_from_pretrained,
llama_convert_flax_to_pt as llama_convert_flax_to_pt,
llama_convert_hf_to_flax_load as llama_convert_hf_to_flax_load,
llama_convert_hf_to_flax as llama_convert_hf_to_flax,
llama_easydel_to_hf as llama_easydel_to_hf
)
from .transform.mpt import (
mpt_convert_flax_to_pt_1b as mpt_convert_flax_to_pt_1b,
mpt_convert_pt_to_flax_1b as mpt_convert_pt_to_flax_1b,
mpt_convert_pt_to_flax_7b as mpt_convert_pt_to_flax_7b,
mpt_convert_flax_to_pt_7b as mpt_convert_flax_to_pt_7b,
mpt_from_pretrained as mpt_from_pretrained
)

from .transform.falcon import (
falcon_convert_pt_to_flax_7b as falcon_convert_pt_to_flax_7b,
from .transform import (
huggingface_to_easydel as huggingface_to_easydel,
easystate_to_huggingface_model as easystate_to_huggingface_model,
easystate_to_torch as easystate_to_torch,
falcon_convert_flax_to_pt_7b as falcon_convert_flax_to_pt_7b,
falcon_from_pretrained as falcon_from_pretrained,
falcon_convert_hf_to_flax as falcon_convert_hf_to_flax,
falcon_easydel_to_hf as falcon_easydel_to_hf
)
from .transform.mistral import (
mistral_convert_hf_to_flax as mistral_convert_hf_to_flax,
mpt_convert_pt_to_flax_1b as mpt_convert_pt_to_flax_1b,
mpt_convert_pt_to_flax_7b as mpt_convert_pt_to_flax_7b,
mpt_convert_flax_to_pt_7b as mpt_convert_flax_to_pt_7b,
mpt_from_pretrained as mpt_from_pretrained,
mistral_convert_hf_to_flax_load as mistral_convert_hf_to_flax_load,
mistral_convert_flax_to_pt as mistral_convert_flax_to_pt,
mistral_from_pretrained as mistral_from_pretrained,
mistral_convert_pt_to_flax as mistral_convert_pt_to_flax,
mistral_easydel_to_hf as mistral_easydel_to_hf
)

from .transform.easydel_transform import (
huggingface_to_easydel as huggingface_to_easydel
falcon_convert_pt_to_flax_7b as falcon_convert_pt_to_flax_7b,
mistral_convert_hf_to_flax as mistral_convert_hf_to_flax,
mpt_convert_flax_to_pt_1b as mpt_convert_flax_to_pt_1b,
llama_convert_flax_to_pt as llama_convert_flax_to_pt,
llama_convert_hf_to_flax_load as llama_convert_hf_to_flax_load,
llama_convert_hf_to_flax as llama_convert_hf_to_flax,
llama_from_pretrained as llama_from_pretrained
)
from .etils import (
EasyDelOptimizers as EasyDelOptimizers,
Expand Down
41 changes: 31 additions & 10 deletions lib/python/EasyDel/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,31 @@
from .llama import llama_from_pretrained, llama_convert_flax_to_pt, llama_convert_hf_to_flax_load, \
llama_convert_hf_to_flax, llama_easydel_to_hf
from .mpt import mpt_convert_flax_to_pt_1b, mpt_convert_pt_to_flax_1b, mpt_convert_pt_to_flax_7b, \
mpt_convert_flax_to_pt_7b, mpt_from_pretrained
from .falcon import falcon_convert_pt_to_flax_7b, falcon_convert_flax_to_pt_7b, falcon_from_pretrained, \
falcon_convert_hf_to_flax, falcon_easydel_to_hf
from .mistral import mistral_convert_hf_to_flax, mistral_convert_hf_to_flax_load, \
mistral_convert_flax_to_pt, \
mistral_from_pretrained, mistral_convert_pt_to_flax, mistral_easydel_to_hf
from .easydel_transform import easystate_to_huggingface_model, easystate_to_torch
from .llama import (
llama_from_pretrained,
llama_convert_flax_to_pt,
llama_convert_hf_to_flax_load,
llama_convert_hf_to_flax,
)
from .mpt import (
mpt_convert_flax_to_pt_1b,
mpt_convert_pt_to_flax_1b,
mpt_convert_pt_to_flax_7b,
mpt_convert_flax_to_pt_7b,
mpt_from_pretrained
)
from .falcon import (
falcon_convert_pt_to_flax_7b,
falcon_convert_flax_to_pt_7b,
falcon_from_pretrained,
falcon_convert_hf_to_flax,
)
from .mistral import (
mistral_convert_hf_to_flax,
mistral_convert_hf_to_flax_load,
mistral_convert_flax_to_pt,
mistral_from_pretrained,
)

from .easydel_transform import (
huggingface_to_easydel,
easystate_to_huggingface_model,
easystate_to_torch
)
11 changes: 0 additions & 11 deletions lib/python/EasyDel/transform/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,6 @@ def falcon_convert_flax_to_pt_7b(state_dict_flax, num_hidden_layers: int, device
return state_dict


def falcon_easydel_to_hf(path, config: FalconConfig):
"""
Takes path to easydel saved ckpt and return the model in pytorch (Transformers Huggingface)
"""
torch_params = load_and_convert_checkpoint_to_torch(path)
edited_params = {}
for k, v in torch_params.items():
edited_params[k.replace('.kernel', '.weight').replace('.embedding', '.weight')] = v
model = FalconForCausalLM(config=config)
model.load_state_dict(edited_params)
return model


def falcon_convert_hf_to_flax(state_dict: Dict[str, torch.Tensor], config: FalconConfig, device):
Expand Down

0 comments on commit de2ec4e

Please sign in to comment.