-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* RL Agent Fixes #215 * Update docstring * Black formatting * Address review comments * black * Address review comments * Black
- Loading branch information
Showing
3 changed files
with
259 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""MovieBot agent for reinforcement learning of dialogue policy.""" | ||
import logging | ||
from typing import Any, Dict, List, Tuple | ||
|
||
from dialoguekit.core import AnnotatedUtterance, Intent | ||
from dialoguekit.participant import DialogueParticipant | ||
from moviebot.agent.agent import MovieBotAgent | ||
from moviebot.core.core_types import DialogueOptions | ||
from moviebot.core.utterance.utterance import UserUtterance | ||
from moviebot.dialogue_manager.dialogue_act import DialogueAct | ||
from moviebot.nlu.neural_nlu import NeuralNLU | ||
from reinforcement_learning.agent.rl_dialogue_manager import DialogueManagerRL | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class MovieBotAgentRL(MovieBotAgent): | ||
def __init__(self, config: Dict[str, Any] = None) -> None: | ||
"""Initializes a MovieBot agent for reinforcement learning. | ||
Args: | ||
config: Configuration. Defaults to None. | ||
""" | ||
super().__init__( | ||
config=config, | ||
) | ||
|
||
self.dialogue_manager = DialogueManagerRL( | ||
self.data_config, self.isBot, self.new_user | ||
) | ||
|
||
if config.get("nlu_type", "") == "neural": | ||
self.nlu = NeuralNLU(None) | ||
|
||
def initialize(self) -> None: | ||
"""Initializes the agent.""" | ||
self.dialogue_manager.initialize() | ||
|
||
def generate_utterance( | ||
self, | ||
agent_dacts: List[DialogueAct], | ||
options: DialogueOptions = {}, | ||
user_fname: str = None, | ||
recommended_item: Dict[str, Any] = None, | ||
) -> Tuple[AnnotatedUtterance, DialogueOptions]: | ||
"""Generates an utterance object based on agent dialogue acts. | ||
Args: | ||
agent_dacts: Agent dialogue acts. | ||
options: Dialogue options that are provided to the user along with | ||
the utterance. | ||
user_fname: User's first name. Defaults to None. | ||
recommended_item: Recommended item. Defaults to None. | ||
Returns: | ||
An annotated utterance and the associated options. | ||
""" | ||
metadata = {"options": options} | ||
if recommended_item: | ||
metadata.update({"recommended_item": recommended_item}) | ||
|
||
agent_response, options = self.nlg.generate_output( | ||
agent_dacts, | ||
self.dialogue_manager.get_state(), | ||
user_fname=user_fname, | ||
) | ||
agent_intent = Intent( | ||
";".join([da.intent.value.label for da in agent_dacts]) | ||
) | ||
|
||
if not self.isBot: | ||
logger.debug( | ||
str(self.dialogue_manager.dialogue_state_tracker.dialogue_state) | ||
) | ||
else: | ||
record_data = self.dialogue_manager.get_state().to_dict() | ||
metadata.update({"record_data": record_data}) | ||
|
||
utterance = AnnotatedUtterance( | ||
intent=agent_intent, | ||
text=agent_response, | ||
participant=DialogueParticipant.AGENT, | ||
annotations=[], | ||
metadata=metadata, | ||
) | ||
|
||
return utterance, options | ||
|
||
def get_user_dialogue_acts( | ||
self, | ||
user_utterance: UserUtterance, | ||
utterance_options: DialogueOptions, | ||
) -> List[DialogueAct]: | ||
"""Generates dialogue acts associated to a given user utterance. | ||
Args: | ||
user_utterance: User utterance. | ||
utterance_options: Dialogue options that are provided to the user | ||
along with the utterance. | ||
Returns: | ||
List of dialogue acts. | ||
""" | ||
return self.nlu.generate_dacts( | ||
user_utterance, | ||
utterance_options, | ||
self.dialogue_manager.get_state(), | ||
) | ||
|
||
def welcome(self, user_fname: str = None) -> None: | ||
"""Sends a welcome message to the user. | ||
This method is not used for reinforcement learning. | ||
""" | ||
pass | ||
|
||
def goodbye(self) -> None: | ||
"""Sends a goodbye message to the user. | ||
This method is not used for reinforcement learning. | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
"""Dialogue manager used when training a dialogue policy with RL.""" | ||
|
||
|
||
import logging | ||
from typing import Any, Dict, List | ||
|
||
from moviebot.core.intents import AgentIntents | ||
from moviebot.core.intents.user_intents import UserIntents | ||
from moviebot.dialogue_manager.dialogue_act import DialogueAct | ||
from moviebot.dialogue_manager.dialogue_manager import DialogueManager | ||
from moviebot.nlu.annotation.item_constraint import ItemConstraint | ||
from moviebot.nlu.annotation.operator import Operator | ||
from moviebot.nlu.annotation.slots import Slots | ||
|
||
|
||
class DialogueManagerRL(DialogueManager): | ||
def initialize(self) -> None: | ||
"""Initializes the dialogue manager.""" | ||
self.dialogue_state_tracker.initialize() | ||
|
||
def recommend(self) -> Dict[str, Any]: | ||
"""Recommends a movie and updates state. | ||
Returns: | ||
Recommended movie. | ||
""" | ||
# accesses the database to fetch results if required | ||
recommended_movies = self.recommender.recommend_items(self.get_state()) | ||
|
||
self.dialogue_state_tracker.update_state_db( | ||
recommended_movies, | ||
self.recommender.get_previous_recommend_items(), | ||
) | ||
if recommended_movies: | ||
self.dialogue_state_tracker.dialogue_state.item_in_focus = ( | ||
recommended_movies[0] | ||
) | ||
return recommended_movies[0] | ||
return None | ||
|
||
def replace_placeholders( | ||
self, | ||
dialogue_act: DialogueAct, | ||
user_dialogue_acts: List[DialogueAct], | ||
recommendation: Dict[str, Any], | ||
) -> DialogueAct: | ||
"""Replaces the placeholders in the dialogue act with actual values. | ||
Args: | ||
dialogue_act: Dialogue act. | ||
user_dialogue_acts: User dialogue acts. | ||
recommendation: Recommended movie. | ||
Returns: | ||
Dialogue act with placeholders replaced. | ||
""" | ||
if ( | ||
dialogue_act.intent == AgentIntents.RECOMMEND | ||
or dialogue_act.intent == AgentIntents.CONTINUE_RECOMMENDATION | ||
): | ||
dialogue_act.params = [ | ||
ItemConstraint( | ||
Slots.TITLE.value, | ||
Operator.EQ, | ||
recommendation[Slots.TITLE.value], | ||
) | ||
] | ||
elif dialogue_act.intent == AgentIntents.INFORM: | ||
for user_dact in user_dialogue_acts: | ||
if user_dact.intent == UserIntents.INQUIRE: | ||
params = [ | ||
ItemConstraint( | ||
param.slot, | ||
Operator.EQ, | ||
recommendation[ | ||
Slots.TITLE.value | ||
if param.slot == Slots.MORE_INFO.value | ||
else param.slot | ||
], | ||
) | ||
for param in user_dact.params | ||
] or [ | ||
ItemConstraint( | ||
"deny", | ||
Operator.EQ, | ||
recommendation[Slots.TITLE.value], | ||
) | ||
] | ||
dialogue_act.params = params | ||
elif dialogue_act.intent == AgentIntents.COUNT_RESULTS: | ||
results = self.get_state().database_result | ||
dialogue_act.params = [ | ||
ItemConstraint( | ||
"count", Operator.EQ, len(results) if results else 0 | ||
) | ||
] | ||
return dialogue_act | ||
|
||
def get_filled_dialogue_acts( | ||
self, dialogue_acts: List[DialogueAct] | ||
) -> List[DialogueAct]: | ||
"""Returns the dialogue acts with filled placeholders. | ||
For example, if the agent replies with a dialogue act with the intent | ||
RECOMMEND, the title will be added as a constraint to the dialogue act: | ||
DialogueAct( | ||
AgentIntents.RECOMMEND, | ||
[ItemConstraint("title", Operator.EQ,"The Matrix")] | ||
) | ||
Args: | ||
dialogue_acts: Dialogue acts. | ||
Raises: | ||
Exception: If the dialogue act cannot be filled. | ||
Returns: | ||
Dialogue acts with filled placeholders. | ||
""" | ||
filled_dialogue_acts = [] | ||
|
||
recommendation = ( | ||
self.recommend() | ||
if AgentIntents.RECOMMEND in [dact.intent for dact in dialogue_acts] | ||
else self.get_state().item_in_focus | ||
) | ||
|
||
user_dacts = self.get_state().last_user_dacts | ||
for dialogue_act in dialogue_acts: | ||
try: | ||
_dialogue_act = self.replace_placeholders( | ||
dialogue_act, user_dacts, recommendation | ||
) | ||
filled_dialogue_acts.append(_dialogue_act) | ||
except Exception as e: | ||
logging.error(e, exc_info=True) | ||
return filled_dialogue_acts |