Skip to content

Commit

Permalink
Merge pull request #3504 from flairNLP/GH-3503-classifier-load
Browse files Browse the repository at this point in the history
Move error message to main load function
  • Loading branch information
alanakbik authored Jul 23, 2024
2 parents 9c4e1d2 + 2f3e82e commit d94b890
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
7 changes: 0 additions & 7 deletions flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,6 @@ def hf_download(model_name: str) -> str:
)
except HTTPError:
# output information
logger.error("-" * 80)
logger.error(
f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
logger.error(" -> Alternatively, point to a model file on your local drive.")
logger.error("-" * 80)
Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid
raise

Expand Down
12 changes: 11 additions & 1 deletion flair/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,17 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model":
continue

# if the model cannot be fetched, load as a file
state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path))
try:
state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path))
except Exception:
log.error("-" * 80)
log.error(
f"ERROR: The key '{model_path}' was neither found on the ModelHub nor is this a valid path to a file on your system!"
)
log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.")
log.error(" -> Alternatively, point to a model file on your local drive.")
log.error("-" * 80)
raise ValueError(f"Could not find any model with name '{model_path}'")

# try to get model class from state
cls_name = state.pop("__cls__", None)
Expand Down
2 changes: 1 addition & 1 deletion flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def train_custom(
if inspect.isclass(sampler):
sampler = sampler()
# set dataset to sample from
sampler.set_dataset(train_data) # type: ignore[union-attr]
sampler.set_dataset(train_data)
shuffle = False

# this field stores the names of all dynamic embeddings in the model (determined after first forward pass)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path):
_assert_universal_dependencies_conllu_dataset(corpus.train)


@pytest.mark.integration()
def test_hipe_2022_corpus(tasks_base_path):
# This test covers the complete HIPE 2022 dataset.
# https://github.com/hipe-eval/HIPE-2022-data
Expand Down Expand Up @@ -681,6 +682,7 @@ def test_hipe_2022(dataset_version="v2.1", add_document_separator=True):
test_hipe_2022(dataset_version="v2.1", add_document_separator=False)


@pytest.mark.integration()
def test_icdar_europeana_corpus(tasks_base_path):
# This test covers the complete ICDAR Europeana corpus:
# https://github.com/stefan-it/historic-domain-adaptation-icdar
Expand All @@ -698,6 +700,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str):
check_number_sentences(len(corpus.test), gold_stats[language]["test"], "test")


@pytest.mark.integration()
def test_masakhane_corpus(tasks_base_path):
# This test covers the complete MasakhaNER dataset, including support for v1 and v2.
supported_versions = ["v1", "v2"]
Expand Down Expand Up @@ -781,6 +784,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


@pytest.mark.integration()
def test_nermud_corpus(tasks_base_path):
# This test covers the NERMuD dataset. Official stats can be found here:
# https://github.com/dhfbk/KIND/tree/main/evalita-2023
Expand Down Expand Up @@ -808,6 +812,7 @@ def test_german_ler_corpus(tasks_base_path):
assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split"


@pytest.mark.integration()
def test_masakha_pos_corpus(tasks_base_path):
# This test covers the complete MasakhaPOS dataset.
supported_versions = ["v1"]
Expand Down Expand Up @@ -876,6 +881,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag
check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version)


@pytest.mark.integration()
def test_german_mobie(tasks_base_path):
corpus = flair.datasets.NER_GERMAN_MOBIE()

Expand Down Expand Up @@ -960,6 +966,7 @@ def test_jsonl_corpus_loads_metadata(tasks_base_path):
assert dataset.sentences[2].get_metadata("from") == 125


@pytest.mark.integration()
def test_ontonotes_download():
from urllib.parse import urlparse

Expand Down

0 comments on commit d94b890

Please sign in to comment.