Skip to content

Commit

Permalink
Merge pull request #3557 from flairNLP/add_cleanconll
Browse files Browse the repository at this point in the history
Add CleanCoNLL object
  • Loading branch information
helpmefindaname authored Dec 6, 2024
2 parents 1f1a45b + 79287c3 commit 8ae1ab8
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 12 deletions.
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
# Expose all sequence labeling datasets
from .sequence_labeling import (
BIOSCOPE,
CLEANCONLL,
CONLL_03,
CONLL_03_DUTCH,
CONLL_03_GERMAN,
Expand Down Expand Up @@ -465,6 +466,7 @@
"CONLL_03_DUTCH",
"CONLL_03_GERMAN",
"CONLL_03_SPANISH",
"CLEANCONLL",
"CONLL_2000",
"FEWNERD",
"KEYPHRASE_INSPEC",
Expand Down
239 changes: 238 additions & 1 deletion flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import logging
import os
import re
import gzip
import shutil
import tarfile
import tempfile
import zipfile
from collections import defaultdict
from collections.abc import Iterable, Iterator
from pathlib import Path
Expand All @@ -15,6 +18,7 @@
cast,
)

import requests
from torch.utils.data import ConcatDataset, Dataset

import flair
Expand Down Expand Up @@ -1421,6 +1425,240 @@ def __init__(
)


class CLEANCONLL(ColumnCorpus):
def __init__(
self,
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
**corpusargs,
) -> None:
"""Initialize the CleanCoNLL corpus.
Args:
base_path: Base directory for the dataset. If None, defaults to flair.cache_root / "datasets".
in_memory: If True, keeps dataset in memory for faster training.
"""
# Set the base path for the dataset
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)

# Define column format
columns = {0: "text", 1: "pos", 2: "nel", 3: "ner*", 4: "ner"}

# Define dataset name
dataset_name = self.__class__.__name__.lower()

# Define data folder path
data_folder = base_path / dataset_name

# Check if the train data file exists, otherwise download and prepare the dataset
train_set = data_folder / "cleanconll.train"

if not train_set.exists():
print("CleanCoNLL files not found, so downloading and creating them.")

# Download and prepare the dataset
self.download_and_prepare_data(data_folder)

else:
print("Found files for CleanCoNLL in:", data_folder)

# Initialize the parent class with the specified parameters
super().__init__(
data_folder,
columns,
encoding="utf-8",
in_memory=in_memory,
document_separator_token="-DOCSTART-",
**corpusargs,
)

@staticmethod
def download_and_prepare_data(data_folder: Path):
def parse_patch(patch_file_path):
"""Parses a patch file and returns a structured representation of the changes."""
changes = []
current_change = None

with open(patch_file_path, encoding="utf-8") as patch_file:
for line in patch_file:
# Check if the line is a change, delete or add command (like 17721c17703,17705 or 5728d5727)
if line and (line[0].isdigit() and ("c" in line or "d" in line or "a" in line)):
if current_change:
# Append the previous change block to the changes list
changes.append(current_change)

# Start a new change block
current_change = {"command": line, "original": [], "new": []}

# Capture original lines (those marked with "<")
elif line.startswith("<"):
if current_change:
current_change["original"].append(line[2:]) # Remove the "< " part

# Capture new lines (those marked with ">")
elif line.startswith(">"):
if current_change:
current_change["new"].append(line[2:]) # Remove the "> " part

# Append the last change block to the changes list
if current_change:
changes.append(current_change)

return changes

def parse_line_range(line_range_str):
"""Utility function to parse a line range string like '17703,17705' or '5727' and returns a tuple (start, end)."""
parts = line_range_str.split(",")
if len(parts) == 1:
start = int(parts[0]) - 1
return (start, start + 1)
else:
start = int(parts[0]) - 1
end = int(parts[1])
return (start, end)

def apply_patch_to_file(original_file, changes, output_file_path):
"""Applies the patch instructions to the content of the original file."""
with open(original_file, encoding="utf-8") as f:
original_lines = f.readlines()

modified_lines = original_lines[:] # Make a copy of original lines

# Apply each change in reverse order (important to avoid index shift issues)
for change in reversed(changes):
command = change["command"]

