Skip to content

Commit

Permalink
update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lfoppiano committed Nov 16, 2023
1 parent f2de694 commit e42beeb
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 193 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ pandas
regex
bump-my-version
supermat
pytest
pytest
tqdm
323 changes: 132 additions & 191 deletions scripts/xml2csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,200 +3,140 @@
import os
import sys
from pathlib import Path
from typing import List

from bs4 import BeautifulSoup, Tag
from tqdm import tqdm

from src.supermat.supermat_tei_parser import get_nodes
from src.supermat.supermat_tei_parser import process_file_to_json
from src.supermat.utils import get_in_paths_from_directory

paragraph_id = 'paragraph_id'


def process_file(finput, use_paragraphs=False):
filename = Path(finput).name.split(".superconductors")[0]
with open(finput, encoding='utf-8') as fp:
doc = fp.read()

soup = BeautifulSoup(doc, 'xml')

paragraphs_grouped = get_nodes(soup, group_by_paragraph=True, use_paragraphs=use_paragraphs)

dic_dest_relationships = {}
dic_source_relationships = {}
ient = 1
i = 0
for para_id, paragraph in enumerate(paragraphs_grouped):
if use_paragraphs:
j = 0
for item in paragraph.contents:
if type(item) is Tag:
if 'type' not in item.attrs:
raise Exception("RS without type is invalid. Stopping")
entity_class = item.attrs['type']
entity_text = item.text

if len(item.attrs) > 0:
if 'xml:id' in item.attrs:
if item.attrs['xml:id'] not in dic_dest_relationships:
dic_dest_relationships[item.attrs['xml:id']] = [i + 1, j + 1, ient, entity_text,
entity_class, para_id, 'N/A']

if 'corresp' in item.attrs:
if (i + 1, j + 1) not in dic_source_relationships:
dic_source_relationships[i + 1, j + 1] = [item.attrs['corresp'].replace('#', ''), ient,
entity_text, entity_class, para_id, 'N/A']
j += 1
ient += 1
i += 1
else:
for sent_id, sentence in enumerate(paragraph):
j = 0
for item in sentence.contents:
if type(item) is Tag:
if 'type' not in item.attrs:
raise Exception("RS without type is invalid. Stopping")
entity_class = item.attrs['type']
entity_text = item.text

if len(item.attrs) > 0:
if 'xml:id' in item.attrs:
if item.attrs['xml:id'] not in dic_dest_relationships:
dic_dest_relationships[item.attrs['xml:id']] = [i + 1, j + 1, ient, entity_text,
entity_class, para_id, sent_id]

if 'corresp' in item.attrs:
if (i + 1, j + 1) not in dic_source_relationships:
dic_source_relationships[i + 1, j + 1] = [item.attrs['corresp'].replace('#', ''),
ient,
entity_text, entity_class, para_id,
sent_id]
j += 1
ient += 1
i += 1

output = []
output_idx = []

struct = {
'id': None,
'filename': None,
'paragraph_id': None,
'material': None,
'tcValue': None,
'pressure': None,
'me_method': None,
'sentence': None
}
mapping = {}

for label in list(struct.keys()):
if label not in mapping:
mapping[label] = {}

for par_num, token_num in dic_source_relationships:
source_item = dic_source_relationships[par_num, token_num]
source_entity_id = source_item[1]
source_id = str(par_num) + '-' + str(token_num)
source_text = source_item[2]
source_label = source_item[3]

# destination_xml_id: Use this to pick up information from dic_dest_relationship
destination_xml_id = source_item[0]

for des in destination_xml_id.split(","):
destination_item = dic_dest_relationships[str(des)]

destination_id = destination_item[2]
destination_text = destination_item[3]
destination_label = destination_item[4]
destination_para = destination_item[5]
destination_sent = destination_item[6]
if destination_label != label:
continue

# try:
# relationship_name = get_relationship_name(source_label, destination_label)
# except Exception as e:
# return []

if source_label not in mapping:
mapping[source_label] = {}

