diff --git a/flair/file_utils.py b/flair/file_utils.py index dfb0049b78..f7f20a20f3 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -171,13 +171,6 @@ def hf_download(model_name: str) -> str: ) except HTTPError: # output information - logger.error("-" * 80) - logger.error( - f"ERROR: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!" - ) - logger.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") - logger.error(" -> Alternatively, point to a model file on your local drive.") - logger.error("-" * 80) Path(flair.cache_root / "models" / model_folder).rmdir() # remove folder again if not valid raise diff --git a/flair/nn/model.py b/flair/nn/model.py index 96b2c2d925..88f51f443b 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -151,7 +151,17 @@ def load(cls, model_path: Union[str, Path, Dict[str, Any]]) -> "Model": continue # if the model cannot be fetched, load as a file - state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path)) + try: + state = model_path if isinstance(model_path, dict) else load_torch_state(str(model_path)) + except Exception: + log.error("-" * 80) + log.error( + f"ERROR: The key '{model_path}' was neither found on the ModelHub nor is this a valid path to a file on your system!" + ) + log.error(" -> Please check https://huggingface.co/models?filter=flair for all available models.") + log.error(" -> Alternatively, point to a model file on your local drive.") + log.error("-" * 80) + raise ValueError(f"Could not find any model with name '{model_path}'") # try to get model class from state cls_name = state.pop("__cls__", None) diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 6d9c3ec54b..fb8590841b 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -473,7 +473,7 @@ def train_custom( if inspect.isclass(sampler): sampler = sampler() # set dataset to sample from - sampler.set_dataset(train_data) # type: ignore[union-attr] + sampler.set_dataset(train_data) shuffle = False # this field stores the names of all dynamic embeddings in the model (determined after first forward pass) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 52fec1c5ea..2d0391b264 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -418,6 +418,7 @@ def test_load_universal_dependencies_conllu_corpus(tasks_base_path): _assert_universal_dependencies_conllu_dataset(corpus.train) +@pytest.mark.integration() def test_hipe_2022_corpus(tasks_base_path): # This test covers the complete HIPE 2022 dataset. # https://github.com/hipe-eval/HIPE-2022-data @@ -681,6 +682,7 @@ def test_hipe_2022(dataset_version="v2.1", add_document_separator=True): test_hipe_2022(dataset_version="v2.1", add_document_separator=False) +@pytest.mark.integration() def test_icdar_europeana_corpus(tasks_base_path): # This test covers the complete ICDAR Europeana corpus: # https://github.com/stefan-it/historic-domain-adaptation-icdar @@ -698,6 +700,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str): check_number_sentences(len(corpus.test), gold_stats[language]["test"], "test") +@pytest.mark.integration() def test_masakhane_corpus(tasks_base_path): # This test covers the complete MasakhaNER dataset, including support for v1 and v2. supported_versions = ["v1", "v2"] @@ -781,6 +784,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version) +@pytest.mark.integration() def test_nermud_corpus(tasks_base_path): # This test covers the NERMuD dataset. Official stats can be found here: # https://github.com/dhfbk/KIND/tree/main/evalita-2023 @@ -808,6 +812,7 @@ def test_german_ler_corpus(tasks_base_path): assert len(corpus.test) == 6673, "Mismatch in number of sentences for test split" +@pytest.mark.integration() def test_masakha_pos_corpus(tasks_base_path): # This test covers the complete MasakhaPOS dataset. supported_versions = ["v1"] @@ -876,6 +881,7 @@ def check_number_sentences(reference: int, actual: int, split_name: str, languag check_number_sentences(len(corpus.test), gold_stats["test"], "test", language, version) +@pytest.mark.integration() def test_german_mobie(tasks_base_path): corpus = flair.datasets.NER_GERMAN_MOBIE() @@ -960,6 +966,7 @@ def test_jsonl_corpus_loads_metadata(tasks_base_path): assert dataset.sentences[2].get_metadata("from") == 125 +@pytest.mark.integration() def test_ontonotes_download(): from urllib.parse import urlparse