# Determine the type of the change: `c` for change, `d` for delete, `a` for add
if "c" in command:
# Example command: 17721c17703,17705
original_line_range, new_line_range = command.split("c")
original_line_range = parse_line_range(original_line_range)
modified_lines[original_line_range[0] : original_line_range[1]] = change["new"]

elif "d" in command:
# Example command: 5728d5727
original_line_number = int(command.split("d")[0]) - 1
del modified_lines[original_line_number]

elif "a" in command:
# Example command: 1000a1001,1002
original_line_number = int(command.split("a")[0]) - 1
insertion_point = original_line_number + 1
for new_token in reversed(change["new"]):
modified_lines.insert(insertion_point, new_token)

# Write the modified content to the output file
with open(output_file_path, "w", encoding="utf-8") as output_file:
output_file.writelines(modified_lines)

def apply_patch(file_path, patch_path, output_path):
patch_instructions = parse_patch(patch_path)
apply_patch_to_file(file_path, patch_instructions, output_path)

def extract_tokens(file_path: Path, output_path: Path):
with open(file_path, encoding="utf-8") as f_in, open(output_path, "w", encoding="utf-8") as f_out:
for line in f_in:
# Strip whitespace to check if the line is empty
stripped_line = line.strip()
if stripped_line:
# Write the first token followed by a newline if the line is not empty
f_out.write(stripped_line.split()[0] + "\n")
else:
# Write an empty line if the line is empty
f_out.write("\n")

def merge_annotations(tokens_file, annotations_file, output_file):
with (
open(tokens_file, encoding="utf-8") as tokens_file,
open(annotations_file, encoding="utf-8") as annotations_file,
open(output_file, "w", encoding="utf-8") as output_file,
):
tokens = tokens_file.readlines()
annotations = annotations_file.readlines()

for token, annotation in zip(tokens, annotations):
# Strip the leading '[TOKEN]\t' from the annotation
stripped_annotation = "\t".join(annotation.strip().split("\t")[1:])
output_file.write(token.strip() + "\t" + stripped_annotation + "\n")

# Create a temporary directory
with tempfile.TemporaryDirectory() as tmpdirname:
tmpdir = Path(tmpdirname)

github_url = "https://github.com/flairNLP/CleanCoNLL/archive/main.zip"
zip_path = cached_path(github_url, tmpdir)
unpack_file(zip_path, tmpdir, "zip", False)
cleanconll_data_root = tmpdir / "CleanCoNLL-main"

# Check the contents of the temporary directory
print(f"Contents of the temporary directory: {list(tmpdir.iterdir())}")

conll03_dir = data_folder / "original_conll-03"
if conll03_dir.exists() and conll03_dir.is_dir() and "train.txt" in [f.name for f in conll03_dir.iterdir()]:
print(f"Original CoNLL03 files detected here: {conll03_dir}")

else:
conll_url = "https://data.deepai.org/conll2003.zip"

conll03_dir.mkdir(parents=True, exist_ok=True)
print(f"Downloading the original CoNLL03 from {conll_url} into {conll03_dir} ...")

zip_path = conll03_dir / "conll2003.zip"
response = requests.get(conll_url)
zip_path.write_bytes(response.content)

with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(conll03_dir)

conll03_train = conll03_dir / "train.txt"
conll03_dev = conll03_dir / "valid.txt"
conll03_test = conll03_dir / "test.txt"

patch_dir = cleanconll_data_root / "data" / "patch_files"
tokens_dir = cleanconll_data_root / "data" / "tokens_updated"
tokens_dir.mkdir(parents=True, exist_ok=True)

# Extract only the tokens from the original CoNLL03 files
extract_tokens(conll03_train, tokens_dir / "train_tokens.txt")
extract_tokens(conll03_dev, tokens_dir / "valid_tokens.txt")
extract_tokens(conll03_test, tokens_dir / "test_tokens.txt")

# Apply the downloaded patch files to apply our token modifications (e.g. line breaks)
apply_patch(
tokens_dir / "train_tokens.txt",
patch_dir / "train_tokens.patch",
tokens_dir / "train_tokens_updated.txt",
)
apply_patch(
tokens_dir / "valid_tokens.txt", patch_dir / "dev_tokens.patch", tokens_dir / "dev_tokens_updated.txt"
)
apply_patch(
tokens_dir / "test_tokens.txt", patch_dir / "test_tokens.patch", tokens_dir / "test_tokens_updated.txt"
)