if destination_id in mapping[destination_label]:
indexes_in_output_table = mapping[destination_label][destination_id]
for index_in_output_table in indexes_in_output_table:
if source_label in output[index_in_output_table]:
row_copy = {key: value for key, value in output[index_in_output_table].items()}
row_copy[destination_label] = destination_text
row_copy[source_label] = source_text
row_copy['filename'] = filename
row_copy["paragraph_id"] = destination_para
row_copy["sentence_id"] = destination_sent
output.append(row_copy)
# output.append({destination_label: destination_text, source_label: source_text})
else:
output[index_in_output_table][source_label] = source_text
elif source_entity_id in mapping[source_label]:
indexes_in_output_table = mapping[source_label][source_entity_id]
for index_in_output_table in indexes_in_output_table:
if destination_label in output[index_in_output_table]:
# output.append({destination_label: destination_text, source_label: source_text})
# if source_label in output[index_in_output_table]:
# output.append({destination_label: destination_text, source_label: source_text})
# else:
row_copy = {key: value for key, value in output[index_in_output_table].items()}
row_copy[source_label] = source_text
row_copy[destination_label] = destination_text
row_copy['filename'] = filename
row_copy["paragraph_id"] = destination_para
row_copy["sentence_id"] = destination_sent
output.append(row_copy)
else:
output[index_in_output_table][destination_label] = destination_text
json = process_file_to_json(finput, use_paragraphs=use_paragraphs)

data_list = []
spans_map = {}
spans_links_map = {}
spans_links_reverse_map = {}

filename = Path(finput).name
passages = json['passages']
for passage_id, passage in tqdm(enumerate(passages)):
passage_common_parts = {
'id': passage_id,
'filename': filename,
'passage_id': str(passage['group_id']) + "|" + str(passage['id']) if 'group_id' in passage else str(
passage['id']),
'text': passage['text']
}

spans_passage_map, spans_passage_links_map, spans_passage_links_reverse_map = get_span_maps(passage)
spans_map = {**spans_map, **spans_passage_map}
spans_links_map = {**spans_links_map, **spans_passage_links_map}
spans_links_reverse_map = {**spans_links_reverse_map, **spans_passage_links_reverse_map}

spans_by_type = {}
for t in ['tcValue', 'material', 'pressure', 'me_method']:
spans_by_type[t] = {key: submap for key, submap in spans_passage_map.items() if
'type' in submap and submap['type'] == t}

records_in_passage = []
for span_id, span in spans_by_type['tcValue'].items():
# for span_id, span in spans_by_type[t].items():

outbound = spans_links_map[span_id] if span_id in spans_links_map.keys() else []
inbound = spans_links_reverse_map[span_id] if span_id in spans_links_reverse_map.keys() else []

for out in list(set(outbound + inbound)):
if out in spans_map:
span_out = spans_map[out]
records_to_update = list(filter(
lambda rip: span['type'] in rip and span_id == rip[span['type']][0] and span_out[
'type'] not in rip,
records_in_passage))

if len(records_to_update) == 0:
records_in_passage.append({'tcValue': (span_id, spans_map[span_id]['text']),
span_out['type']: (span_out['id'], span_out['text'])})
else:
for record_to_update in records_to_update:
record_to_update[span_out['type']] = (span_out['id'], span_out['text'])

for re in records_in_passage:
ids = [re[key][0] for key in re.keys()]
is_same_passage = all(id_value in spans_passage_map.keys() for id_value in ids)
for key, value in passage_common_parts.items():
if not is_same_passage:
if key != "text":
re[key] = value
else:
output.append(
{
destination_label: destination_text,
source_label: source_text,
'filename': filename,
paragraph_id: destination_para,
"sentence_id": destination_sent
}
)
output_idx.append(
{
destination_label: destination_id,
source_label: source_id,
'filename': filename,
paragraph_id: destination_para,
"sentence_id": destination_sent
}
)

current_index = len(output) - 1
if destination_id not in mapping[destination_label]:
mapping[destination_label][destination_id] = set()
mapping[destination_label][destination_id].add(current_index)
else:
mapping[destination_label][destination_id].add(current_index)
re[key] = value

data_list.extend(records_in_passage)

spans_by_type = {}
## Recover possible items not tcValue, that were not linked before
for ent_type in ['material', 'pressure', 'me_method']:
spans_by_type[ent_type] = {key: submap for key, submap in spans_map.items() if
'type' in submap and submap['type'] == ent_type}
for span_id, span in spans_by_type[ent_type].items():
outbound = spans_links_map[span_id] if span_id in spans_links_map.keys() else []
inbound = spans_links_reverse_map[span_id] if span_id in spans_links_reverse_map.keys() else []

for out in list(set(outbound + inbound)):
if out in spans_map:
span_out = spans_map[out]
records_to_update = list(filter(lambda rip: span['type'] not in rip and span_out['type'] in rip and span_out['id'] ==
rip[span_out['type']][0], data_list))

