diff --git a/moviebot/dialogue_manager/dialogue_policy/a2c_dialogue_policy.py b/moviebot/dialogue_manager/dialogue_policy/a2c_dialogue_policy.py index 726ac9d..4b8fc87 100644 --- a/moviebot/dialogue_manager/dialogue_policy/a2c_dialogue_policy.py +++ b/moviebot/dialogue_manager/dialogue_policy/a2c_dialogue_policy.py @@ -32,9 +32,7 @@ def __init__( num_timesteps: The number of timesteps. Defaults to None. n_envs: The number of environments. Defaults to 1. """ - super().__init__( - input_size, hidden_size, output_size, possible_actions - ) + super().__init__(input_size, hidden_size, output_size, possible_actions) self.n_envs = n_envs @@ -67,9 +65,7 @@ def __init__( self.critic_optimizer, total_iters=num_timesteps ) - def forward( - self, state: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass. Args: diff --git a/moviebot/dialogue_manager/dialogue_policy/dqn_dialogue_policy.py b/moviebot/dialogue_manager/dialogue_policy/dqn_dialogue_policy.py index 367d878..e7000fc 100644 --- a/moviebot/dialogue_manager/dialogue_policy/dqn_dialogue_policy.py +++ b/moviebot/dialogue_manager/dialogue_policy/dqn_dialogue_policy.py @@ -29,9 +29,7 @@ def __init__( output_size: The size of the output vector. possible_actions: The list of possible actions. """ - super().__init__( - input_size, hidden_size, output_size, possible_actions - ) + super().__init__(input_size, hidden_size, output_size, possible_actions) self.model = torch.nn.Sequential( torch.nn.Linear(input_size, hidden_size),