Skip to content

Commit

Permalink
Fix agent testing
Browse files Browse the repository at this point in the history
  • Loading branch information
washingtonsk8 committed Feb 17, 2021
1 parent 0aaae03 commit afff8c6
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 0 deletions.
1 change: 1 addition & 0 deletions rl_fast_forward/REINFORCE/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def train(self, envs, dataset_name, n_epochs, model_path=None):
avg_epoch_losses.append(epoch_losses.detach().cpu())
avg_critic_losses.append(critic_losses.detach().cpu())

self.writer.add_scalar('Rewards_{}/_overall_avg_reward'.format(dataset_name), np.mean([scores[exp_key] for exp_key in exp_keys]), i_epoch)

self.optimizer.zero_grad()
epoch_losses.backward() # Computes the derivative of loss with respect to theta (dLoss/dTheta)
Expand Down
1 change: 1 addition & 0 deletions rl_fast_forward/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test(args):
test_set = [exp for exp in experiments if Experiment2VideoMapping(exp).split_type in ['validation', 'test']]

test_envs = {exp_key: VideoEnvironment(args.semantic_encoder_model_filename, args.user_document_filename, experiment_name=exp_key, batch_size=args.batch_size) for exp_key in test_set}
env = test_envs[test_set[0]]

json_sf = {'info': {'version': 'v1.1_{}'.format(dt.now().strftime('%Y%m%d_%H%M%S')), 'dataset': args.dataset}, 'data': {}}

Expand Down

0 comments on commit afff8c6

Please sign in to comment.