Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fix for custom code tokenizers in pipelines #32300

Merged
merged 10 commits into from
Aug 27, 2024
6 changes: 5 additions & 1 deletion src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,11 @@ def pipeline(

model_config = model.config
hub_kwargs["_commit_hash"] = model.config._commit_hash
load_tokenizer = type(model_config) in TOKENIZER_MAPPING or model_config.tokenizer_class is not None
load_tokenizer = (
type(model_config) in TOKENIZER_MAPPING
or model_config.tokenizer_class is not None
or isinstance(tokenizer, str)
)
load_feature_extractor = type(model_config) in FEATURE_EXTRACTOR_MAPPING or feature_extractor is not None
load_image_processor = type(model_config) in IMAGE_PROCESSOR_MAPPING or image_processor is not None

Expand Down
39 changes: 39 additions & 0 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
from requests.exceptions import HTTPError

from transformers import (
AutomaticSpeechRecognitionPipeline,
AutoModelForSequenceClassification,
AutoTokenizer,
DistilBertForSequenceClassification,
MaskGenerationPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
TFAutoModelForSequenceClassification,
pipeline,
)
Expand Down Expand Up @@ -859,6 +862,42 @@ def new_forward(*args, **kwargs):

self.assertEqual(self.COUNT, 1)

@require_torch
def test_custom_code_with_string_tokenizer(self):
# This test checks for an edge case - tokenizer loading used to fail when using a custom code model
# with a separate tokenizer that was passed as a repo name rather than a tokenizer object.
# See https://github.com/huggingface/transformers/issues/31669
text_generator = pipeline(
"text-generation",
model="Rocketknight1/fake-custom-model-test",
tokenizer="Rocketknight1/fake-custom-model-test",
trust_remote_code=True,
)

self.assertIsInstance(text_generator, TextGenerationPipeline) # Assert successful loading

@require_torch
def test_custom_code_with_string_feature_extractor(self):
speech_recognizer = pipeline(
"automatic-speech-recognition",
model="Rocketknight1/fake-custom-wav2vec2",
feature_extractor="Rocketknight1/fake-custom-wav2vec2",
trust_remote_code=True,
)

self.assertIsInstance(speech_recognizer, AutomaticSpeechRecognitionPipeline) # Assert successful loading

@require_torch
def test_custom_code_with_string_preprocessor(self):
mask_generator = pipeline(
"mask-generation",
model="Rocketknight1/fake-custom-sam",
processor="Rocketknight1/fake-custom-sam",
trust_remote_code=True,
)

self.assertIsInstance(mask_generator, MaskGenerationPipeline) # Assert successful loading


@require_torch
@is_staging_test
Expand Down
Loading