Skip to content

Commit

Permalink
allow --detail and --subset to be used also with the new XML test…
Browse files Browse the repository at this point in the history
… sets
  • Loading branch information
martinpopel authored and jkawamoto committed Jul 16, 2024
1 parent f434710 commit 163b594
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 48 deletions.
29 changes: 13 additions & 16 deletions sacrebleu/dataset/wmt_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def _unwrap_wmt21_or_later(raw_file):
This script is adapted from https://github.com/wmt-conference/wmt-format-tools
:param raw_file: The raw xml file to unwrap.
:return: Dictionary which contains the following fields:
:return: Dictionary which contains the following fields
(each a list with values for each sentence):
- `src`: The source sentences.
- `docid`: ID indicating which document the sentences belong to.
- `origlang`: The original language of the document.
- `domain`: Domain of the document.
- `ref:{translator}`: The references produced by each translator.
- `ref`: An alias for the references from the first translator.
"""
Expand Down Expand Up @@ -60,13 +62,8 @@ def _unwrap_wmt21_or_later(raw_file):

systems = defaultdict(list)

src_sent_count, doc_count = 0, 0
src_sent_count, doc_count, seen_domain = 0, 0, False
for doc in tree.getroot().findall(".//doc"):
docid = doc.attrib["id"]
origlang = doc.attrib["origlang"]
# present wmt22++
domain = doc.attrib.get("domain", None)

# Skip the testsuite
if "testsuite" in doc.attrib:
continue
Expand Down Expand Up @@ -104,17 +101,17 @@ def get_sents(doc):
src.append(src_sents[seg_id])
for system_name in hyps.keys():
systems[system_name].append(hyps[system_name][seg_id])
docids.append(docid)
orig_langs.append(origlang)
if domain is not None:
domains.append(domain)
docids.append(doc.attrib["id"])
orig_langs.append(doc.attrib["origlang"])
# The "domain" attribute is missing in WMT21 and WMT22
domains.append(doc.get("domain"))
seen_domain = doc.get("domain") is not None
src_sent_count += 1

data = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if len(domains):
data["domain"] = domains

return data
fields = {"src": src, **refs, "docid": docids, "origlang": orig_langs, **systems}
if seen_domain:
fields["domain"] = domains
return fields

def _get_langpair_path(self, langpair):
"""
Expand Down
102 changes: 70 additions & 32 deletions sacrebleu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,18 @@ def get_available_testsets_for_langpair(langpair: str) -> List[str]:


def get_available_origlangs(test_sets, langpair) -> List[str]:
"""Return a list of origlang values in according to the raw SGM files."""
"""Return a list of origlang values according to the raw XML/SGM files."""
if test_sets is None:
return []

origlangs = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
for origlang in dataset._unwrap_wmt21_or_later(rawfile)['origlang']:
origlangs.add(origlang)
if rawfile.endswith('.sgm'):
with smart_open(rawfile) as fin:
for line in fin:
Expand All @@ -505,6 +509,25 @@ def get_available_origlangs(test_sets, langpair) -> List[str]:
return sorted(list(origlangs))


def get_available_subsets(test_sets, langpair) -> List[str]:
"""Return a list of domain values according to the raw XML files and domain/country values from the SGM files."""
if test_sets is None:
return []

subsets = set()
for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
fields = dataset._unwrap_wmt21_or_later(rawfile)
if 'domain' in fields:
subsets |= set(fields['domain'])
elif test_set in SUBSETS:
subsets |= set("country:" + v.split("-")[0] for v in SUBSETS[test_set].values())
subsets |= set(v.split("-")[1] for v in SUBSETS[test_set].values())
return sorted(list(subsets))

def filter_subset(systems, test_sets, langpair, origlang, subset=None):
"""Filter sentences with a given origlang (or subset) according to the raw SGM files."""
if origlang is None and subset is None:
Expand All @@ -516,37 +539,49 @@ def filter_subset(systems, test_sets, langpair, origlang, subset=None):
re_id = re.compile(r'.* docid="([^"]+)".*\n')

indices_to_keep = []

for test_set in test_sets.split(','):
dataset = DATASETS[test_set]
rawfile = os.path.join(SACREBLEU_DIR, test_set, 'raw', dataset.langpairs[langpair][0])
if not rawfile.endswith('.sgm'):
raise Exception(f'--origlang and --subset supports only *.sgm files, not {rawfile!r}')
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
number_sentences_included = 0
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
from .dataset.wmt_xml import WMTXMLDataset
if isinstance(dataset, WMTXMLDataset):
fields = dataset._unwrap_wmt21_or_later(rawfile)
for doc_origlang, doc_domain in zip(fields['origlang'], fields['domain']):
if origlang is None:
include_doc = True
else:
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
include_doc = doc_origlang == origlang
if subset is not None and (doc_domain is None or not re.search(subset, doc_domain)):
include_doc = False
indices_to_keep.append(include_doc)
elif rawfile.endswith('.sgm'):
if subset is not None:
if test_set not in SUBSETS:
raise Exception('No subset annotation available for test set ' + test_set)
doc_to_tags = SUBSETS[test_set]
with smart_open(rawfile) as fin:
include_doc = False
for line in fin:
if line.startswith('<doc '):
if origlang is None:
include_doc = True
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
number_sentences_included += 1 if include_doc else 0
doc_origlang = re_origlang.sub(r'\1', line)
if origlang.startswith('non-'):
include_doc = doc_origlang != origlang[4:]
else:
include_doc = doc_origlang == origlang

if subset is not None:
doc_id = re_id.sub(r'\1', line)
if not re.search(subset, doc_to_tags.get(doc_id, '')):
include_doc = False
if line.startswith('<seg '):
indices_to_keep.append(include_doc)
else:
raise Exception(f'--origlang and --subset supports only WMT *.xml and *.sgm files, not {rawfile!r}')
return [[sentence for sentence, keep in zip(sys, indices_to_keep) if keep] for sys in systems]


Expand All @@ -565,8 +600,9 @@ def print_subset_results(metrics, full_system, full_refs, args):
subsets = [None]
if args.subset is not None:
subsets += [args.subset]
elif all(t in SUBSETS for t in args.test_set.split(',')):
subsets += COUNTRIES + DOMAINS
else:
subsets += get_available_subsets(args.test_set, args.langpair)

for subset in subsets:
system, *refs = filter_subset(
[full_system, *full_refs], args.test_set, args.langpair, origlang, subset)
Expand All @@ -575,9 +611,11 @@ def print_subset_results(metrics, full_system, full_refs, args):
continue

key = f'origlang={origlang}'
if subset in COUNTRIES:
key += f' country={subset}'
elif subset in DOMAINS:
if subset is None:
key += f' domain=ALL'
elif subset.startswith('country:'):
key += f' country={subset[8:]}'
else:
key += f' domain={subset}'

for metric in metrics.values():
Expand Down

0 comments on commit 163b594

Please sign in to comment.