Skip to content

Commit

Permalink
Show domain where present (WMT22+)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpost committed Apr 12, 2024
1 parent 78e25e2 commit 1f70cd2
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _unwrap_wmt21_or_later(raw_file):
src = []
docids = []
orig_langs = []
domains = []

refs = { _get_field_by_translator(translator): [] for translator in translators }

Expand All @@ -63,6 +64,8 @@ def _unwrap_wmt21_or_later(raw_file):
for doc in tree.getroot().findall(".//doc"):
docid = doc.attrib["id"]
origlang = doc.attrib["origlang"]
# present wmt22++
domain = doc.attrib.get("domain", None)

# Skip the testsuite
if "testsuite" in doc.attrib:
Expand Down Expand Up @@ -103,9 +106,15 @@ def get_sents(doc):
systems[system_name].append(hyps[system_name][seg_id])
docids.append(docid)
orig_langs.append(origlang)
if domain is not None:
domains.append(domain)
src_sent_count += 1

return {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if len(domains):
data["domain"] = domains

return data

def _get_langpair_path(self, langpair):
"""
Expand Down

0 comments on commit 1f70cd2

Please sign in to comment.