if len(records_to_update) > 0:
for record_to_update in records_to_update:
record_to_update[span['type']] = (span['id'], span['text'])
## This is incorrect for paragraphs
record_to_update['text'] = ""

for re in data_list:
for column in re.keys():
if column not in ['id', 'filename', 'passage_id', 'text']:
re[column] = re[column][1]

return data_list


def get_span_maps(passage):
span_links_map = {}
span_links_reverse_map = {}
span_map = {}

for span in passage['spans']:
if 'id' in span:
if span['id'] not in span_map:
span_map[span['id']] = span
if 'links' in span:
for link in span['links']:
target_id = link['targetId']
source_id = span['id']
if target_id not in span_links_map:
span_links_map[target_id] = [source_id]
else:
span_links_map[target_id].append(source_id)

if source_entity_id not in mapping[source_label]:
mapping[source_label][source_entity_id] = set()
mapping[source_label][source_entity_id].add(current_index)
else:
mapping[source_label][source_entity_id].add(current_index)
if source_id not in span_links_reverse_map:
span_links_reverse_map[source_id] = [target_id]
else:
span_links_reverse_map[source_id].append(target_id)
else:
print("error")

return output
return span_map, span_links_map, span_links_reverse_map


def write_output(data, output_path, format, use_paragraphs=False):
# passage_id_column_name = 'paragraph_id' if use_paragraphs else 'sentence_id'
delimiter = '\t' if format == 'tsv' else ','
fw = csv.writer(open(output_path, encoding='utf-8', mode='w'), delimiter=delimiter, quotechar='"')
columns = ['id', 'filename', 'paragraph_id', 'sentence_id', 'material', 'tcValue', 'pressure', 'me_method']
fw.writerow(columns)
for d in data:
fw.writerow([d[c] if c in d else '' for c in columns])
def write_output(output_path: List, data: List[dict], columns: List):
with open(output_path, encoding='utf-8', mode='w') as fo:
fw = csv.writer(fo, delimiter=",", quotechar='"')
fw.writerow(columns)
for d in data:
fw.writerow([d[c] if c in d else '' for c in columns])


if __name__ == '__main__':
Expand Down Expand Up @@ -233,12 +173,12 @@ def write_output(data, output_path, format, use_paragraphs=False):
if os.path.isdir(input):
path_list = get_in_paths_from_directory(input, ".xml", recursive=recursive)

data_sorted = []
output_data = []
for path in path_list:
print("Processing: ", path)
file_data = process_file(path, use_paragraphs=use_paragraphs)
data = sorted(file_data, key=lambda k: k[paragraph_id])
data_sorted.extend(data)
file_output_data = process_file(path, use_paragraphs=use_paragraphs)
# data = sorted(file_data, key=lambda k: k['passage_id'])
output_data.extend(file_output_data)

if os.path.isdir(str(output)):
output_path = os.path.join(output, "output") + "." + format
Expand All @@ -248,14 +188,15 @@ def write_output(data, output_path, format, use_paragraphs=False):

elif os.path.isfile(input):
input_path = Path(input)
data = process_file(input_path, use_paragraphs=use_paragraphs)
data_sorted = sorted(data, key=lambda k: k[paragraph_id])
output_data = process_file(input_path, use_paragraphs=use_paragraphs)
# data_sorted = sorted(data, key=lambda k: k['paragraph_id'])
output_filename = input_path.stem
output_path = os.path.join(output, str(output_filename) + "." + format)
else:
print("The input should be either a file or a directory")
sys.exit(-1)

data = [{**record, **{"id": idx}} for idx, record in enumerate(data_sorted)]
data = [{**record, **{"id": idx}} for idx, record in enumerate(output_data)]

write_output(data, output_path, format)
columns = ['id', 'filename', 'passage_id', 'material', 'tcValue', 'pressure', 'me_method', 'text']
write_output(output_path, data, columns)
2 changes: 1 addition & 1 deletion src/supermat/supermat_tei_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def process_paragraphs(paragraph_list: list) -> [List, List]:
if 'corresp' in item.attrs:
if 'id' not in span or span['id'] == "":
id_str = str(i + 1) + "," + str(j + 1)
span_id = get_hash(id_str)
span['id'] = get_hash(id_str)
if span['id'] not in dic_source_relationships:
dic_source_relationships[span['id']] = [item.attrs['corresp'].replace('#', ''),
ient,
Expand Down

0 comments on commit e42beeb

Please sign in to comment.