diff --git a/sacrebleu/dataset/wmt_xml.py b/sacrebleu/dataset/wmt_xml.py index d5eb5d8..4f78bcc 100644 --- a/sacrebleu/dataset/wmt_xml.py +++ b/sacrebleu/dataset/wmt_xml.py @@ -54,6 +54,7 @@ def _unwrap_wmt21_or_later(raw_file): src = [] docids = [] orig_langs = [] + domains = [] refs = { _get_field_by_translator(translator): [] for translator in translators } @@ -63,6 +64,8 @@ def _unwrap_wmt21_or_later(raw_file): for doc in tree.getroot().findall(".//doc"): docid = doc.attrib["id"] origlang = doc.attrib["origlang"] + # present wmt22++ + domain = doc.attrib.get("domain", None) # Skip the testsuite if "testsuite" in doc.attrib: @@ -103,9 +106,15 @@ def get_sents(doc): systems[system_name].append(hyps[system_name][seg_id]) docids.append(docid) orig_langs.append(origlang) + if domain is not None: + domains.append(domain) src_sent_count += 1 - return {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} + data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} + if len(domains): + data["domain"] = domains + + return data def _get_langpair_path(self, langpair): """