diff --git a/chatterbot/trainers.py b/chatterbot/trainers.py index 18cdf5ff5..48d2ab7ce 100644 --- a/chatterbot/trainers.py +++ b/chatterbot/trainers.py @@ -2,7 +2,6 @@ import sys import csv import time -from multiprocessing import Pool, Manager from dateutil import parser as date_parser from chatterbot.conversation import Statement from chatterbot.tagging import PosLemmaTagger @@ -174,41 +173,6 @@ def train(self, *corpus_paths): self.chatbot.storage.create_many(statements_to_create) -def read_file(files, queue, preprocessors, tagger): - - statements_from_file = [] - - for tsv_file in files: - with open(tsv_file, 'r', encoding='utf-8') as tsv: - reader = csv.reader(tsv, delimiter='\t') - - previous_statement_text = None - previous_statement_search_text = '' - - for row in reader: - if len(row) > 0: - statement = Statement( - text=row[3], - in_response_to=previous_statement_text, - conversation='training', - created_at=date_parser.parse(row[0]), - persona=row[1] - ) - - for preprocessor in preprocessors: - statement = preprocessor(statement) - - statement.search_text = tagger.get_bigram_pair_string(statement.text) - statement.search_in_response_to = previous_statement_search_text - - previous_statement_text = statement.text - previous_statement_search_text = statement.search_text - - statements_from_file.append(statement) - - queue.put(tuple(statements_from_file)) - - class UbuntuCorpusTrainer(Trainer): """ Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. @@ -337,9 +301,6 @@ def train(self): '**', '**', '*.tsv' ) - manager = Manager() - queue = manager.Queue() - def chunks(items, items_per_chunk): for start_index in range(0, len(items), items_per_chunk): end_index = start_index + items_per_chunk @@ -349,55 +310,40 @@ def chunks(items, items_per_chunk): file_groups = tuple(chunks(file_list, 10000)) - argument_groups = tuple( - ( - file_names, - queue, - self.chatbot.preprocessors, - tagger, - ) for file_names in file_groups - ) - - pool_batches = chunks(argument_groups, 9) - - total_batches = len(file_groups) - batch_number = 0 - start_time = time.time() - with Pool() as pool: - for pool_batch in pool_batches: - pool.starmap(read_file, pool_batch) + for tsv_files in file_groups: + + statements_from_file = [] - while True: + for tsv_file in tsv_files: + with open(tsv_file, 'r', encoding='utf-8') as tsv: + reader = csv.reader(tsv, delimiter='\t') - if queue.empty(): - break + previous_statement_text = None + previous_statement_search_text = '' - batch_number += 1 + for row in reader: + if len(row) > 0: + statement = Statement( + text=row[3], + in_response_to=previous_statement_text, + conversation='training', + created_at=date_parser.parse(row[0]), + persona=row[1] + ) - print('Training with batch {} with {} batches remaining...'.format( - batch_number, - total_batches - batch_number - )) + for preprocessor in self.chatbot.preprocessors: + statement = preprocessor(statement) - self.chatbot.storage.create_many(queue.get()) + statement.search_text = tagger.get_bigram_pair_string(statement.text) + statement.search_in_response_to = previous_statement_search_text - elapsed_time = time.time() - start_time - time_per_batch = elapsed_time / batch_number - remaining_time = time_per_batch * (total_batches - batch_number) + previous_statement_text = statement.text + previous_statement_search_text = statement.search_text - print('{:.0f} hours {:.0f} minutes {:.0f} seconds elapsed.'.format( - elapsed_time // 3600 % 24, - elapsed_time // 60 % 60, - elapsed_time % 60 - )) + statements_from_file.append(statement) - print('{:.0f} hours {:.0f} minutes {:.0f} seconds remaining.'.format( - remaining_time // 3600 % 24, - remaining_time // 60 % 60, - remaining_time % 60 - )) - print('---') + self.chatbot.storage.create_many(statements_from_file) print('Training took', time.time() - start_time, 'seconds.')