diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index 2837e017c0..d480c93552 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -171,6 +171,7 @@ # Expose all sequence labeling datasets from .sequence_labeling import ( BIOSCOPE, + CLEANCONLL, CONLL_03, CONLL_03_DUTCH, CONLL_03_GERMAN, @@ -465,6 +466,7 @@ "CONLL_03_DUTCH", "CONLL_03_GERMAN", "CONLL_03_SPANISH", + "CLEANCONLL", "CONLL_2000", "FEWNERD", "KEYPHRASE_INSPEC", diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 55e50723d1..fa0962f2f4 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -3,8 +3,11 @@ import logging import os import re +import gzip import shutil import tarfile +import tempfile +import zipfile from collections import defaultdict from collections.abc import Iterable, Iterator from pathlib import Path @@ -15,6 +18,7 @@ cast, ) +import requests from torch.utils.data import ConcatDataset, Dataset import flair @@ -1421,6 +1425,240 @@ def __init__( ) +class CLEANCONLL(ColumnCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + **corpusargs, + ) -> None: + """Initialize the CleanCoNLL corpus. + + Args: + base_path: Base directory for the dataset. If None, defaults to flair.cache_root / "datasets". + in_memory: If True, keeps dataset in memory for faster training. + """ + # Set the base path for the dataset + base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) + + # Define column format + columns = {0: "text", 1: "pos", 2: "nel", 3: "ner*", 4: "ner"} + + # Define dataset name + dataset_name = self.__class__.__name__.lower() + + # Define data folder path + data_folder = base_path / dataset_name + + # Check if the train data file exists, otherwise download and prepare the dataset + train_set = data_folder / "cleanconll.train" + + if not train_set.exists(): + print("CleanCoNLL files not found, so downloading and creating them.") + + # Download and prepare the dataset + self.download_and_prepare_data(data_folder) + + else: + print("Found files for CleanCoNLL in:", data_folder) + + # Initialize the parent class with the specified parameters + super().__init__( + data_folder, + columns, + encoding="utf-8", + in_memory=in_memory, + document_separator_token="-DOCSTART-", + **corpusargs, + ) + + @staticmethod + def download_and_prepare_data(data_folder: Path): + def parse_patch(patch_file_path): + """Parses a patch file and returns a structured representation of the changes.""" + changes = [] + current_change = None + + with open(patch_file_path, encoding="utf-8") as patch_file: + for line in patch_file: + # Check if the line is a change, delete or add command (like 17721c17703,17705 or 5728d5727) + if line and (line[0].isdigit() and ("c" in line or "d" in line or "a" in line)): + if current_change: + # Append the previous change block to the changes list + changes.append(current_change) + + # Start a new change block + current_change = {"command": line, "original": [], "new": []} + + # Capture original lines (those marked with "<") + elif line.startswith("<"): + if current_change: + current_change["original"].append(line[2:]) # Remove the "< " part + + # Capture new lines (those marked with ">") + elif line.startswith(">"): + if current_change: + current_change["new"].append(line[2:]) # Remove the "> " part + + # Append the last change block to the changes list + if current_change: + changes.append(current_change) + + return changes + + def parse_line_range(line_range_str): + """Utility function to parse a line range string like '17703,17705' or '5727' and returns a tuple (start, end).""" + parts = line_range_str.split(",") + if len(parts) == 1: + start = int(parts[0]) - 1 + return (start, start + 1) + else: + start = int(parts[0]) - 1 + end = int(parts[1]) + return (start, end) + + def apply_patch_to_file(original_file, changes, output_file_path): + """Applies the patch instructions to the content of the original file.""" + with open(original_file, encoding="utf-8") as f: + original_lines = f.readlines() + + modified_lines = original_lines[:] # Make a copy of original lines + + # Apply each change in reverse order (important to avoid index shift issues) + for change in reversed(changes): + command = change["command"] + + # Determine the type of the change: `c` for change, `d` for delete, `a` for add + if "c" in command: + # Example command: 17721c17703,17705 + original_line_range, new_line_range = command.split("c") + original_line_range = parse_line_range(original_line_range) + modified_lines[original_line_range[0] : original_line_range[1]] = change["new"] + + elif "d" in command: + # Example command: 5728d5727 + original_line_number = int(command.split("d")[0]) - 1 + del modified_lines[original_line_number] + + elif "a" in command: + # Example command: 1000a1001,1002 + original_line_number = int(command.split("a")[0]) - 1 + insertion_point = original_line_number + 1 + for new_token in reversed(change["new"]): + modified_lines.insert(insertion_point, new_token) + + # Write the modified content to the output file + with open(output_file_path, "w", encoding="utf-8") as output_file: + output_file.writelines(modified_lines) + + def apply_patch(file_path, patch_path, output_path): + patch_instructions = parse_patch(patch_path) + apply_patch_to_file(file_path, patch_instructions, output_path) + + def extract_tokens(file_path: Path, output_path: Path): + with open(file_path, encoding="utf-8") as f_in, open(output_path, "w", encoding="utf-8") as f_out: + for line in f_in: + # Strip whitespace to check if the line is empty + stripped_line = line.strip() + if stripped_line: + # Write the first token followed by a newline if the line is not empty + f_out.write(stripped_line.split()[0] + "\n") + else: + # Write an empty line if the line is empty + f_out.write("\n") + + def merge_annotations(tokens_file, annotations_file, output_file): + with ( + open(tokens_file, encoding="utf-8") as tokens_file, + open(annotations_file, encoding="utf-8") as annotations_file, + open(output_file, "w", encoding="utf-8") as output_file, + ): + tokens = tokens_file.readlines() + annotations = annotations_file.readlines() + + for token, annotation in zip(tokens, annotations): + # Strip the leading '[TOKEN]\t' from the annotation + stripped_annotation = "\t".join(annotation.strip().split("\t")[1:]) + output_file.write(token.strip() + "\t" + stripped_annotation + "\n") + + # Create a temporary directory + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdir = Path(tmpdirname) + + github_url = "https://github.com/flairNLP/CleanCoNLL/archive/main.zip" + zip_path = cached_path(github_url, tmpdir) + unpack_file(zip_path, tmpdir, "zip", False) + cleanconll_data_root = tmpdir / "CleanCoNLL-main" + + # Check the contents of the temporary directory + print(f"Contents of the temporary directory: {list(tmpdir.iterdir())}") + + conll03_dir = data_folder / "original_conll-03" + if conll03_dir.exists() and conll03_dir.is_dir() and "train.txt" in [f.name for f in conll03_dir.iterdir()]: + print(f"Original CoNLL03 files detected here: {conll03_dir}") + + else: + conll_url = "https://data.deepai.org/conll2003.zip" + + conll03_dir.mkdir(parents=True, exist_ok=True) + print(f"Downloading the original CoNLL03 from {conll_url} into {conll03_dir} ...") + + zip_path = conll03_dir / "conll2003.zip" + response = requests.get(conll_url) + zip_path.write_bytes(response.content) + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(conll03_dir) + + conll03_train = conll03_dir / "train.txt" + conll03_dev = conll03_dir / "valid.txt" + conll03_test = conll03_dir / "test.txt" + + patch_dir = cleanconll_data_root / "data" / "patch_files" + tokens_dir = cleanconll_data_root / "data" / "tokens_updated" + tokens_dir.mkdir(parents=True, exist_ok=True) + + # Extract only the tokens from the original CoNLL03 files + extract_tokens(conll03_train, tokens_dir / "train_tokens.txt") + extract_tokens(conll03_dev, tokens_dir / "valid_tokens.txt") + extract_tokens(conll03_test, tokens_dir / "test_tokens.txt") + + # Apply the downloaded patch files to apply our token modifications (e.g. line breaks) + apply_patch( + tokens_dir / "train_tokens.txt", + patch_dir / "train_tokens.patch", + tokens_dir / "train_tokens_updated.txt", + ) + apply_patch( + tokens_dir / "valid_tokens.txt", patch_dir / "dev_tokens.patch", tokens_dir / "dev_tokens_updated.txt" + ) + apply_patch( + tokens_dir / "test_tokens.txt", patch_dir / "test_tokens.patch", tokens_dir / "test_tokens_updated.txt" + ) + + # Merge the updated token files with the CleanCoNLL annotations + cleanconll_annotations_dir = cleanconll_data_root / "data" / "cleanconll_annotations" + data_folder.mkdir(parents=True, exist_ok=True) + + merge_annotations( + tokens_dir / "train_tokens_updated.txt", + cleanconll_annotations_dir / "cleanconll_annotations.train", + data_folder / "cleanconll.train", + ) + merge_annotations( + tokens_dir / "dev_tokens_updated.txt", + cleanconll_annotations_dir / "cleanconll_annotations.dev", + data_folder / "cleanconll.dev", + ) + merge_annotations( + tokens_dir / "test_tokens_updated.txt", + cleanconll_annotations_dir / "cleanconll_annotations.test", + data_folder / "cleanconll.test", + ) + + print("Done with creating. CleanCoNLL files are placed here:", data_folder) + + class CONLL_2000(ColumnCorpus): def __init__( self, @@ -1451,7 +1689,6 @@ def __init__( if not data_file.is_file(): cached_path(f"{conll_2000_path}train.txt.gz", Path("datasets") / dataset_name) cached_path(f"{conll_2000_path}test.txt.gz", Path("datasets") / dataset_name) - import gzip with ( gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in, diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index a27b57ab96..fdb16eea28 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -278,11 +278,11 @@ def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens special_tokens = [] # check if special tokens exist to circumvent error message - if embedding.tokenizer._bos_token: + if embedding.tokenizer.bos_token is not None: special_tokens.append(embedding.tokenizer.bos_token) - if embedding.tokenizer._cls_token: + if embedding.tokenizer.cls_token is not None: special_tokens.append(embedding.tokenizer.cls_token) - if embedding.tokenizer._sep_token: + if embedding.tokenizer.sep_token is not None: special_tokens.append(embedding.tokenizer.sep_token) # iterate over subtokens and reconstruct tokens @@ -1354,9 +1354,10 @@ def from_params(cls, params): def to_params(self): config_dict = self.model.config.to_dict() - # do not switch the attention implementation upon reload. - config_dict["attn_implementation"] = self.model.config._attn_implementation - config_dict.pop("_attn_implementation_autoset", None) + if hasattr(self.model.config, "_attn_implementation"): + # do not switch the attention implementation upon reload. + config_dict["attn_implementation"] = self.model.config._attn_implementation + config_dict.pop("_attn_implementation_autoset", None) super_params = super().to_params() diff --git a/flair/models/tars_model.py b/flair/models/tars_model.py index 4f5cb85731..f4171fdb27 100644 --- a/flair/models/tars_model.py +++ b/flair/models/tars_model.py @@ -383,7 +383,7 @@ def __init__( # transformer separator self.separator = str(self.tars_embeddings.tokenizer.sep_token) - if self.tars_embeddings.tokenizer._bos_token: + if self.tars_embeddings.tokenizer.bos_token is not None: self.separator += str(self.tars_embeddings.tokenizer.bos_token) self.prefix = prefix @@ -718,9 +718,11 @@ def __init__( ) # transformer separator - self.separator = str(self.tars_embeddings.tokenizer.sep_token) - if self.tars_embeddings.tokenizer._bos_token: - self.separator += str(self.tars_embeddings.tokenizer.bos_token) + self.separator = ( + self.tars_embeddings.tokenizer.sep_token if self.tars_embeddings.tokenizer.sep_token is not None else "" + ) + if self.tars_embeddings.tokenizer.bos_token is not None: + self.separator += self.tars_embeddings.tokenizer.bos_token self.prefix = prefix self.num_negative_labels_to_sample = num_negative_labels_to_sample diff --git a/requirements.txt b/requirements.txt index bb5ecafd45..2704114ace 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,6 @@ tabulate>=0.8.10 torch>=1.5.0,!=1.8 tqdm>=4.63.0 transformer-smaller-training-vocab>=0.2.3 -transformers[sentencepiece]>=4.18.0,<5.0.0 +transformers[sentencepiece]>=4.25.0,<5.0.0 wikipedia-api>=0.5.7 bioc<3.0.0,>=2.0.0