Skip to content

Commit

Permalink
reformat and track grad norm
Browse files Browse the repository at this point in the history
  • Loading branch information
KuoHaoZeng committed Feb 21, 2024
1 parent 92d06ea commit fead94c
Show file tree
Hide file tree
Showing 118 changed files with 1,139 additions and 550 deletions.
172 changes: 109 additions & 63 deletions allenact/algorithms/onpolicy_sync/engine.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions allenact/algorithms/onpolicy_sync/losses/a2cacktr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Implementation of A2C and ACKTR losses."""

from typing import cast, Tuple, Dict, Optional

import torch
Expand Down Expand Up @@ -99,7 +100,9 @@ def loss( # type: ignore
**kwargs,
):
losses_per_step = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -169,4 +172,7 @@ def __init__(
)


A2CConfig = dict(value_loss_coef=0.5, entropy_coef=0.01,)
A2CConfig = dict(
value_loss_coef=0.5,
entropy_coef=0.01,
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def loss( # type: ignore
torch.log((probs_tensor * expert_group_actions_mask).sum(-1))
).mean()

return total_loss, {"grouped_action_cross_entropy": total_loss.item(),}
return total_loss, {
"grouped_action_cross_entropy": total_loss.item(),
}
12 changes: 8 additions & 4 deletions allenact/algorithms/onpolicy_sync/losses/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def loss( # type: ignore
ready_actions[group_name] = expert_action

current_loss, expert_successes = self.group_loss(
cd, expert_action, expert_action_masks,
cd,
expert_action,
expert_action_masks,
)

should_report_loss = (
Expand Down Expand Up @@ -204,7 +206,9 @@ def loss( # type: ignore
)
return (
total_loss,
{"expert_cross_entropy": total_loss.item(), **losses}
if should_report_loss
else {},
(
{"expert_cross_entropy": total_loss.item(), **losses}
if should_report_loss
else {}
),
)
28 changes: 17 additions & 11 deletions allenact/algorithms/onpolicy_sync/losses/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def add_trailing_dims(t: torch.Tensor):
"action": (action_loss, None),
"entropy": (dist_entropy.mul_(-1.0), self.entropy_coef), # type: ignore
},
{
"ratio": ratio,
"ratio_clamped": clamped_ratio,
"ratio_used": torch.where(
cast(torch.Tensor, use_clamped), clamped_ratio, ratio
),
}
if self.show_ratios
else {},
(
{
"ratio": ratio,
"ratio_clamped": clamped_ratio,
"ratio_used": torch.where(
cast(torch.Tensor, use_clamped), clamped_ratio, ratio
),
}
if self.show_ratios
else {}
),
)

def loss( # type: ignore
Expand All @@ -135,7 +137,9 @@ def loss( # type: ignore
**kwargs
):
losses_per_step, ratio_info = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -210,7 +214,9 @@ def loss( # type: ignore

return (
value_loss,
{"value": value_loss.item(),},
{
"value": value_loss.item(),
},
)


Expand Down
79 changes: 46 additions & 33 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Defines the reinforcement learning `OnPolicyRunner`."""

import copy
import enum
import glob
Expand Down Expand Up @@ -542,9 +543,9 @@ def start_train(
config=self.config,
callback_sensors=self._get_callback_sensors,
results_queue=self.queues["results"],
checkpoints_queue=self.queues["checkpoints"]
if self.running_validation
else None,
checkpoints_queue=(
self.queues["checkpoints"] if self.running_validation else None
),
checkpoints_dir=self.checkpoint_dir(),
seed=self.seed,
deterministic_cudnn=self.deterministic_cudnn,
Expand All @@ -555,9 +556,9 @@ def start_train(
distributed_port=distributed_port,
max_sampler_processes_per_worker=max_sampler_processes_per_worker,
save_ckpt_after_every_pipeline_stage=save_ckpt_after_every_pipeline_stage,
initial_model_state_dict=initial_model_state_dict
if model_hash is None
else model_hash,
initial_model_state_dict=(
initial_model_state_dict if model_hash is None else model_hash
),
first_local_worker_id=worker_ids[0],
distributed_preemption_threshold=self.distributed_preemption_threshold,
valid_on_initial_weights=valid_on_initial_weights,
Expand Down Expand Up @@ -782,9 +783,11 @@ def checkpoint_dir(
self, start_time_str: Optional[str] = None, create_if_none: bool = True
):
path_parts = [
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str or self.local_start_time_str,
]
if self.save_dir_fmt == SaveDirFormat.NESTED:
Expand Down Expand Up @@ -816,9 +819,11 @@ def log_writer_path(self, start_time_str: str) -> str:
)
path = os.path.join(
self.output_dir,
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
"train_tb",
)
Expand All @@ -827,9 +832,11 @@ def log_writer_path(self, start_time_str: str) -> str:
path = os.path.join(
self.output_dir,
"tb",
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
)
if self.mode == TEST_MODE_STR:
Expand All @@ -850,19 +857,23 @@ def metric_path(self, start_time_str: str) -> str:
return os.path.join(
self.output_dir,
"metrics",
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
start_time_str,
)
else:
raise NotImplementedError

def save_project_state(self):
path_parts = [
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag),
(
self.config.tag()
if self.extra_tag == ""
else os.path.join(self.config.tag(), self.extra_tag)
),
self.local_start_time_str,
]
if self.save_dir_fmt == SaveDirFormat.NESTED:
Expand Down Expand Up @@ -1091,12 +1102,12 @@ def update_keys_metric(
f" AllenAct, please report this issue at https://github.com/allenai/allenact/issues."
)
else:
scalar_name_to_total_storage_experience[
scalar_name
] = total_exp_for_storage
scalar_name_to_total_experiences_key[
scalar_name
] = storage_uuid_to_total_experiences_key[storage_uuid]
scalar_name_to_total_storage_experience[scalar_name] = (
total_exp_for_storage
)
scalar_name_to_total_experiences_key[scalar_name] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)

