diff --git a/ingredient_extraction/train/train.py b/ingredient_extraction/train/train.py index be63527e..dfff84ba 100644 --- a/ingredient_extraction/train/train.py +++ b/ingredient_extraction/train/train.py @@ -3,6 +3,7 @@ import html import os from pathlib import Path +from typing import Dict, List import evaluate import numpy as np @@ -30,7 +31,7 @@ label_list = list(id2label.values()) -def convert_pipeline_output_to_html(text: str, output: list[dict]): +def convert_pipeline_output_to_html(text: str, output: List[dict]): html_str = "" previous_idx = 0 for item in output: @@ -157,7 +158,7 @@ def tokenize_and_align_labels(examples, tokenizer): def display_labeled_sequence( - tokens: list[str], labels: list[int], id2label: dict[int, str] + tokens: List[str], labels: List[int], id2label: Dict[int, str] ): assert len(tokens) == len(labels) output = []