# Merge the updated token files with the CleanCoNLL annotations
cleanconll_annotations_dir = cleanconll_data_root / "data" / "cleanconll_annotations"
data_folder.mkdir(parents=True, exist_ok=True)

merge_annotations(
tokens_dir / "train_tokens_updated.txt",
cleanconll_annotations_dir / "cleanconll_annotations.train",
data_folder / "cleanconll.train",
)
merge_annotations(
tokens_dir / "dev_tokens_updated.txt",
cleanconll_annotations_dir / "cleanconll_annotations.dev",
data_folder / "cleanconll.dev",
)
merge_annotations(
tokens_dir / "test_tokens_updated.txt",
cleanconll_annotations_dir / "cleanconll_annotations.test",
data_folder / "cleanconll.test",
)

print("Done with creating. CleanCoNLL files are placed here:", data_folder)


class CONLL_2000(ColumnCorpus):
def __init__(
self,
Expand Down Expand Up @@ -1451,7 +1689,6 @@ def __init__(
if not data_file.is_file():
cached_path(f"{conll_2000_path}train.txt.gz", Path("datasets") / dataset_name)
cached_path(f"{conll_2000_path}test.txt.gz", Path("datasets") / dataset_name)
import gzip

with (
gzip.open(flair.cache_root / "datasets" / dataset_name / "train.txt.gz", "rb") as f_in,
Expand Down
13 changes: 7 additions & 6 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,11 @@ def _reconstruct_word_ids_from_subtokens(embedding, tokens: list[str], subtokens

special_tokens = []
# check if special tokens exist to circumvent error message
if embedding.tokenizer._bos_token:
if embedding.tokenizer.bos_token is not None:
special_tokens.append(embedding.tokenizer.bos_token)
if embedding.tokenizer._cls_token:
if embedding.tokenizer.cls_token is not None:
special_tokens.append(embedding.tokenizer.cls_token)
if embedding.tokenizer._sep_token:
if embedding.tokenizer.sep_token is not None:
special_tokens.append(embedding.tokenizer.sep_token)

# iterate over subtokens and reconstruct tokens
Expand Down Expand Up @@ -1354,9 +1354,10 @@ def from_params(cls, params):
def to_params(self):
config_dict = self.model.config.to_dict()

# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
config_dict.pop("_attn_implementation_autoset", None)
if hasattr(self.model.config, "_attn_implementation"):
# do not switch the attention implementation upon reload.
config_dict["attn_implementation"] = self.model.config._attn_implementation
config_dict.pop("_attn_implementation_autoset", None)

super_params = super().to_params()

Expand Down
10 changes: 6 additions & 4 deletions flair/models/tars_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def __init__(

# transformer separator
self.separator = str(self.tars_embeddings.tokenizer.sep_token)
if self.tars_embeddings.tokenizer._bos_token:
if self.tars_embeddings.tokenizer.bos_token is not None:
self.separator += str(self.tars_embeddings.tokenizer.bos_token)

self.prefix = prefix
Expand Down Expand Up @@ -718,9 +718,11 @@ def __init__(
)

# transformer separator
self.separator = str(self.tars_embeddings.tokenizer.sep_token)
if self.tars_embeddings.tokenizer._bos_token:
self.separator += str(self.tars_embeddings.tokenizer.bos_token)
self.separator = (
self.tars_embeddings.tokenizer.sep_token if self.tars_embeddings.tokenizer.sep_token is not None else ""
)
if self.tars_embeddings.tokenizer.bos_token is not None:
self.separator += self.tars_embeddings.tokenizer.bos_token

self.prefix = prefix
self.num_negative_labels_to_sample = num_negative_labels_to_sample
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ tabulate>=0.8.10
torch>=1.5.0,!=1.8
tqdm>=4.63.0
transformer-smaller-training-vocab>=0.2.3
transformers[sentencepiece]>=4.18.0,<5.0.0
transformers[sentencepiece]>=4.25.0,<5.0.0
wikipedia-api>=0.5.7
bioc<3.0.0,>=2.0.0

0 comments on commit 8ae1ab8

Please sign in to comment.