Skip to content

Commit

Permalink
Added WMT22 data (closes #215) (#216)
Browse files Browse the repository at this point in the history
* Added WMT22 data
* allow langpair-specific overrides to select the default annotator
* added both refs for wmt21/dev
  • Loading branch information
mjpost authored Oct 18, 2022
1 parent e416ee2 commit 5166cf7
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 25 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,7 @@ data/
build
dist
__pycache__
sacrebleu.egg-info
.sacrebleu
*~
.DS_Store
50 changes: 49 additions & 1 deletion sacrebleu/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,51 @@

DATASETS = {
# wmt
"wmt22": WMTXMLDataset(
"wmt22",
data=["https://github.com/wmt-conference/wmt22-news-systems/archive/refs/tags/v1.1.tar.gz"],
description="Official evaluation and system data for WMT22.",
md5=["0840978b9b50b9ac3b2b081e37d620b9"],
langpairs={
"cs-en": {
"path": "wmt22-news-systems-1.1/xml/wmttest2022.cs-en.all.xml",
"refs": ["B"],
},
"cs-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.cs-uk.all.xml"],
"de-en": ["wmt22-news-systems-1.1/xml/wmttest2022.de-en.all.xml"],
"de-fr": ["wmt22-news-systems-1.1/xml/wmttest2022.de-fr.all.xml"],
"en-cs": {
"path": "wmt22-news-systems-1.1/xml/wmttest2022.en-cs.all.xml",
"refs": ["B"],
},
"en-de": ["wmt22-news-systems-1.1/xml/wmttest2022.en-de.all.xml"],
"en-hr": ["wmt22-news-systems-1.1/xml/wmttest2022.en-hr.all.xml"],
"en-ja": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ja.all.xml"],
"en-liv": ["wmt22-news-systems-1.1/xml/wmttest2022.en-liv.all.xml"],
"en-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.en-ru.all.xml"],
"en-uk": ["wmt22-news-systems-1.1/xml/wmttest2022.en-uk.all.xml"],
"en-zh": ["wmt22-news-systems-1.1/xml/wmttest2022.en-zh.all.xml"],
"fr-de": ["wmt22-news-systems-1.1/xml/wmttest2022.fr-de.all.xml"],
"ja-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ja-en.all.xml"],
"liv-en": {
"path": "wmt22-news-systems-1.1/xml/wmttest2022.liv-en.all.xml",
# no translator because data is English-original
"refs": [""],
},
"ru-en": ["wmt22-news-systems-1.1/xml/wmttest2022.ru-en.all.xml"],
"ru-sah": {
"path": "wmt22-news-systems-1.1/xml/wmttest2022.ru-sah.all.xml",
# no translator because data is Yakut-original
"refs": [""],
},
"sah-ru": ["wmt22-news-systems-1.1/xml/wmttest2022.sah-ru.all.xml"],
"uk-cs": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-cs.all.xml"],
"uk-en": ["wmt22-news-systems-1.1/xml/wmttest2022.uk-en.all.xml"],
"zh-en": ["wmt22-news-systems-1.1/xml/wmttest2022.zh-en.all.xml"],
},
# the default reference to use with this dataset
refs=["A"],
),
"wmt21/systems": WMTXMLDataset(
"wmt21/systems",
data=["https://github.com/wmt-conference/wmt21-news-systems/archive/refs/tags/v1.3.tar.gz"],
Expand Down Expand Up @@ -101,8 +146,9 @@
"xh-zu": ["wmt21-news-systems-1.3/xml/florestest2021.xh-zu.all.xml"],
"zu-xh": ["wmt21-news-systems-1.3/xml/florestest2021.zu-xh.all.xml"],
},
# the reference to use with this dataset
refs=["A"],
),

