Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ajouter la validation croisée à l'entrainement #7

Merged
merged 8 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified alexi/models/crf.joblib.gz
Binary file not shown.
Binary file modified alexi/models/crf.vl.joblib.gz
Binary file not shown.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
5 changes: 3 additions & 2 deletions scripts/check_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import argparse
import csv
from pathlib import Path

from alexi.analyse import Bloc, group_iob
from alexi.segment import Bullet
from alexi.analyse import group_iob, Bloc


def make_argparse() -> argparse.ArgumentParser:
Expand All @@ -13,7 +14,7 @@ def make_argparse() -> argparse.ArgumentParser:
return parser


def check_bloc(bloc: Bloc, lineno: int) -> None:
def check_bloc(bloc: Bloc, lineno: int) -> list[str]:
prev_bio = "O"
prev_segment = ""
errors = []
Expand Down
19 changes: 19 additions & 0 deletions scripts/grid_search.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/sh

set -e

C1="0.01 0.05 0.1 0.2 0.5 1.0"
C2="0.01 0.05 0.1 0.2 0.5 1.0"
FEATURES="literal delta vl vsl"
OUTDIR=grid-$(date +%Y%m%d-%H:%M)

mkdir -p $OUTDIR
for f in $FEATURES; do
for c1 in $C1; do
for c2 in $C2; do
echo "Training $f L1 $c1 L2 $c2"
python scripts/train_crf.py -x 0 --c1 $c1 --c2 $c2 --features $f \
--scores $OUTDIR/scores-${f}-${c1}-${c2}.csv data/*.csv 2>&1
done
done
done
177 changes: 123 additions & 54 deletions scripts/train_crf.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
"""Entrainer un CRF pour segmentation/identification"""

import argparse
import csv
import itertools
import logging
import os
from pathlib import Path
from typing import Iterable, Iterator, Optional
from typing import Iterable, Iterator

import joblib # type: ignore
import numpy as np
import sklearn_crfsuite as crfsuite # type: ignore
from sklearn.metrics import make_scorer # type: ignore
from sklearn.model_selection import KFold, cross_validate # type: ignore
from sklearn_crfsuite import metrics

from alexi.segment import load, page2features, page2labels, split_pages

LOGGER = logging.getLogger("train-crf")


def make_argparse():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--niter", default=100, type=int, help="Nombre d'iterations d'entrainement"
"csvs", nargs="+", help="Fichiers CSV d'entrainement", type=Path
)
parser.add_argument(
"--niter", default=200, type=int, help="Nombre d'iterations d'entrainement"
)
parser.add_argument("--features", default="vsl", help="Extracteur de traits")
parser.add_argument("--labels", default="literal", help="Transformateur de classes")
Expand All @@ -23,78 +35,135 @@ def make_argparse():
)
parser.add_argument("-n", default=2, type=int, help="Largeur du contexte de traits")
parser.add_argument(
"--c1", default=0.5, type=float, help="Coefficient de regularisation L1"
"--c1", default=0.2, type=float, help="Coefficient de regularisation L1"
)
parser.add_argument(
"--c2", default=0.01, type=float, help="Coefficient de regularisation L2"
)
parser.add_argument("--seed", default=1381, type=int, help="Graine aléatoire")
parser.add_argument(
"--min-count",
default=10,
type=int,
help="Seuil d'évaluation pour chaque classification",
)
parser.add_argument(
"--c2", default=0.1, type=float, help="Coefficient de regularisation L2"
"-x",
"--cross-validation-folds",
default=1,
type=int,
help="Faire la validation croisée pour évaluer le modèle.",
)
parser.add_argument("-o", "--outfile", help="Fichier destination pour modele")
parser.add_argument("-s", "--scores", help="Fichier destination pour évaluations")
return parser


def filter_tab(words: Iterable[dict]) -> Iterator[dict]:
"""Enlever les mots dans des tableaux car on va s'en occuper autrement."""
for w in words:
if "Tableau" in w["segment"]:
continue
if "Table" in w["tagstack"]:
continue
yield w


def train(
train_set: Iterable[dict],
dev_set: Optional[Iterable[dict]] = None,
features="vsl",
labels="literal",
n=2,
niter=69,
c1=0.1,
c2=0.1,
) -> crfsuite.CRF:
train_pages = list(split_pages(filter_tab(train_set)))
X_train = [page2features(s, features, n) for s in train_pages]
y_train = [page2labels(s, labels) for s in train_pages]

params = {
"c1": c1,
"c2": c2,
"algorithm": "lbfgs",
"max_iterations": niter,
"all_possible_transitions": True,
def run_cv(args: argparse.Namespace, params: dict, X, y):
if args.cross_validation_folds == 0:
args.cross_validation_folds = os.cpu_count()
LOGGER.debug("Using 1 fold per CPU")
LOGGER.info("Running cross-validation in %d folds", args.cross_validation_folds)
counts: dict[str, int] = {}
for c in itertools.chain.from_iterable(y):
if c.startswith("B-"):
count = counts.setdefault(c, 0)
counts[c] = count + 1
labels = []
for c, n in counts.items():
if n < args.min_count:
LOGGER.debug("Label %s count %d (excluded)", c, n)
else:
LOGGER.debug("Label %s count %d", c, n)
labels.append(c)
labels.sort()
LOGGER.info("Evaluating on: %s", ",".join(labels))
crf = crfsuite.CRF(**params)
scoring = {
"macro_f1": make_scorer(
metrics.flat_f1_score, labels=labels, average="macro", zero_division=0.0
),
"micro_f1": make_scorer(
metrics.flat_f1_score, labels=labels, average="micro", zero_division=0.0
),
}
crf = crfsuite.CRF(**params, verbose=True)
if dev_set is not None:
dev_pages = list(split_pages(filter_tab(dev_set)))
X_dev = [page2features(s, features, n) for s in dev_pages]
y_dev = [page2labels(s, labels) for s in dev_pages]
crf.fit(X_train, y_train, X_dev=X_dev, y_dev=y_dev)
else:
crf.fit(X_train, y_train)
return crf
for name in labels:
scoring[name] = make_scorer(
metrics.flat_f1_score,
labels=[name],
average="micro",
zero_division=0.0,
)
scores = cross_validate(
crf,
X,
y,
cv=KFold(args.cross_validation_folds, shuffle=True, random_state=args.seed),
scoring=scoring,
return_estimator=True,
n_jobs=os.cpu_count(),
)
LOGGER.info("Macro F1: %.3f", scores["test_macro_f1"].mean())
LOGGER.info("Micro F1: %.3f", scores["test_micro_f1"].mean())
if args.outfile:
for idx, xcrf in enumerate(scores["estimator"]):
joblib.dump(
(xcrf, args.n, args.features, args.labels),
args.outfile + f"_{idx + 1}.gz",
)
if args.scores:
with open(args.scores, "wt") as outfh:
fieldnames = [
"Label",
"Average",
*range(1, args.cross_validation_folds + 1),
]
writer = csv.DictWriter(outfh, fieldnames=fieldnames)
writer.writeheader()

def makerow(name, scores):
row = {"Label": name, "Average": np.mean(scores)}
for idx, score in enumerate(scores):
row[idx + 1] = score
return row

writer.writerow(makerow("ALL", scores["test_macro_f1"]))
for name in labels:
writer.writerow(makerow(name, scores[f"test_{name}"]))


def main():
parser = make_argparse()
args = parser.parse_args()
train_set = itertools.chain(
load(Path("data/train").glob("*.csv")),
load([Path("test/data/pdf_structure.csv")]),
load([Path("test/data/pdf_figures.csv")]),
)
dev_set = load(Path("data/dev").glob("*.csv"))
if args.train_dev:
train_set = itertools.chain(train_set, dev_set)
dev_set = None
crf = train(
train_set,
dev_set,
features=args.features,
labels=args.labels,
n=args.n,
niter=args.niter,
c1=args.c1,
c2=args.c2,
)
if args.outfile:
joblib.dump((crf, args.n, args.features, args.labels), args.outfile)
logging.basicConfig(level=logging.INFO)
data = load(args.csvs)
pages = list(split_pages(filter_tab(data)))
X = [page2features(s, args.features, args.n) for s in pages]
y = [page2labels(s, args.labels) for s in pages]
params = {
"c1": args.c1,
"c2": args.c2,
"algorithm": "lbfgs",
"max_iterations": args.niter,
"all_possible_transitions": True,
}
if args.cross_validation_folds == 1:
crf = crfsuite.CRF(**params, verbose=True)
crf.fit(X, y)
if args.outfile:
joblib.dump((crf, args.n, args.features, args.labels), args.outfile)
else:
run_cv(args, params, X, y)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ install_requires =
pdfplumber @ git+https://github.com/dhdaines/pdfplumber.git
whoosh
sklearn-crfsuite @ git+https://github.com/MeMartijn/updated-sklearn-crfsuite.git
scikit-learn
joblib
lxml
packages = alexi
Expand Down
Binary file added test/data/model.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion test/test_analyse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from alexi.format import format_xml

DATADIR = Path(__file__).parent / "data"
TRAINDIR = Path(__file__).parent.parent / "data" / "train"
TRAINDIR = Path(__file__).parent.parent / "data"

IOBTEST = [
"<Titre>Titre incomplet</Titre>",
Expand Down
2 changes: 1 addition & 1 deletion test/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from alexi.format import format_html, format_text

DATADIR = Path(__file__).parent / "data"
TRAINDIR = Path(__file__).parent.parent / "data" / "train"
TRAINDIR = Path(__file__).parent.parent / "data"


def test_format_html():
Expand Down
2 changes: 1 addition & 1 deletion test/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_segment():
del word["segment"]
writer.writerow(word)
testfh.seek(0, 0)
seg = Segmenteur()
seg = Segmenteur(DATADIR / "model.gz")
reader = csv.DictReader(testfh)
words = list(seg(reader))
assert len(words) > 0
Expand Down