Skip to content

Commit

Permalink
black==23.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 11, 2024
1 parent 92d06ea commit 698ee02
Show file tree
Hide file tree
Showing 120 changed files with 799 additions and 344 deletions.
57 changes: 37 additions & 20 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.multiprocessing as mp # type: ignore
import torch.nn as nn
import torch.optim as optim

# noinspection PyProtectedMember
from torch._C._distributed_c10d import ReduceOp

Expand Down Expand Up @@ -196,16 +197,17 @@ def __init__(

create_model_kwargs = {}
if self.machine_params.sensor_preprocessor_graph is not None:
self.sensor_preprocessor_graph = self.machine_params.sensor_preprocessor_graph.to(
self.device
self.sensor_preprocessor_graph = (
self.machine_params.sensor_preprocessor_graph.to(self.device)
)
create_model_kwargs[
"sensor_preprocessor_graph"
] = self.sensor_preprocessor_graph

set_seed(self.seed)
self.actor_critic = cast(
ActorCriticModel, self.config.create_model(**create_model_kwargs),
ActorCriticModel,
self.config.create_model(**create_model_kwargs),
).to(self.device)

if initial_model_state_dict is not None:
Expand Down Expand Up @@ -343,7 +345,7 @@ def worker_seeds(nprocesses: int, initial_seed: Optional[int]) -> List[int]:
if initial_seed is not None:
rstate = random.getstate()
random.seed(initial_seed)
seeds = [random.randint(0, (2 ** 31) - 1) for _ in range(nprocesses)]
seeds = [random.randint(0, (2**31) - 1) for _ in range(nprocesses)]
if initial_seed is not None:
random.setstate(rstate)
return seeds
Expand Down Expand Up @@ -400,7 +402,8 @@ def checkpoint_load(
ckpt = torch.load(os.path.abspath(ckpt), map_location="cpu")

ckpt = cast(
Dict[str, Union[Dict[str, Any], torch.Tensor, float, int, str, List]], ckpt,
Dict[str, Union[Dict[str, Any], torch.Tensor, float, int, str, List]],
ckpt,
)

self.actor_critic.load_state_dict(ckpt["model_state_dict"]) # type:ignore
Expand All @@ -414,7 +417,9 @@ def checkpoint_load(

# aggregates task metrics currently in queue
def aggregate_task_metrics(
self, logging_pkg: LoggingPackage, num_tasks: int = -1,
self,
logging_pkg: LoggingPackage,
num_tasks: int = -1,
) -> LoggingPackage:
if num_tasks > 0:
if len(self.single_process_metrics) != num_tasks:
Expand Down Expand Up @@ -473,7 +478,6 @@ def initialize_storage_and_viz(
storage_to_initialize: Optional[Sequence[ExperienceStorage]],
visualizer: Optional[VizSuite] = None,
):

keep: Optional[List] = None
if visualizer is not None or (
storage_to_initialize is not None
Expand Down Expand Up @@ -652,7 +656,8 @@ def collect_step_across_all_task_samplers(
) -> int:
rollout_storage = cast(RolloutStorage, uuid_to_storage[rollout_storage_uuid])
actions, actor_critic_output, memory, _ = self.act(
rollout_storage=rollout_storage, dist_wrapper_class=dist_wrapper_class,
rollout_storage=rollout_storage,
dist_wrapper_class=dist_wrapper_class,
)

# Flatten actions
Expand Down Expand Up @@ -687,7 +692,9 @@ def collect_step_across_all_task_samplers(
observations, rewards, dones, infos = [list(x) for x in zip(*outputs)]

rewards = torch.tensor(
rewards, dtype=torch.float, device=self.device, # type:ignore
rewards,
dtype=torch.float,
device=self.device, # type:ignore
)

# We want rewards to have dimensions [sampler, reward]
Expand All @@ -701,7 +708,9 @@ def collect_step_across_all_task_samplers(
masks = (
1.0
- torch.tensor(
dones, dtype=torch.float32, device=self.device, # type:ignore
dones,
dtype=torch.float32,
device=self.device, # type:ignore
)
).view(
-1, 1
Expand Down Expand Up @@ -802,7 +811,6 @@ def step_count(self) -> int:
return 0
return self.training_pipeline.current_stage.steps_taken_in_stage


def compute_losses_track_them_and_backprop(
self,
stage: PipelineStage,
Expand Down Expand Up @@ -942,7 +950,6 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin):
bsize = batch["bsize"]

if actor_critic_output_for_batch is None:

try:
actor_critic_output_for_batch, _ = self.actor_critic(
observations=batch["observations"],
Expand Down Expand Up @@ -1216,8 +1223,10 @@ def __init__(
"offpolicy_epoch_done", self.store
)
# Flag for finished worker in current epoch with custom component
self.insufficient_data_for_update = torch.distributed.PrefixStore( # type:ignore
"insufficient_data_for_update", self.store
self.insufficient_data_for_update = (
torch.distributed.PrefixStore( # type:ignore
"insufficient_data_for_update", self.store
)
)
else:
self.num_workers_done = None
Expand All @@ -1243,7 +1252,7 @@ def advance_seed(
if seed is None:
return seed
seed = (seed ^ (self.training_pipeline.total_steps + 1)) % (
2 ** 31 - 1
2**31 - 1
) # same seed for all workers

if (not return_same_seed_per_worker) and (
Expand Down Expand Up @@ -1375,7 +1384,9 @@ def step_count(self, val: int) -> None:

@property
def log_interval(self):
return self.training_pipeline.current_stage.training_settings.metric_accumulate_interval
return (
self.training_pipeline.current_stage.training_settings.metric_accumulate_interval
)

@property
def approx_steps(self):
Expand Down Expand Up @@ -1416,7 +1427,8 @@ def tracking_callback(type: TrackingInfoType, info: Dict[str, Any], n: int):
)

actions, actor_critic_output, memory, step_observation = super().act(
rollout_storage=rollout_storage, dist_wrapper_class=dist_wrapper_class,
rollout_storage=rollout_storage,
dist_wrapper_class=dist_wrapper_class,
)

self.step_count += self.num_active_samplers
Expand Down Expand Up @@ -1474,14 +1486,18 @@ def backprop_step(
else: # local_global_batch_size_tuple is not None, since we're distributed:
p.grad = p.grad * local_to_global_batch_size_ratio
reductions.append(
dist.all_reduce(p.grad, async_op=True,) # sum
dist.all_reduce(
p.grad,
async_op=True,
) # sum
) # synchronize
all_params.append(p)
for reduction, p in zip(reductions, all_params):
reduction.wait()

nn.utils.clip_grad_norm_(
self.actor_critic.parameters(), max_norm=max_grad_norm, # type: ignore
self.actor_critic.parameters(),
max_norm=max_grad_norm, # type: ignore
)

self.optimizer.step() # type: ignore
Expand Down Expand Up @@ -2097,7 +2113,8 @@ def run_eval(
lengths: List[int]
if self.num_active_samplers > 0:
lengths = self.vector_tasks.command(
"sampler_attr", ["length"] * self.num_active_samplers,
"sampler_attr",
["length"] * self.num_active_samplers,
)
npending = sum(lengths)
else:
Expand Down
9 changes: 7 additions & 2 deletions allenact/algorithms/onpolicy_sync/losses/a2cacktr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ def loss( # type: ignore
**kwargs,
):
losses_per_step = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -169,4 +171,7 @@ def __init__(
)


A2CConfig = dict(value_loss_coef=0.5, entropy_coef=0.01,)
A2CConfig = dict(
value_loss_coef=0.5,
entropy_coef=0.01,
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def loss( # type: ignore
torch.log((probs_tensor * expert_group_actions_mask).sum(-1))
).mean()

return total_loss, {"grouped_action_cross_entropy": total_loss.item(),}
return total_loss, {
"grouped_action_cross_entropy": total_loss.item(),
}
4 changes: 3 additions & 1 deletion allenact/algorithms/onpolicy_sync/losses/imitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ def loss( # type: ignore
ready_actions[group_name] = expert_action

current_loss, expert_successes = self.group_loss(
cd, expert_action, expert_action_masks,
cd,
expert_action,
expert_action_masks,
)

should_report_loss = (
Expand Down
9 changes: 6 additions & 3 deletions allenact/algorithms/onpolicy_sync/losses/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def loss_per_step(
) -> Tuple[
Dict[str, Tuple[torch.Tensor, Optional[float]]], Dict[str, torch.Tensor]
]: # TODO tuple output

actions = cast(torch.LongTensor, batch["actions"])
values = actor_critic_output.values

Expand Down Expand Up @@ -135,7 +134,9 @@ def loss( # type: ignore
**kwargs
):
losses_per_step, ratio_info = self.loss_per_step(
step_count=step_count, batch=batch, actor_critic_output=actor_critic_output,
step_count=step_count,
batch=batch,
actor_critic_output=actor_critic_output,
)
losses = {
key: (loss.mean(), weight)
Expand Down Expand Up @@ -210,7 +211,9 @@ def loss( # type: ignore

return (
value_loss,
{"value": value_loss.item(),},
{
"value": value_loss.item(),
},
)


Expand Down
3 changes: 0 additions & 3 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,6 @@ def update_keys_metric(
(stage_component_uuid, storage_uuid),
info_tracker,
) in pkg.info_trackers.items():

if stage_component_uuid is not None:
storage_uuid_to_stage_component_uuids[storage_uuid].add(
stage_component_uuid
Expand Down Expand Up @@ -1313,7 +1312,6 @@ def log_and_close(
if pkg_mode == TRAIN_MODE_STR:
collected.append(package)
if len(collected) >= nworkers:

collected = sorted(
collected,
key=lambda pkg: (
Expand Down Expand Up @@ -1479,7 +1477,6 @@ def get_checkpoint_files(
checkpoint_path_dir_or_pattern: str,
approx_ckpt_step_interval: Optional[int] = None,
):

if os.path.isdir(checkpoint_path_dir_or_pattern):
# The fragment is a path to a directory, lets use this directory
# as the base dir to search for checkpoints
Expand Down
27 changes: 20 additions & 7 deletions allenact/algorithms/onpolicy_sync/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def empty(self) -> bool:
class MiniBatchStorageMixin(abc.ABC):
@abc.abstractmethod
def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
) -> Generator[Dict[str, Any], None, None]:
raise NotImplementedError

Expand Down Expand Up @@ -183,7 +184,8 @@ def initialize(
self.action_space = action_space

self.memory_first_last: Memory = self.create_memory(
spec=self.memory_specification, num_samplers=num_samplers,
spec=self.memory_specification,
num_samplers=num_samplers,
).to(self.device)
for key in self.memory_specification:
self.flattened_to_unflattened["memory"][key] = [key]
Expand Down Expand Up @@ -249,7 +251,10 @@ def observations(self) -> Memory:
return self._observations_full.slice(dim=0, start=0, stop=self.step + 1)

@staticmethod
def create_memory(spec: Optional[FullMemorySpecType], num_samplers: int,) -> Memory:
def create_memory(
spec: Optional[FullMemorySpecType],
num_samplers: int,
) -> Memory:
if spec is None:
return Memory()

Expand Down Expand Up @@ -290,7 +295,9 @@ def to(self, device: torch.device):
self.device = device

def insert_observations(
self, observations: ObservationType, time_step: int,
self,
observations: ObservationType,
time_step: int,
):
self.insert_tensors(
storage=self._observations_full,
Expand All @@ -300,7 +307,9 @@ def insert_observations(
)

def insert_memory(
self, memory: Optional[Memory], time_step: int,
self,
memory: Optional[Memory],
time_step: int,
):
if memory is None:
assert len(self.memory_first_last) == 0
Expand Down Expand Up @@ -519,7 +528,10 @@ def before_updates(
):
assert len(kwargs) == 0
self.compute_returns(
next_value=next_value, use_gae=use_gae, gamma=gamma, tau=tau,
next_value=next_value,
use_gae=use_gae,
gamma=gamma,
tau=tau,
)

self._advantages = self.returns[:-1] - self.value_preds[:-1]
Expand Down Expand Up @@ -587,7 +599,8 @@ def compute_returns(
)

def batched_experience_generator(
self, num_mini_batch: int,
self,
num_mini_batch: int,
):
assert self._before_update_called, (
"self._before_update_called() must be called before"
Expand Down
Loading

0 comments on commit 698ee02

Please sign in to comment.