Skip to content

Commit

Permalink
add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Nov 15, 2023
1 parent 49c3a39 commit f2de694
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
39 changes: 19 additions & 20 deletions scripts/xml2csv_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import csv
import os
from pathlib import Path
from typing import List

from supermat.supermat_tei_parser import process_file_to_json

Expand All @@ -11,7 +10,7 @@
paragraph_id = 'paragraph_id'


def write_output(data, output_path, format, header):
def write_output(output_path, data, header, format="csv"):
delimiter = '\t' if format == 'tsv' else ','
fw = csv.writer(open(output_path, encoding='utf-8', mode='w'), delimiter=delimiter, quotechar='"')
fw.writerow(header)
Expand Down Expand Up @@ -41,7 +40,7 @@ def get_texts(data_sorted):

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Converter XML (Supermat) to a tabular values (CSV, TSV) for entity extraction (no relation information are used)")
description="Converter XML (Supermat) to CSV for entity extraction (no relation information are used)")

parser.add_argument("--input",
help="Input file or directory",
Expand All @@ -53,21 +52,22 @@ def get_texts(data_sorted):
action="store_true",
default=False,
help="Process input directory recursively. If input is a file, this parameter is ignored.")
parser.add_argument("--format",
default='csv',
choices=['tsv', 'csv'],
help="Output format.")
parser.add_argument("--entity-type",
default="material",
required=False,
help="Select which entity type to extract.")
parser.add_argument("--use-paragraphs",
default=False,
action="store_true",
help="Uses paragraphs instead of sentences. By default this script assumes that the XML is at sentence level.")
help="Uses paragraphs instead of sentences. "
"By default this script assumes that the XML is at sentence level.")

args = parser.parse_args()

input = args.input
output = args.output
recursive = args.recursive
format = args.format
ent_type = args.entity_type
use_paragraphs = args.use_paragraphs

if os.path.isdir(input):
Expand All @@ -87,43 +87,42 @@ def get_texts(data_sorted):
texts_data.extend(text_data)

if os.path.isdir(str(output)):
output_path_text = os.path.join(output, "output-text") + "." + format
output_path_expected = os.path.join(output, "output-" + ent_type) + "." + format
output_path_text = os.path.join(output, "output-text") + ".csv"
output_path_expected = os.path.join(output, "output-" + ent_type) + ".csv"
else:
parent_dir = Path(output).parent
output_path_text = os.path.join(parent_dir, "output-text" + "." + format)
output_path_expected = os.path.join(parent_dir, "output-" + ent_type + "." + format)
output_path_text = os.path.join(parent_dir, "output-text" + ".csv")
output_path_expected = os.path.join(parent_dir, "output-" + ent_type + ".csv")

header = ["id", "filename", "pid", ent_type]

for idx, data in enumerate(entities_data):
data[0] = idx

write_output(entities_data, output_path_expected, format, header)
write_output(output_path_expected, entities_data, header)

header = ["id", "filename", "pid", "text"]
for idx, data in enumerate(texts_data):
data[0] = idx
write_output(texts_data, output_path_text, format, header)
write_output(output_path_text, texts_data, header)

elif os.path.isfile(input):
input_path = Path(input)
file_data = process_file_to_json(input_path, use_paragraphs=use_paragraphs)
output_filename = input_path.stem

output_path_text = os.path.join(output, str(output_filename) + "-text" + "." + format)
output_path_text = os.path.join(output, str(output_filename) + "-text" + ".csv")
texts_data = get_texts(file_data)
for idx, data in enumerate(texts_data):
data[0] = idx

header = ["id", "filename", "pid", "text"]
write_output(texts_data, output_path_text, format, header)
write_output(output_path_text, texts_data, header)

ent_type = "material"
output_path_expected = os.path.join(output, str(output_filename) + "-" + ent_type + "." + format)
output_path_expected = os.path.join(output, str(output_filename) + "-" + ent_type + ".csv")
ent_data_no_duplicates = get_entity_data(file_data, ent_type)
for idx, data in enumerate(ent_data_no_duplicates):
data[0] = idx

header = ["id", "filename", "pid", ent_type]
write_output(ent_data_no_duplicates, output_path_expected, format, header)
write_output(output_path_expected, ent_data_no_duplicates, header)
18 changes: 18 additions & 0 deletions tests/test_data/kotesample.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<?xml version="1.0" encoding="UTF-8"?>
<tei xmlns="http://www.tei-c.org/ns/1.0">
<text xml:lang="en">
<body>
<p>
<s>We used Daphne 7474 for the <rs xml:id="x10000" type="me_method">resistivity</rs> measurement as a pressure-transmitting medium. 13) Applied pressure was estimated from the <rs type="tc">T c</rs> of the <rs type="material">lead</rs> manometer.</s>
</p>
<p>
<s>a)-3(c) show the temperature dependences of ρ for <rs type="material" xml:id="x92">Sr 2 ScFePO 3</rs> at high pressures up to 4.17 GPa.</s>
<s>At <rs corresp="#x95,#x97" type="pressure">ambient pressure</rs>, we estimated <rs type="tc">T onset c</rs> = <rs corresp="#x92,#x11-164" type="tcValue" xml:id="x95">15.6 K</rs> and <rs type="tc">T zero c</rs> = <rs corresp="#x92,#x11-164" type="tcValue" xml:id="x97">8.4 K</rs>.</s>
<s>This system exhibits a large supercon- ducting transition width probably owing to the low bulk density of the sample.</s>
<s>We should note that a clear <rs xml:id="x10001" type="me_method">dia- magnetic signal</rs> is observed below <rs corresp="#x92,#x10000" type="tcValue" xml:id="x99">15 − 17 K</rs>. 9) The tran- sition width was large in run 1; it was sharp in run 2 using a different part of the same sample.</s>
<s>Since the ab- solute value of ρ changed unnaturally in run 2, ρ is nor- malized at 20 K.</s>
<s>As seen in the figure, the <rs type="tc">T c</rs> of this compound decreases significantly under pressure, and the zero-<rs type="me_method" xml:id="x11-164">resistance</rs> state goes beyond the observed tempera- ture range above 4 GPa.</s>
</p>
</body>
</text>
</tei>
25 changes: 24 additions & 1 deletion tests/test_supermat_tei_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import bs4
from bs4 import BeautifulSoup

from src.supermat.supermat_tei_parser import get_nodes, process_paragraphs
from src.supermat.supermat_tei_parser import get_nodes, process_paragraphs, process_file_to_json


def test_get_sentences_nodes_input_with_sentences_grouped():
Expand Down Expand Up @@ -128,3 +128,26 @@ def test_process_paragraphs_input_sentences_2():

second = passages[1]
assert len(second['spans']) == 6


def test_process_file():
file_path = os.path.join(os.path.dirname(__file__), "test_data", 'kotesample.xml')

json = process_file_to_json(file_path)

span_map = {}

passages = json['passages']
for passage in passages:
for span in passage['spans']:
if 'id' in span:
if span['id'] not in span_map:
span_map[span['id']] = span
else:
print("error")

assert len(span_map.keys()) == 8

relations = json['relations']

assert len(relations) == 8

0 comments on commit f2de694

Please sign in to comment.