diff --git a/chatterbot/chatterbot.py b/chatterbot/chatterbot.py index 054309543..cec453386 100644 --- a/chatterbot/chatterbot.py +++ b/chatterbot/chatterbot.py @@ -13,12 +13,6 @@ class ChatBot(object): def __init__(self, name, **kwargs): self.name = name - primary_search_algorithm = IndexedTextSearch(self, **kwargs) - - self.search_algorithms = { - primary_search_algorithm.name: primary_search_algorithm - } - storage_adapter = kwargs.get('storage_adapter', 'chatterbot.storage.SQLStorageAdapter') logic_adapters = kwargs.get('logic_adapters', [ @@ -33,6 +27,12 @@ def __init__(self, name, **kwargs): self.storage = utils.initialize_class(storage_adapter, **kwargs) + primary_search_algorithm = IndexedTextSearch(self, **kwargs) + + self.search_algorithms = { + primary_search_algorithm.name: primary_search_algorithm + } + for adapter in logic_adapters: utils.validate_adapter_class(adapter, LogicAdapter) logic_adapter = utils.initialize_class(adapter, self, **kwargs) diff --git a/chatterbot/comparisons.py b/chatterbot/comparisons.py index 206403fb9..978f76354 100644 --- a/chatterbot/comparisons.py +++ b/chatterbot/comparisons.py @@ -2,13 +2,16 @@ This module contains various text-comparison algorithms designed to compare one statement to another. """ -from chatterbot import languages from difflib import SequenceMatcher import spacy class Comparator: + def __init__(self, language): + + self.language = language + def __call__(self, statement_a, statement_b): return self.compare(statement_a, statement_b) @@ -59,10 +62,8 @@ class SpacySimilarity(Comparator): Calculate the similarity of two statements using Spacy models. """ - def __init__(self): - super().__init__() - - self.language = languages.ENG + def __init__(self, language): + super().__init__(language) self.nlp = spacy.load(self.language.ISO_639_1) @@ -105,10 +106,8 @@ class JaccardSimilarity(Comparator): .. _`Jaccard similarity index`: https://en.wikipedia.org/wiki/Jaccard_index """ - def __init__(self): - super().__init__() - - self.language = languages.ENG + def __init__(self, language): + super().__init__(language) self.nlp = spacy.load(self.language.ISO_639_1) @@ -134,11 +133,3 @@ def compare(self, statement_a, statement_b): ratio = numerator / denominator return ratio - - -# ---------------------------------------- # - - -levenshtein_distance = LevenshteinDistance() -spacy_similarity = SpacySimilarity() -jaccard_similarity = JaccardSimilarity() diff --git a/chatterbot/search.py b/chatterbot/search.py index bd946f4b1..f3af44b7c 100644 --- a/chatterbot/search.py +++ b/chatterbot/search.py @@ -5,7 +5,7 @@ class IndexedTextSearch: """ :param statement_comparison_function: The dot-notated import path to a statement comparison function. - Defaults to ``levenshtein_distance``. + Defaults to ``LevenshteinDistance``. :param search_page_size: The maximum number of records to load into memory at a time when searching. @@ -15,13 +15,17 @@ class IndexedTextSearch: name = 'indexed_text_search' def __init__(self, chatbot, **kwargs): - from chatterbot.comparisons import levenshtein_distance + from chatterbot.comparisons import LevenshteinDistance self.chatbot = chatbot - self.compare_statements = kwargs.get( + statement_comparison_function = kwargs.get( 'statement_comparison_function', - levenshtein_distance + LevenshteinDistance + ) + + self.compare_statements = statement_comparison_function( + language=self.chatbot.storage.tagger.language ) self.search_page_size = kwargs.get( diff --git a/docs/chatterbot.rst b/docs/chatterbot.rst index 0f267233d..172a688a3 100644 --- a/docs/chatterbot.rst +++ b/docs/chatterbot.rst @@ -58,7 +58,7 @@ which specifies the import path to the adapter class. logic_adapters=[ { 'import_path': 'my.logic.AdapterClass1', - 'statement_comparison_function': chatterbot.comparisons.levenshtein_distance + 'statement_comparison_function': chatterbot.comparisons.LevenshteinDistance 'response_selection_method': chatterbot.response_selection.get_first_response }, { diff --git a/docs/comparisons.rst b/docs/comparisons.rst index 1c26ede1f..8934aa299 100644 --- a/docs/comparisons.rst +++ b/docs/comparisons.rst @@ -43,9 +43,9 @@ is shown below. .. code-block:: python from chatterbot import ChatBot - from chatterbot.comparisons import levenshtein_distance + from chatterbot.comparisons import LevenshteinDistance chatbot = ChatBot( # ... - statement_comparison_function=levenshtein_distance + statement_comparison_function=LevenshteinDistance ) diff --git a/docs/logic/index.rst b/docs/logic/index.rst index 8313ad611..80da8af17 100644 --- a/docs/logic/index.rst +++ b/docs/logic/index.rst @@ -67,7 +67,7 @@ Setting parameters logic_adapters=[ { "import_path": "chatterbot.logic.BestMatch", - "statement_comparison_function": chatterbot.comparisons.levenshtein_distance, + "statement_comparison_function": chatterbot.comparisons.LevenshteinDistance, "response_selection_method": chatterbot.response_selection.get_first_response } ] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 4cea9a68f..1ce5367c8 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -95,7 +95,7 @@ def test_levenshtein_distance_comparisons(self): """ self.chatbot.logic_adapters[0] = BestMatch( self.chatbot, - statement_comparison_function=comparisons.levenshtein_distance, + statement_comparison_function=comparisons.LevenshteinDistance, response_selection_method=response_selection.get_first_response ) @@ -110,7 +110,7 @@ def test_spacy_similarity_comparisons(self): """ self.chatbot.logic_adapters[0] = BestMatch( self.chatbot, - statement_comparison_function=comparisons.spacy_similarity, + statement_comparison_function=comparisons.SpacySimilarity, response_selection_method=response_selection.get_first_response ) @@ -155,7 +155,7 @@ def test_levenshtein_distance_comparisons(self): """ self.chatbot.logic_adapters[0] = BestMatch( self.chatbot, - statement_comparison_function=comparisons.levenshtein_distance, + statement_comparison_function=comparisons.LevenshteinDistance, response_selection_method=response_selection.get_first_response ) @@ -170,7 +170,7 @@ def test_spacy_similarity_comparisons(self): """ self.chatbot.logic_adapters[0] = BestMatch( self.chatbot, - statement_comparison_function=comparisons.spacy_similarity, + statement_comparison_function=comparisons.SpacySimilarity, response_selection_method=response_selection.get_first_response ) diff --git a/tests/test_comparisons.py b/tests/test_comparisons.py index 03299f91d..e9caa4a6c 100644 --- a/tests/test_comparisons.py +++ b/tests/test_comparisons.py @@ -5,10 +5,18 @@ from unittest import TestCase from chatterbot.conversation import Statement from chatterbot import comparisons +from chatterbot import languages class LevenshteinDistanceTestCase(TestCase): + def setUp(self): + super().setUp() + + self.compare = comparisons.LevenshteinDistance( + language=languages.ENG + ) + def test_levenshtein_distance_statement_false(self): """ Falsy values should match by zero. @@ -16,7 +24,7 @@ def test_levenshtein_distance_statement_false(self): statement = Statement(text='') other_statement = Statement(text='Hello') - value = comparisons.levenshtein_distance(statement, other_statement) + value = self.compare(statement, other_statement) self.assertEqual(value, 0) @@ -27,7 +35,7 @@ def test_levenshtein_distance_other_statement_false(self): statement = Statement(text='Hello') other_statement = Statement(text='') - value = comparisons.levenshtein_distance(statement, other_statement) + value = self.compare(statement, other_statement) self.assertEqual(value, 0) @@ -39,7 +47,7 @@ def test_levenshtein_distance_statement_integer(self): statement = Statement(text=2) other_statement = Statement(text='Hello') - value = comparisons.levenshtein_distance(statement, other_statement) + value = self.compare(statement, other_statement) self.assertEqual(value, 0) @@ -50,13 +58,20 @@ def test_exact_match_different_capitalization(self): statement = Statement(text='Hi HoW ArE yOu?') other_statement = Statement(text='hI hOw are YoU?') - value = comparisons.levenshtein_distance(statement, other_statement) + value = self.compare(statement, other_statement) self.assertEqual(value, 1) class SpacySimilarityTests(TestCase): + def setUp(self): + super().setUp() + + self.compare = comparisons.SpacySimilarity( + language=languages.ENG + ) + def test_exact_match_different_stopwords(self): """ Test sentences with different stopwords. @@ -64,7 +79,7 @@ def test_exact_match_different_stopwords(self): statement = Statement(text='What is matter?') other_statement = Statement(text='What is the matter?') - value = comparisons.spacy_similarity(statement, other_statement) + value = self.compare(statement, other_statement) self.assertAlmostEqual(value, 0.9, places=1) @@ -75,13 +90,20 @@ def test_exact_match_different_capitalization(self): statement = Statement(text='Hi HoW ArE yOu?') other_statement = Statement(text='hI hOw are YoU?') - value = comparisons.spacy_similarity(statement, other_statement) + value = self.compare(statement, other_statement) self.assertAlmostEqual(value, 0.8, places=1) class JaccardSimilarityTestCase(TestCase): + def setUp(self): + super().setUp() + + self.compare = comparisons.JaccardSimilarity( + language=languages.ENG + ) + def test_exact_match_different_capitalization(self): """ Test that text capitalization is ignored. @@ -89,6 +111,6 @@ def test_exact_match_different_capitalization(self): statement = Statement(text='Hi HoW ArE yOu?') other_statement = Statement(text='hI hOw are YoU?') - value = comparisons.jaccard_similarity(statement, other_statement) + value = self.compare(statement, other_statement) self.assertEqual(value, 1) diff --git a/tests/test_search.py b/tests/test_search.py index 6dcf3b2f7..08f77564d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -60,7 +60,7 @@ def setUp(self): super().setUp() self.search_algorithm = IndexedTextSearch( self.chatbot, - statement_comparison_function=comparisons.spacy_similarity + statement_comparison_function=comparisons.SpacySimilarity ) def test_get_closest_statement(self): @@ -107,7 +107,7 @@ def setUp(self): super().setUp() self.search_algorithm = IndexedTextSearch( self.chatbot, - statement_comparison_function=comparisons.levenshtein_distance + statement_comparison_function=comparisons.LevenshteinDistance ) def test_get_closest_statement(self):