Skip to content

Commit

Permalink
Update run script to support A2C
Browse files Browse the repository at this point in the history
  • Loading branch information
NoB0 committed Dec 6, 2023
1 parent 30b1659 commit 3c91016
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
8 changes: 7 additions & 1 deletion docs/source/reinforcement_learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,11 @@ Execute the following command to test a dialogue policy:

.. code-block:: bash
python -m reinforcement_learning.run_dialogue_policy --agent-config <path to IAI MovieBot configuration> --artifact-name <W&B artifact name> --model-path <path to saved model in W&B>
python -m reinforcement_learning.run_dialogue_policy --agent-config <path to IAI MovieBot configuration> --artifact-name <W&B artifact name> --model-path <path to saved model in W&B> --policy-type <dqn OR a2c>
For more information about the arguments, execute the following command:

.. code-block:: bash
python -m reinforcement_learning.run_dialogue_policy -h
12 changes: 9 additions & 3 deletions moviebot/dialogue_manager/dialogue_policy/a2c_dialogue_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
NeuralDialoguePolicy,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class A2CDialoguePolicy(NeuralDialoguePolicy):
def __init__(
Expand All @@ -30,7 +32,9 @@ def __init__(
num_timesteps: The number of timesteps. Defaults to None.
n_envs: The number of environments. Defaults to 1.
"""
super().__init__(input_size, hidden_size, output_size, possible_actions)
super().__init__(
input_size, hidden_size, output_size, possible_actions
)

self.n_envs = n_envs

Expand Down Expand Up @@ -63,7 +67,9 @@ def __init__(
self.critic_optimizer, total_iters=num_timesteps
)

def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(
self, state: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass.
Args:
Expand Down Expand Up @@ -194,7 +200,7 @@ def load_policy(cls, path: str) -> A2CDialoguePolicy:
Returns:
The loaded policy.
"""
state_dict = torch.load(path)
state_dict = torch.load(path, map_location=DEVICE)
policy = cls(
state_dict["input_size"],
state_dict["hidden_size"],
Expand Down
8 changes: 7 additions & 1 deletion reinforcement_learning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,11 @@ A dialogue policy can be tested with a human user via the terminal. Instead of i
Execute the following command to test a dialogue policy:

```bash
python -m reinforcement_learning.run_dialogue_policy --agent-config <path to IAI MovieBot configuration> --artifact-name <W&B artifact name> --model-path <path to saved model in W&B>
python -m reinforcement_learning.run_dialogue_policy --agent-config <path to IAI MovieBot configuration> --artifact-name <W&B artifact name> --model-path <path to saved model in W&B> --policy-type <dqn OR a2c>
```

For more information about the arguments, execute the following command:

```bash
python -m reinforcement_learning.run_dialogue_policy -h
```
24 changes: 22 additions & 2 deletions reinforcement_learning/run_dialogue_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import torch

import wandb
from moviebot.dialogue_manager.dialogue_policy import DQNDialoguePolicy
from moviebot.dialogue_manager.dialogue_policy import (
A2CDialoguePolicy,
DQNDialoguePolicy,
)
from moviebot.domain.movie_domain import MovieDomain
from reinforcement_learning.environment import DialogueEnvMovieBot
from reinforcement_learning.utils import define_possible_actions, get_config
Expand All @@ -26,16 +29,24 @@ def parse_args(args: str = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(prog="run_dialogue_policy.py")
parser.add_argument(
"--agent-config",
required=True,
help="Path to the agent configuration file.",
)
parser.add_argument(
"--artifact-name",
required=True,
help="W&B artifact name.",
)
parser.add_argument(
"--model-path",
required=True,
help="Path to the model file in W&B.",
)
parser.add_argument(
"--policy-type",
required=True,
help="Type of the policy, either 'dqn' or 'a2c'.",
)
return parser.parse_args(args)


Expand All @@ -57,7 +68,16 @@ def parse_args(args: str = None) -> argparse.Namespace:
artifact = run.use_artifact(args.artifact_name, type="model")
policy_artifact_filepath = artifact.get_path(args.model_path)
policy_artifact_file = policy_artifact_filepath.download()
policy = DQNDialoguePolicy.load_policy(policy_artifact_file)

if args.policy_type == "dqn":
policy = DQNDialoguePolicy.load_policy(policy_artifact_file)
elif args.policy_type == "a2c":
policy = A2CDialoguePolicy.load_policy(policy_artifact_file)
else:
raise ValueError(
f"Unknown policy type '{args.policy_type}'. "
"Supported types are 'dqn' and 'a2c'."
)

# Create the environment
env: DialogueEnvMovieBot = gym.make(
Expand Down

0 comments on commit 3c91016

Please sign in to comment.