assert all_equal(
checkpoint_file_name
Expand Down Expand Up @@ -1156,9 +1167,9 @@ def update_keys_metric(
stage_component_uuid,
)
callback_metric_means[approx_eps_key] = eps
scalar_name_to_total_experiences_key[
approx_eps_key
] = storage_uuid_to_total_experiences_key[storage_uuid]
scalar_name_to_total_experiences_key[approx_eps_key] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)

if log_writer is not None:
log_writer.add_scalar(
Expand Down Expand Up @@ -1358,9 +1369,11 @@ def log_and_close(
self.process_valid_package(
log_writer=log_writer,
pkg=package,
all_results=eval_results
if self._collect_valid_results
else None,
all_results=(
eval_results
if self._collect_valid_results
else None
),
)

if metrics_file is not None:
Expand Down
27 changes: 20 additions & 7 deletions allenact/algorithms/onpolicy_sync/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def empty(self) -> bool:
class MiniBatchStorageMixin(abc.ABC):
@abc.abstractmethod
def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
) -> Generator[Dict[str, Any], None, None]:
raise NotImplementedError

Expand Down Expand Up @@ -183,7 +184,8 @@ def initialize(
self.action_space = action_space

self.memory_first_last: Memory = self.create_memory(
spec=self.memory_specification, num_samplers=num_samplers,
spec=self.memory_specification,
num_samplers=num_samplers,
).to(self.device)
for key in self.memory_specification:
self.flattened_to_unflattened["memory"][key] = [key]
Expand Down Expand Up @@ -249,7 +251,10 @@ def observations(self) -> Memory:
return self._observations_full.slice(dim=0, start=0, stop=self.step + 1)

@staticmethod
def create_memory(spec: Optional[FullMemorySpecType], num_samplers: int,) -> Memory:
def create_memory(
spec: Optional[FullMemorySpecType],
num_samplers: int,
) -> Memory:
if spec is None:
return Memory()

Expand Down Expand Up @@ -290,7 +295,9 @@ def to(self, device: torch.device):
self.device = device

def insert_observations(
self, observations: ObservationType, time_step: int,
self,
observations: ObservationType,
time_step: int,
):
self.insert_tensors(
storage=self._observations_full,
Expand All @@ -300,7 +307,9 @@ def insert_observations(
)

def insert_memory(
self, memory: Optional[Memory], time_step: int,
self,
memory: Optional[Memory],
time_step: int,
):
if memory is None:
assert len(self.memory_first_last) == 0
Expand Down Expand Up @@ -519,7 +528,10 @@ def before_updates(
):
assert len(kwargs) == 0
self.compute_returns(
next_value=next_value, use_gae=use_gae, gamma=gamma, tau=tau,
next_value=next_value,
use_gae=use_gae,
gamma=gamma,
tau=tau,
)

self._advantages = self.returns[:-1] - self.value_preds[:-1]
Expand Down Expand Up @@ -587,7 +599,8 @@ def compute_returns(
)

def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
):
assert self._before_update_called, (
"self._before_update_called() must be called before"
Expand Down
24 changes: 15 additions & 9 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,9 @@ def _task_sampling_loop_worker(
else:
connection_write_fn(
sp_vector_sampled_tasks.command_at(
sampler_index=sampler_index, command=command, data=data,
sampler_index=sampler_index,
command=command,
data=data,
)
)
else:
Expand Down Expand Up @@ -500,7 +502,9 @@ def get_observations(self):
List of observations for each of the unpaused tasks.
"""
return self.call(["get_observations"] * self.num_unpaused_tasks,)
return self.call(
["get_observations"] * self.num_unpaused_tasks,
)

def command_at(
self, sampler_index: int, command: str, data: Optional[Any] = None
Expand Down Expand Up @@ -689,9 +693,9 @@ def pause_at(self, sampler_index: int) -> None:
for i in range(
sampler_index + 1, len(self.sampler_index_to_process_ind_and_subprocess_ind)
):
other_process_and_sub_process_inds = self.sampler_index_to_process_ind_and_subprocess_ind[
i
]
other_process_and_sub_process_inds = (
self.sampler_index_to_process_ind_and_subprocess_ind[i]
)
if other_process_and_sub_process_inds[0] == process_ind:
other_process_and_sub_process_inds[1] -= 1
else:
Expand Down Expand Up @@ -988,9 +992,9 @@ def _task_sampling_loop_generator_fn(
)
if step_result.info is None:
step_result = step_result.clone({"info": {}})
step_result.info[
COMPLETE_TASK_CALLBACK_KEY
] = task_callback_data
step_result.info[COMPLETE_TASK_CALLBACK_KEY] = (
task_callback_data
)

if auto_resample_when_done:
current_task = task_sampler.next_task()
Expand Down Expand Up @@ -1140,7 +1144,9 @@ def get_observations(self):
List of observations for each of the unpaused tasks.
"""
return self.call(["get_observations"] * self.num_unpaused_tasks,)
return self.call(
["get_observations"] * self.num_unpaused_tasks,
)

def next_task_at(self, index_process: int) -> List[RLStepResult]:
"""Move to the the next Task from the TaskSampler in index_process
Expand Down
Loading

0 comments on commit fead94c

Please sign in to comment.