Skip to content

Commit

Permalink
Allow language paramters to be set for comparison methods
Browse files Browse the repository at this point in the history
  • Loading branch information
gunthercox committed Apr 6, 2019
1 parent f002e39 commit 21dfaf0
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 44 deletions.
12 changes: 6 additions & 6 deletions chatterbot/chatterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ class ChatBot(object):
def __init__(self, name, **kwargs):
self.name = name

primary_search_algorithm = IndexedTextSearch(self, **kwargs)

self.search_algorithms = {
primary_search_algorithm.name: primary_search_algorithm
}

storage_adapter = kwargs.get('storage_adapter', 'chatterbot.storage.SQLStorageAdapter')

logic_adapters = kwargs.get('logic_adapters', [
Expand All @@ -33,6 +27,12 @@ def __init__(self, name, **kwargs):

self.storage = utils.initialize_class(storage_adapter, **kwargs)

primary_search_algorithm = IndexedTextSearch(self, **kwargs)

self.search_algorithms = {
primary_search_algorithm.name: primary_search_algorithm
}

for adapter in logic_adapters:
utils.validate_adapter_class(adapter, LogicAdapter)
logic_adapter = utils.initialize_class(adapter, self, **kwargs)
Expand Down
25 changes: 8 additions & 17 deletions chatterbot/comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
This module contains various text-comparison algorithms
designed to compare one statement to another.
"""
from chatterbot import languages
from difflib import SequenceMatcher
import spacy


class Comparator:

def __init__(self, language):

self.language = language

def __call__(self, statement_a, statement_b):
return self.compare(statement_a, statement_b)

Expand Down Expand Up @@ -59,10 +62,8 @@ class SpacySimilarity(Comparator):
Calculate the similarity of two statements using Spacy models.
"""

def __init__(self):
super().__init__()

self.language = languages.ENG
def __init__(self, language):
super().__init__(language)

self.nlp = spacy.load(self.language.ISO_639_1)

Expand Down Expand Up @@ -105,10 +106,8 @@ class JaccardSimilarity(Comparator):
.. _`Jaccard similarity index`: https://en.wikipedia.org/wiki/Jaccard_index
"""

def __init__(self):
super().__init__()

self.language = languages.ENG
def __init__(self, language):
super().__init__(language)

self.nlp = spacy.load(self.language.ISO_639_1)

Expand All @@ -134,11 +133,3 @@ def compare(self, statement_a, statement_b):
ratio = numerator / denominator

return ratio


# ---------------------------------------- #


levenshtein_distance = LevenshteinDistance()
spacy_similarity = SpacySimilarity()
jaccard_similarity = JaccardSimilarity()
12 changes: 8 additions & 4 deletions chatterbot/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class IndexedTextSearch:
"""
:param statement_comparison_function: The dot-notated import path
to a statement comparison function.
Defaults to ``levenshtein_distance``.
Defaults to ``LevenshteinDistance``.
:param search_page_size:
The maximum number of records to load into memory at a time when searching.
Expand All @@ -15,13 +15,17 @@ class IndexedTextSearch:
name = 'indexed_text_search'

def __init__(self, chatbot, **kwargs):
from chatterbot.comparisons import levenshtein_distance
from chatterbot.comparisons import LevenshteinDistance

self.chatbot = chatbot

self.compare_statements = kwargs.get(
statement_comparison_function = kwargs.get(
'statement_comparison_function',
levenshtein_distance
LevenshteinDistance
)

self.compare_statements = statement_comparison_function(
language=self.chatbot.storage.tagger.language
)

self.search_page_size = kwargs.get(
Expand Down
2 changes: 1 addition & 1 deletion docs/chatterbot.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ which specifies the import path to the adapter class.
logic_adapters=[
{
'import_path': 'my.logic.AdapterClass1',
'statement_comparison_function': chatterbot.comparisons.levenshtein_distance
'statement_comparison_function': chatterbot.comparisons.LevenshteinDistance
'response_selection_method': chatterbot.response_selection.get_first_response
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs/comparisons.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ is shown below.
.. code-block:: python
from chatterbot import ChatBot
from chatterbot.comparisons import levenshtein_distance
from chatterbot.comparisons import LevenshteinDistance
chatbot = ChatBot(
# ...
statement_comparison_function=levenshtein_distance
statement_comparison_function=LevenshteinDistance
)
2 changes: 1 addition & 1 deletion docs/logic/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Setting parameters
logic_adapters=[
{
"import_path": "chatterbot.logic.BestMatch",
"statement_comparison_function": chatterbot.comparisons.levenshtein_distance,
"statement_comparison_function": chatterbot.comparisons.LevenshteinDistance,
"response_selection_method": chatterbot.response_selection.get_first_response
}
]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_levenshtein_distance_comparisons(self):
"""
self.chatbot.logic_adapters[0] = BestMatch(
self.chatbot,
statement_comparison_function=comparisons.levenshtein_distance,
statement_comparison_function=comparisons.LevenshteinDistance,
response_selection_method=response_selection.get_first_response
)

Expand All @@ -110,7 +110,7 @@ def test_spacy_similarity_comparisons(self):
"""
self.chatbot.logic_adapters[0] = BestMatch(
self.chatbot,
statement_comparison_function=comparisons.spacy_similarity,
statement_comparison_function=comparisons.SpacySimilarity,
response_selection_method=response_selection.get_first_response
)

Expand Down Expand Up @@ -155,7 +155,7 @@ def test_levenshtein_distance_comparisons(self):
"""
self.chatbot.logic_adapters[0] = BestMatch(
self.chatbot,
statement_comparison_function=comparisons.levenshtein_distance,
statement_comparison_function=comparisons.LevenshteinDistance,
response_selection_method=response_selection.get_first_response
)

Expand All @@ -170,7 +170,7 @@ def test_spacy_similarity_comparisons(self):
"""
self.chatbot.logic_adapters[0] = BestMatch(
self.chatbot,
statement_comparison_function=comparisons.spacy_similarity,
statement_comparison_function=comparisons.SpacySimilarity,
response_selection_method=response_selection.get_first_response
)

Expand Down
36 changes: 29 additions & 7 deletions tests/test_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,26 @@
from unittest import TestCase
from chatterbot.conversation import Statement
from chatterbot import comparisons
from chatterbot import languages


class LevenshteinDistanceTestCase(TestCase):

def setUp(self):
super().setUp()

self.compare = comparisons.LevenshteinDistance(
language=languages.ENG
)

def test_levenshtein_distance_statement_false(self):
"""
Falsy values should match by zero.
"""
statement = Statement(text='')
other_statement = Statement(text='Hello')

value = comparisons.levenshtein_distance(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertEqual(value, 0)

Expand All @@ -27,7 +35,7 @@ def test_levenshtein_distance_other_statement_false(self):
statement = Statement(text='Hello')
other_statement = Statement(text='')

value = comparisons.levenshtein_distance(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertEqual(value, 0)

Expand All @@ -39,7 +47,7 @@ def test_levenshtein_distance_statement_integer(self):
statement = Statement(text=2)
other_statement = Statement(text='Hello')

value = comparisons.levenshtein_distance(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertEqual(value, 0)

Expand All @@ -50,21 +58,28 @@ def test_exact_match_different_capitalization(self):
statement = Statement(text='Hi HoW ArE yOu?')
other_statement = Statement(text='hI hOw are YoU?')

value = comparisons.levenshtein_distance(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertEqual(value, 1)


class SpacySimilarityTests(TestCase):

def setUp(self):
super().setUp()

self.compare = comparisons.SpacySimilarity(
language=languages.ENG
)

def test_exact_match_different_stopwords(self):
"""
Test sentences with different stopwords.
"""
statement = Statement(text='What is matter?')
other_statement = Statement(text='What is the matter?')

value = comparisons.spacy_similarity(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertAlmostEqual(value, 0.9, places=1)

Expand All @@ -75,20 +90,27 @@ def test_exact_match_different_capitalization(self):
statement = Statement(text='Hi HoW ArE yOu?')
other_statement = Statement(text='hI hOw are YoU?')

value = comparisons.spacy_similarity(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertAlmostEqual(value, 0.8, places=1)


class JaccardSimilarityTestCase(TestCase):

def setUp(self):
super().setUp()

self.compare = comparisons.JaccardSimilarity(
language=languages.ENG
)

def test_exact_match_different_capitalization(self):
"""
Test that text capitalization is ignored.
"""
statement = Statement(text='Hi HoW ArE yOu?')
other_statement = Statement(text='hI hOw are YoU?')

value = comparisons.jaccard_similarity(statement, other_statement)
value = self.compare(statement, other_statement)

self.assertEqual(value, 1)
4 changes: 2 additions & 2 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self):
super().setUp()
self.search_algorithm = IndexedTextSearch(
self.chatbot,
statement_comparison_function=comparisons.spacy_similarity
statement_comparison_function=comparisons.SpacySimilarity
)

def test_get_closest_statement(self):
Expand Down Expand Up @@ -107,7 +107,7 @@ def setUp(self):
super().setUp()
self.search_algorithm = IndexedTextSearch(
self.chatbot,
statement_comparison_function=comparisons.levenshtein_distance
statement_comparison_function=comparisons.LevenshteinDistance
)

def test_get_closest_statement(self):
Expand Down

0 comments on commit 21dfaf0

Please sign in to comment.