diff --git a/src/adapters/models/mt5/adapter_model.py b/src/adapters/models/mt5/adapter_model.py index 7b763d1854..b8673422c6 100644 --- a/src/adapters/models/mt5/adapter_model.py +++ b/src/adapters/models/mt5/adapter_model.py @@ -2,6 +2,7 @@ import torch +from transformers.modeling_utils import PreTrainedModel from transformers.models.mt5.modeling_mt5 import ( MT5_INPUTS_DOCSTRING, MT5_START_DOCSTRING, @@ -9,7 +10,6 @@ MT5PreTrainedModel, ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from transformers.modeling_utils import PreTrainedModel from ...composition import adjust_tensors_for_parallel from ...heads import ( @@ -19,9 +19,9 @@ QuestionAnsweringHead, Seq2SeqLMHead, ) +from ...loading import PredictionHeadLoader from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init -from ...loading import PredictionHeadLoader logger = logging.getLogger(__name__)