From 163b5941040510a7c3dfd54698629cdd6317a515 Mon Sep 17 00:00:00 2001 From: Martin Popel Date: Sun, 24 Mar 2024 22:08:14 +0100 Subject: [PATCH] allow `--detail` and `--subset` to be used also with the new XML test sets --- sacrebleu/dataset/wmt_xml.py | 29 +++++----- sacrebleu/utils.py | 102 ++++++++++++++++++++++++----------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/sacrebleu/dataset/wmt_xml.py b/sacrebleu/dataset/wmt_xml.py index 4f78bcc..1aedf1d 100644 --- a/sacrebleu/dataset/wmt_xml.py +++ b/sacrebleu/dataset/wmt_xml.py @@ -26,10 +26,12 @@ def _unwrap_wmt21_or_later(raw_file): This script is adapted from https://github.com/wmt-conference/wmt-format-tools :param raw_file: The raw xml file to unwrap. - :return: Dictionary which contains the following fields: + :return: Dictionary which contains the following fields + (each a list with values for each sentence): - `src`: The source sentences. - `docid`: ID indicating which document the sentences belong to. - `origlang`: The original language of the document. + - `domain`: Domain of the document. - `ref:{translator}`: The references produced by each translator. - `ref`: An alias for the references from the first translator. """ @@ -60,13 +62,8 @@ def _unwrap_wmt21_or_later(raw_file): systems = defaultdict(list) - src_sent_count, doc_count = 0, 0 + src_sent_count, doc_count, seen_domain = 0, 0, False for doc in tree.getroot().findall(".//doc"): - docid = doc.attrib["id"] - origlang = doc.attrib["origlang"] - # present wmt22++ - domain = doc.attrib.get("domain", None) - # Skip the testsuite if "testsuite" in doc.attrib: continue @@ -104,17 +101,17 @@ def get_sents(doc): src.append(src_sents[seg_id]) for system_name in hyps.keys(): systems[system_name].append(hyps[system_name][seg_id]) - docids.append(docid) - orig_langs.append(origlang) - if domain is not None: - domains.append(domain) + docids.append(doc.attrib["id"]) + orig_langs.append(doc.attrib["origlang"]) + # The "domain" attribute is missing in WMT21 and WMT22 + domains.append(doc.get("domain")) + seen_domain = doc.get("domain") is not None src_sent_count += 1 - data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} - if len(domains): - data["domain"] = domains - - return data + fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems} + if seen_domain: + fields["domain"] = domains + return fields def _get_langpair_path(self, langpair): """ diff --git a/sacrebleu/utils.py b/sacrebleu/utils.py index 56e6fca..cddc3a7 100644 --- a/sacrebleu/utils.py +++ b/sacrebleu/utils.py @@ -488,7 +488,7 @@ def get_available_testsets_for_langpair(langpair: str) -> List[str]: def get_available_origlangs(test_sets, langpair) -> List[str]: - """Return a list of origlang values in according to the raw SGM files.""" + """Return a list of origlang values according to the raw XML/SGM files.""" if test_sets is None: return [] @@ -496,6 +496,10 @@ def get_available_origlangs(test_sets, langpair) -> List[str]: for test_set in test_sets.split(','): dataset = DATASETS[test_set] rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) + from .dataset.wmt_xml import WMTXMLDataset + if isinstance(dataset, WMTXMLDataset): + for origlang in dataset._unwrap_wmt21_or_later(rawfile)['origlang']: + origlangs.add(origlang) if rawfile.endswith('.sgm'): with smart_open(rawfile) as fin: for line in fin: @@ -505,6 +509,25 @@ def get_available_origlangs(test_sets, langpair) -> List[str]: return sorted(list(origlangs)) +def get_available_subsets(test_sets, langpair) -> List[str]: + """Return a list of domain values according to the raw XML files and domain/country values from the SGM files.""" + if test_sets is None: + return [] + + subsets = set() + for test_set in test_sets.split(','): + dataset = DATASETS[test_set] + from .dataset.wmt_xml import WMTXMLDataset + if isinstance(dataset, WMTXMLDataset): + rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) + fields = dataset._unwrap_wmt21_or_later(rawfile) + if 'domain' in fields: + subsets |= set(fields['domain']) + elif test_set in SUBSETS: + subsets |= set("country:" + v.split("-")[0] for v in SUBSETS[test_set].values()) + subsets |= set(v.split("-")[1] for v in SUBSETS[test_set].values()) + return sorted(list(subsets)) + def filter_subset(systems, test_sets, langpair, origlang, subset=None): """Filter sentences with a given origlang (or subset) according to the raw SGM files.""" if origlang is None and subset is None: @@ -516,37 +539,49 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None): re_id = re.compile(r'.* docid="([^"]+)".*\n') indices_to_keep = [] - for test_set in test_sets.split(','): dataset = DATASETS[test_set] rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0]) - if not rawfile.endswith('.sgm'): - raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}') - if subset is not None: - if test_set not in SUBSETS: - raise Exception('No subset annotation available for test set ' + test_set) - doc_to_tags = SUBSETS[test_set] - number_sentences_included = 0 - with smart_open(rawfile) as fin: - include_doc = False - for line in fin: - if line.startswith('