diff --git a/interest/utils.py b/interest/utils.py index 4f9a12e..c601162 100644 --- a/interest/utils.py +++ b/interest/utils.py @@ -156,3 +156,19 @@ def save_filtered_articles(input_file: Any, article_id: str, print('output_fp', output_fp) with open(output_fp, "w", encoding=ENCODING) as json_file: json.dump(data, json_file, indent=4) + + +def get_file_name_without_extension(full_path: str) -> str: + """ + Extracts the file name without extension from a full path. + + Args: + full_path (str): The full path of the file. + + Returns: + str: The file name without extension. + + """ + base_name = os.path.basename(full_path) + file_name_without_ext = os.path.splitext(base_name)[0] + return file_name_without_ext diff --git a/scripts/step4_generate_output.py b/scripts/step4_generate_output.py index b71c94c..ad89f61 100644 --- a/scripts/step4_generate_output.py +++ b/scripts/step4_generate_output.py @@ -9,14 +9,14 @@ from pandas import DataFrame from interest.settings import SPACY_MODEL from interest.article_final_selection.process_article import ArticleProcessor -from interest.utils import read_config +from interest.utils import read_config, get_file_name_without_extension from interest.output_generator.text_formater import (TextFormatter, SEGMENTED_TEXT_FORMATTER) FILE_PATH_FIELD = "file_path" -ARTICLE_ID_FIELD = "article_id" TITLE_FIELD = "title" +ARTICLE_ID_FIELD = "article_id" BODY_FIELD = "body" LABEL_FIELD = "label" SELECTED_FIELD = "selected" @@ -130,14 +130,11 @@ def find_articles_in_file(filepath: str, formatter: TextFormatter) -> ( SENTENCES_PER_SEGMENT = str(read_config(args.config_path, SENTENCE_PER_SEGMENT_KEY)) - result_df = pd.DataFrame(columns=[FILE_PATH_FIELD, ARTICLE_ID_FIELD, - TITLE_FIELD, BODY_FIELD, LABEL_FIELD]) - text_formatter = TextFormatter(str(output_unit), int(SENTENCES_PER_SEGMENT), spacy_model=SPACY_MODEL) for articles_filepath in args.input_dir.rglob(args.glob): df = find_articles_in_file(articles_filepath, text_formatter) - result_df = pd.concat([result_df, df], ignore_index=True) + file_name = get_file_name_without_extension(articles_filepath) + df.to_csv(os.path.join(args.output_dir, 'articles_to_label_'+file_name+'.csv'), index = False) - result_df.to_csv(os.path.join(args.output_dir, 'articles_to_label.csv'))