Skip to content

Commit

Permalink
Merge pull request #7423 from hotzenklotz/conf_mat_response_selector
Browse files Browse the repository at this point in the history
Use response selector keys for confusion matrix labelling
  • Loading branch information
dakshvar22 committed Dec 15, 2020
2 parents bb2247c + 021bcb1 commit a9fd74f
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog/7423.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use response selector keys (sub-intents) as labels for plotting the confusion matrix during NLU evaluation to improve readability.
31 changes: 18 additions & 13 deletions rasa/nlu/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ENTITY_ATTRIBUTE_TYPE,
ENTITY_ATTRIBUTE_GROUP,
ENTITY_ATTRIBUTE_ROLE,
RESPONSE_KEY_ATTRIBUTE,
)
from rasa.model import get_model
from rasa.nlu import config, training_data, utils
Expand Down Expand Up @@ -62,7 +63,7 @@

ResponseSelectionEvaluationResult = namedtuple(
"ResponseSelectionEvaluationResult",
"intent_target " "response_target " "response_prediction " "message " "confidence",
"response_key response_prediction_full_intent message confidence",
)

EntityEvaluationResult = namedtuple(
Expand Down Expand Up @@ -235,10 +236,10 @@ def remove_empty_response_examples(
for r in response_results:
# substitute None values with empty string
# to enable sklearn evaluation
if r.response_prediction is None:
r = r._replace(response_prediction="")
if r.response_prediction_full_intent is None:
r = r._replace(response_prediction_full_intent="")

if r.response_target != "" and r.response_target is not None:
if r.response_key != "" and r.response_key is not None:
filtered.append(r)

return filtered
Expand Down Expand Up @@ -373,7 +374,7 @@ def evaluate_response_selections(
)

target_responses, predicted_responses = _targets_predictions_from(
response_selection_results, "response_target", "response_prediction"
response_selection_results, "response_key", "response_prediction_full_intent"
)

if report_folder:
Expand Down Expand Up @@ -409,9 +410,8 @@ def evaluate_response_selections(
predictions = [
{
"text": res.message,
"intent_target": res.intent_target,
"response_target": res.response_target,
"response_predicted": res.response_prediction,
"response_key": res.response_key,
"response_predicted": res.response_prediction_full_intent,
"confidence": res.confidence,
}
for res in response_selection_results
Expand Down Expand Up @@ -1050,13 +1050,16 @@ def get_eval_data(
response_prediction_key, {}
).get(OPEN_UTTERANCE_PREDICTION_KEY, {})

response_target = example.get("response", "")
response_prediction_full_intent = selector_properties.get(
response_prediction_key, {}
).get("full_retrieval_intent", {})

response_key = example.get_combined_intent_response_key()

response_selection_results.append(
ResponseSelectionEvaluationResult(
intent_target,
response_target,
response_prediction.get("name"),
response_key,
response_prediction_full_intent,
result.get("text", {}),
response_prediction.get("confidence"),
)
Expand Down Expand Up @@ -1486,7 +1489,9 @@ def compute_metrics(
response_selection_metrics = {}
if response_selection_results:
response_selection_metrics = _compute_metrics(
response_selection_results, "response_target", "response_prediction"
response_selection_results,
"response_key",
"response_prediction_full_intent",
)

return (
Expand Down
34 changes: 9 additions & 25 deletions tests/nlu/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,18 +470,10 @@ def test_response_evaluation_report(tmpdir_factory):

response_results = [
ResponseSelectionEvaluationResult(
"chitchat",
"It's sunny in Berlin",
"It's sunny in Berlin",
"What's the weather",
0.65432,
"weather/ask_weather", "weather/ask_weather", "What's the weather", 0.65432,
),
ResponseSelectionEvaluationResult(
"chitchat",
"My name is Mr.bot",
"My name is Mr.bot",
"What's your name?",
0.98765,
"faq/ask_name", "faq/ask_name", "What's your name?", 0.98765,
),
]

Expand All @@ -499,14 +491,13 @@ def test_response_evaluation_report(tmpdir_factory):

prediction = {
"text": "What's your name?",
"intent_target": "chitchat",
"response_target": "My name is Mr.bot",
"response_predicted": "My name is Mr.bot",
"response_key": "faq/ask_name",
"response_predicted": "faq/ask_name",
"confidence": 0.98765,
}

assert len(report.keys()) == 5
assert report["My name is Mr.bot"] == name_query_results
assert report["faq/ask_name"] == name_query_results
assert result["predictions"][1] == prediction


Expand Down Expand Up @@ -590,23 +581,16 @@ def test_empty_intent_removal():

def test_empty_response_removal():
response_results = [
ResponseSelectionEvaluationResult(None, None, "What's the weather", 0.65432,),
ResponseSelectionEvaluationResult(
"chitchat", None, "It's sunny in Berlin", "What's the weather", 0.65432
),
ResponseSelectionEvaluationResult(
"chitchat",
"My name is Mr.bot",
"My name is Mr.bot",
"What's your name?",
0.98765,
"faq/ask_name", "faq/ask_name", "What's your name?", 0.98765,
),
]
response_results = remove_empty_response_examples(response_results)

assert len(response_results) == 1
assert response_results[0].intent_target == "chitchat"
assert response_results[0].response_target == "My name is Mr.bot"
assert response_results[0].response_prediction == "My name is Mr.bot"
assert response_results[0].response_key == "faq/ask_name"
assert response_results[0].response_prediction_full_intent == "faq/ask_name"
assert response_results[0].confidence == 0.98765
assert response_results[0].message == "What's your name?"

Expand Down

0 comments on commit a9fd74f

Please sign in to comment.