Skip to content

Commit

Permalink
Added TER and BLEU for early stopping (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaaallleen authored Jul 22, 2024
1 parent a96cf21 commit 5120fdb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
35 changes: 34 additions & 1 deletion eole/utils/earlystopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,43 @@ def _caller(self, stats):
return stats.accuracy()


class BLEUScorer(Scorer):
def __init__(self):
super(BLEUScorer, self).__init__(float("-inf"), "bleu")

def is_improving(self, stats):
return stats.computed_metric("BLEU") > self.best_score

def is_decreasing(self, stats):
return stats.computed_metric("BLEU") < self.best_score

def _caller(self, stats):
return stats.computed_metric("BLEU")


class TERScorer(Scorer):
def __init__(self):
super(TERScorer, self).__init__(float("inf"), "bleu")

def is_improving(self, stats):
return stats.computed_metric("TER") < self.best_score

def is_decreasing(self, stats):
return stats.computed_metric("TER") > self.best_score

def _caller(self, stats):
return stats.computed_metric("TER")


DEFAULT_SCORERS = [PPLScorer(), AccuracyScorer()]


SCORER_BUILDER = {"ppl": PPLScorer, "accuracy": AccuracyScorer}
SCORER_BUILDER = {
"ppl": PPLScorer,
"accuracy": AccuracyScorer,
"BLEU": BLEUScorer,
"TER": TERScorer,
}


def scorers_from_config(config):
Expand Down
6 changes: 6 additions & 0 deletions eole/utils/statistics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Statistics calculation utility """

import time
import math
import sys
Expand Down Expand Up @@ -101,6 +102,11 @@ def update(self, stat, update_n_src_words=False):
if update_n_src_words:
self.n_src_words += stat.n_src_words

def computed_metric(self, metric):
"""check if metric(TER/BLEU) is computed and return it"""
assert metric in self.computed_metrics, "Metric {} not found".format(metric)
return self.computed_metrics[metric]

def accuracy(self):
"""compute accuracy"""
return 100 * (self.n_correct / self.n_words)
Expand Down

0 comments on commit 5120fdb

Please sign in to comment.