"wmt21": WMTXMLDataset(
"wmt21",
data=["http://data.statmt.org/wmt21/translation-task/test.tgz"],
Expand Down Expand Up @@ -210,6 +256,8 @@
"en-is": ["dev/xml/newsdev2021.en-is.xml"],
"is-en": ["dev/xml/newsdev2021.is-en.xml"],
},
# datasets are bidirectional in origin, so use both refs
refs=["A", ""],
),
"wmt20/tworefs": FakeSGMLDataset(
"wmt20/tworefs",
Expand Down
78 changes: 55 additions & 23 deletions sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
from collections import defaultdict


def _get_field_by_translator(translator):
if not translator:
return "ref"
else:
return f"ref:{translator}"

class WMTXMLDataset(Dataset):
"""
The 2021+ WMT dataset format. Everything is contained in a single file.
Can be parsed with the lxml parser.
"""

@staticmethod
def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]):
def _unwrap_wmt21_or_later(raw_file):
"""
Unwraps the XML file from wmt21 or later.
This script is adapted from https://github.com/wmt-conference/wmt-format-tools
Expand All @@ -37,27 +42,20 @@ def _unwrap_wmt21_or_later(raw_file, allowed_refs=[]):
for ref_doc in tree.getroot().findall(".//ref"):
ref_langs.add(ref_doc.get("lang"))
translator = ref_doc.get("translator")
if len(allowed_refs) == 0 or translator in allowed_refs:
translators.add(translator)
translators.add(translator)

assert (
len(src_langs) == 1
), f"Multiple source languages found in the file: {raw_file}"
assert (
len(ref_langs) == 1
), f"Multiple reference languages found in the file: {raw_file}"
), f"Found {len(ref_langs)} reference languages found in the file: {raw_file}"

src = []
docids = []
orig_langs = []

def get_field_by_translator(translator):
if not translator:
return "ref"
else:
return f"ref:{translator}"

refs = {get_field_by_translator(translator): [] for translator in translators}
refs = { _get_field_by_translator(translator): [] for translator in translators }

systems = defaultdict(list)

Expand Down Expand Up @@ -97,7 +95,7 @@ def get_sents(doc):
if not any([value.get(seg_id, "") for value in trans_to_ref.values()]):
continue
for translator in translators:
refs[get_field_by_translator(translator)].append(
refs[_get_field_by_translator(translator)].append(
trans_to_ref.get(translator, {translator: {}}).get(seg_id, "")
)
src.append(src_sents[seg_id])
Expand All @@ -109,22 +107,31 @@ def get_sents(doc):

return {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}

def _get_langpair_path(self, langpair):
"""
Returns the path for this language pair.
This is useful because in WMT22, the language-pair data structure can be a dict,
in order to allow for overriding which test set to use.
"""
langpair_data = self._get_langpair_metadata(langpair)[langpair]
rel_path = langpair_data["path"] if type(langpair_data) == dict else langpair_data[0]
return os.path.join(self._rawdir, rel_path)

def process_to_text(self, langpair=None):
"""Processes raw files to plain text files.
:param langpair: The language pair to process. e.g. "en-de". If None, all files will be processed.
"""
# ensure that the dataset is downloaded
self.maybe_download()
langpairs = self._get_langpair_metadata(langpair)

for langpair, files in langpairs.items():
rawfile = os.path.join(
self._rawdir, files[0]
) # all source and reference data in one file for wmt21 and later
for langpair in sorted(self._get_langpair_metadata(langpair).keys()):
# The data type can be a list of paths, or a dict, containing the "path"
# and an override on which labeled reference to use (key "refs")
rawfile = self._get_langpair_path(langpair)

with smart_open(rawfile) as fin:
fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", []))
fields = self._unwrap_wmt21_or_later(fin)

for fieldname in fields:
textfile = self._get_txt_file_path(langpair, fieldname)
Expand All @@ -137,11 +144,37 @@ def process_to_text(self, langpair=None):
for line in fields[fieldname]:
print(self._clean(line), file=fout)

def _get_langpair_allowed_refs(self, langpair):
"""
Returns the preferred references for this language pair.
This can be set in the language pair block (as in WMT22), and backs off to the
test-set-level default, or nothing.
There is one exception. In the metadata, sometimes there is no translator field
listed (e.g., wmt22:liv-en). In this case, the reference is set to "", and the
field "ref" is returned.
"""
defaults = self.kwargs.get("refs", [])
langpair_data = self._get_langpair_metadata(langpair)[langpair]
if type(langpair_data) == dict:
allowed_refs = langpair_data.get("refs", defaults)
else:
allowed_refs = defaults
allowed_refs = [_get_field_by_translator(ref) for ref in allowed_refs]

return allowed_refs

def get_reference_files(self, langpair):
"""
Returns the requested reference files.
This is defined as a default at the test-set level, and can be overridden per language.
"""
# Iterate through the (label, file path) pairs, looking for permitted labels
allowed_refs = self._get_langpair_allowed_refs(langpair)
all_files = self.get_files(langpair)
all_fields = self.fieldnames(langpair)
ref_files = [
f for f, field in zip(all_files, all_fields) if field.startswith("ref")
f for f, field in zip(all_files, all_fields) if field in allowed_refs
]
return ref_files

Expand All @@ -157,10 +190,9 @@ def fieldnames(self, langpair):
:return: a list of field names
"""
self.maybe_download()
meta = self._get_langpair_metadata(langpair)[langpair]
rawfile = os.path.join(self._rawdir, meta[0])
rawfile = self._get_langpair_path(langpair)

with smart_open(rawfile) as fin:
fields = self._unwrap_wmt21_or_later(fin, allowed_refs=self.kwargs.get("refs", []))
fields = self._unwrap_wmt21_or_later(fin)

return list(fields.keys())
38 changes: 37 additions & 1 deletion test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,40 @@ def test_source_and_references():
"""
for ds in dataset.DATASETS.values():
for pair in ds.langpairs:
assert len(list(ds.source(pair))) == len(list(ds.references(pair)))
src_len = len(list(ds.source(pair)))
ref_len = len(list(ds.references(pair)))
assert src_len == ref_len, f"source/reference failure for {ds.name}:{pair} len(source)={src_len} len(references)={ref_len}"


def test_wmt22_references():
"""
WMT21 added the ability to specify which reference to use (among many in the XML).
The default was "A" for everything.
WMT22 added the ability to override this default on a per-langpair basis, by
replacing the langpair list of paths with a dict that had the list of paths and
the annotator override.
"""
wmt22 = dataset.DATASETS["wmt22"]

# make sure CS-EN returns all reference fields
cs_en_fields = wmt22.fieldnames("cs-en")
for ref in ["ref:B", "ref:C"]:
assert ref in cs_en_fields
assert "ref:A" not in cs_en_fields

# make sure ref:B is the one used by default
assert wmt22._get_langpair_allowed_refs("cs-en") == ["ref:B"]

# similar check for another dataset: there should be no default ("A"),
# and the only ref found should be the unannotated one
assert "ref:A" not in wmt22.fieldnames("liv-en")
assert "ref" in wmt22.fieldnames("liv-en")

# and that ref:A is the default for all languages where it wasn't overridden
for langpair, langpair_data in wmt22.langpairs.items():
if type(langpair_data) == dict:
assert wmt22._get_langpair_allowed_refs(langpair) != ["ref:A"]
else:
assert wmt22._get_langpair_allowed_refs(langpair) == ["ref:A"]


0 comments on commit 5166cf7

Please sign in to comment.