Skip to content

Commit

Permalink
black formatted
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 11, 2024
1 parent 8ba6edb commit 927d990
Show file tree
Hide file tree
Showing 19 changed files with 132 additions and 125 deletions.
42 changes: 22 additions & 20 deletions allenact/algorithms/onpolicy_sync/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def __init__(
self.sensor_preprocessor_graph = (
self.machine_params.sensor_preprocessor_graph.to(self.device)
)
create_model_kwargs["sensor_preprocessor_graph"] = (
self.sensor_preprocessor_graph
)
create_model_kwargs[
"sensor_preprocessor_graph"
] = self.sensor_preprocessor_graph

set_seed(self.seed)
self.actor_critic = cast(
Expand Down Expand Up @@ -290,9 +290,9 @@ def __init__(
self.optimizer: Optional[optim.optimizer.Optimizer] = None
# noinspection PyProtectedMember
self.lr_scheduler: Optional[_LRScheduler] = None
self.insufficient_data_for_update: Optional[torch.distributed.PrefixStore] = (
None
)
self.insufficient_data_for_update: Optional[
torch.distributed.PrefixStore
] = None

# Training pipeline will be instantiated during training and inference.
# During inference however, it will be instantiated anew on each run of `run_eval`
Expand Down Expand Up @@ -1061,9 +1061,9 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin):
to_track["lr"] = self.optimizer.param_groups[0]["lr"]

if training_settings.num_mini_batch is not None:
to_track["rollout_num_mini_batch"] = (
training_settings.num_mini_batch
)
to_track[
"rollout_num_mini_batch"
] = training_settings.num_mini_batch

for k, v in to_track.items():
# We need to set the bsize to 1 for `worker_batch_size` below as we're trying to record the
Expand Down Expand Up @@ -1095,13 +1095,13 @@ def single_batch_generator(streaming_storage: StreamingStorageMixin):
)
)

stage.stage_component_uuid_to_stream_memory[stage_component.uuid] = (
detach_recursively(
input=stage.stage_component_uuid_to_stream_memory[
stage_component.uuid
],
inplace=True,
)
stage.stage_component_uuid_to_stream_memory[
stage_component.uuid
] = detach_recursively(
input=stage.stage_component_uuid_to_stream_memory[
stage_component.uuid
],
inplace=True,
)

def close(self, verbose=True):
Expand Down Expand Up @@ -1850,10 +1850,12 @@ def run_pipeline(self, valid_on_initial_weights: bool = False):
# a pipeline stage completes is controlled above
checkpoint_file_name = None
if should_save_checkpoints and (
self.training_pipeline.total_steps - self.last_save
>= cur_stage_training_settings.save_interval
self.training_pipeline.total_steps - self.last_save
>= cur_stage_training_settings.save_interval
):
checkpoint_file_name = self._save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_counter()
checkpoint_file_name = (
self._save_checkpoint_then_send_checkpoint_for_validation_and_update_last_save_counter()
)
already_saved_checkpoint = True

if (
Expand Down Expand Up @@ -1905,7 +1907,7 @@ def train(
checkpoint_file_name = download_checkpoint_from_wandb(
checkpoint_path_dir_or_pattern,
ckpt_dir,
only_allow_one_ckpt=True
only_allow_one_ckpt=True,
)
self.checkpoint_load(checkpoint_file_name, restart_pipeline)

Expand Down
22 changes: 12 additions & 10 deletions allenact/algorithms/onpolicy_sync/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,12 +1102,12 @@ def update_keys_metric(
f" AllenAct, please report this issue at https://github.com/allenai/allenact/issues."
)
else:
scalar_name_to_total_storage_experience[scalar_name] = (
total_exp_for_storage
)
scalar_name_to_total_experiences_key[scalar_name] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)
scalar_name_to_total_storage_experience[
scalar_name
] = total_exp_for_storage
scalar_name_to_total_experiences_key[
scalar_name
] = storage_uuid_to_total_experiences_key[storage_uuid]

if any(checkpoint_file_name):
ckpt_to_store = None
Expand Down Expand Up @@ -1174,9 +1174,9 @@ def update_keys_metric(
stage_component_uuid,
)
callback_metric_means[approx_eps_key] = eps
scalar_name_to_total_experiences_key[approx_eps_key] = (
storage_uuid_to_total_experiences_key[storage_uuid]
)
scalar_name_to_total_experiences_key[
approx_eps_key
] = storage_uuid_to_total_experiences_key[storage_uuid]

if log_writer is not None:
log_writer.add_scalar(
Expand Down Expand Up @@ -1502,7 +1502,9 @@ def get_checkpoint_files(
if "wandb://" == checkpoint_path_dir_or_pattern[:8]:
eval_dir = "wandb_ckpts_to_eval/{}".format(self.local_start_time_str)
os.makedirs(eval_dir, exist_ok=True)
return download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, eval_dir, only_allow_one_ckpt=False)
return download_checkpoint_from_wandb(
checkpoint_path_dir_or_pattern, eval_dir, only_allow_one_ckpt=False
)

if os.path.isdir(checkpoint_path_dir_or_pattern):
# The fragment is a path to a directory, lets use this directory
Expand Down
6 changes: 3 additions & 3 deletions allenact/algorithms/onpolicy_sync/vector_sampled_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,9 +990,9 @@ def _task_sampling_loop_generator_fn(
)
if step_result.info is None:
step_result = step_result.clone({"info": {}})
step_result.info[COMPLETE_TASK_CALLBACK_KEY] = (
task_callback_data
)
step_result.info[
COMPLETE_TASK_CALLBACK_KEY
] = task_callback_data

if auto_resample_when_done:
current_task = task_sampler.next_task()
Expand Down
3 changes: 2 additions & 1 deletion allenact/base_abstractions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def log_prob(


class TrackingCallback(Protocol):
def __call__(self, type: TrackingInfoType, info: Dict[str, Any], n: int): ...
def __call__(self, type: TrackingInfoType, info: Dict[str, Any], n: int):
...


class TeacherForcingDistr(Distr):
Expand Down
18 changes: 9 additions & 9 deletions allenact/embodiedai/aux_losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,21 +514,21 @@ def get_aux_loss(
beliefs.device
) # (T+k, k, N, 1)

pred_masks[num_steps - 1 :] = (
False # GRU(b_t, a_{t:t+k-1}) is invalid when t >= T, as we don't have real z_{t+1}
)
pred_masks[
num_steps - 1 :
] = False # GRU(b_t, a_{t:t+k-1}) is invalid when t >= T, as we don't have real z_{t+1}
for j in range(1, self.planning_steps + 1): # for j-step predictions
pred_masks[: j - 1, j - 1] = (
False # Remove the upper triangle above the diagnonal (but I think this is unnecessary for valid_masks)
)
pred_masks[
: j - 1, j - 1
] = False # Remove the upper triangle above the diagnonal (but I think this is unnecessary for valid_masks)
for n in range(num_sampler):
has_zeros_batch = torch.where(masks[:, n] == 0)[0]
# in j-step prediction, timesteps z -> z + j are disallowed as those are the first j timesteps of a new episode
# z-> z-1 because of pred_masks being offset by 1
for z in has_zeros_batch:
pred_masks[z - 1 : z - 1 + j, j - 1, n] = (
False # can affect j timesteps
)
pred_masks[
z - 1 : z - 1 + j, j - 1, n
] = False # can affect j timesteps

# instead of the whole range, we actually are only comparing a window i:i+k for each query/target i - for each, select the appropriate k
# we essentially gather diagonals from this full mask, t of them, k long
Expand Down
24 changes: 12 additions & 12 deletions allenact/embodiedai/mapping/mapping_utils/map_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,9 @@ def update(
:vr, (width_div_2 - vr_div_2) : (width_div_2 + vr_div_2), :
]

to_return["egocentric_local_context"] = (
egocentric_local_context.cpu().numpy()
)
to_return[
"egocentric_local_context"
] = egocentric_local_context.cpu().numpy()

return to_return

Expand Down Expand Up @@ -443,15 +443,15 @@ def build_ground_truth_map(self, object_hulls: Sequence[ObjectHull2d]):
if ot in self.object_type_to_index:
ind = self.object_type_to_index[ot]

self.ground_truth_semantic_map[:, :, ind : (ind + 1)] = (
cv2.fillConvexPoly(
img=np.array(
self.ground_truth_semantic_map[:, :, ind : (ind + 1)],
dtype=np.uint8,
),
points=self._xzs_to_colrows(np.array(object_hull.hull_points)),
color=255,
)
self.ground_truth_semantic_map[
:, :, ind : (ind + 1)
] = cv2.fillConvexPoly(
img=np.array(
self.ground_truth_semantic_map[:, :, ind : (ind + 1)],
dtype=np.uint8,
),
points=self._xzs_to_colrows(np.array(object_hull.hull_points)),
color=255,
)

def update(
Expand Down
5 changes: 1 addition & 4 deletions allenact/embodiedai/models/basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,7 @@ def adapt_result(
nsteps: int,
nsamplers: int,
nagents: int,
) -> Tuple[
torch.FloatTensor,
torch.FloatTensor,
]:
) -> Tuple[torch.FloatTensor, torch.FloatTensor,]:
output_dims = (nsteps, nsamplers) + ((nagents, -1) if obs_agent else (-1,))
hidden_dims = (self.num_recurrent_layers, nsamplers) + (
(nagents, -1) if mem_agent else (-1,)
Expand Down
17 changes: 11 additions & 6 deletions allenact/utils/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ def __init__(
self.mode = mode

self.training_steps: int = training_steps
self.storage_uuid_to_total_experiences: Dict[str, int] = (
storage_uuid_to_total_experiences
)
self.storage_uuid_to_total_experiences: Dict[
str, int
] = storage_uuid_to_total_experiences
self.pipeline_stage = pipeline_stage

self.metrics_tracker = ScalarMeanTracker()
Expand Down Expand Up @@ -763,7 +763,10 @@ def add_stage_component(self, stage_component: StageComponent):
self.stage_component_uuid_to_stream_memory[stage_component.uuid] = Memory()

def __setattr__(self, key: str, value: Any):
if key not in ["training_settings", "callback_to_change_engine_attributes"] and self.training_settings.has_key(key):
if key not in [
"training_settings",
"callback_to_change_engine_attributes",
] and self.training_settings.has_key(key):
raise NotImplementedError(
f"Cannot set {key} in {self.__name__}, update the"
f" `training_settings` attribute of {self.__name__} instead."
Expand Down Expand Up @@ -1190,7 +1193,9 @@ def current_stage_losses(
}


def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir, only_allow_one_ckpt=False):
def download_checkpoint_from_wandb(
checkpoint_path_dir_or_pattern, all_ckpt_dir, only_allow_one_ckpt=False
):
api = wandb.Api()
run_token = checkpoint_path_dir_or_pattern.split("//")[1]
ckpt_steps = checkpoint_path_dir_or_pattern.split("//")[2:]
Expand All @@ -1215,4 +1220,4 @@ def download_checkpoint_from_wandb(checkpoint_path_dir_or_pattern, all_ckpt_dir,
ckpt_dir = "{}/ckpt-{}.pt".format(all_ckpt_dir, steps)
shutil.move("tmp/ckpt.pt", ckpt_dir)
shutil.rmtree("tmp")
return ckpt_dir
return ckpt_dir
6 changes: 3 additions & 3 deletions allenact/utils/viz_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,9 @@ def __init__(
self.actor_critic_source,
) = self._setup_sources()

self.data: Dict[str, List[Dict]] = (
{}
) # dict of episode id to list of dicts with collected data
self.data: Dict[
str, List[Dict]
] = {} # dict of episode id to list of dicts with collected data
self.last_it2epid: List[str] = []

def _setup_sources(self):
Expand Down
24 changes: 12 additions & 12 deletions allenact_plugins/babyai_plugin/babyai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,13 @@ def forward_loop(
for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
time_ind
]:
current_instr_embeddings_list[sampler_needing_reset_ind] = (
unique_instr_embeddings[
reset_multi_ind_to_index[
(time_ind, sampler_needing_reset_ind)
]
current_instr_embeddings_list[
sampler_needing_reset_ind
] = unique_instr_embeddings[
reset_multi_ind_to_index[
(time_ind, sampler_needing_reset_ind)
]
)
]

instr_embeddings_list.append(
torch.stack(current_instr_embeddings_list, dim=0)
Expand Down Expand Up @@ -352,13 +352,13 @@ def forward(
for sampler_needing_reset_ind in time_ind_to_which_need_instr_reset[
time_ind
]:
current_instr_embeddings_list[sampler_needing_reset_ind] = (
unique_instr_embeddings[
reset_multi_ind_to_index[
(time_ind, sampler_needing_reset_ind)
]
current_instr_embeddings_list[
sampler_needing_reset_ind
] = unique_instr_embeddings[
reset_multi_ind_to_index[
(time_ind, sampler_needing_reset_ind)
]
)
]

instr_embeddings_list.append(
torch.stack(current_instr_embeddings_list, dim=0)
Expand Down
12 changes: 6 additions & 6 deletions allenact_plugins/ithor_plugin/ithor_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def teleport_agent_to(
break
if not reachable:
self.last_action = "TeleportFull"
self.last_event.metadata["errorMessage"] = (
"Target position was not initially reachable."
)
self.last_event.metadata[
"errorMessage"
] = "Target position was not initially reachable."
self.last_action_success = False
return
self.controller.step(
Expand Down Expand Up @@ -681,9 +681,9 @@ def step(
self.teleport_agent_to(**start_location, force_action=True) # type: ignore
self.last_action = action
self.last_action_success = False
self.last_event.metadata["errorMessage"] = (
"Moved to location outside of initially reachable points."
)
self.last_event.metadata[
"errorMessage"
] = "Moved to location outside of initially reachable points."
elif "RandomizeHideSeekObjects" in action:
last_position = self.get_agent_location()
self.controller.step(action_dict)
Expand Down
12 changes: 6 additions & 6 deletions allenact_plugins/ithor_plugin/ithor_task_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__(
self.scene_counter: Optional[int] = None
self.scene_order: Optional[List[str]] = None
self.scene_id: Optional[int] = None
self.scene_period: Optional[Union[str, int]] = (
scene_period # default makes a random choice
)
self.scene_period: Optional[
Union[str, int]
] = scene_period # default makes a random choice
self.max_tasks: Optional[int] = None
self.reset_tasks = max_tasks

Expand Down Expand Up @@ -174,9 +174,9 @@ def next_task(
)

task_info["start_pose"] = copy.copy(pose)
task_info["id"] = (
f"{scene}__{'_'.join(list(map(str, self.env.get_key(pose))))}__{task_info['object_type']}"
)
task_info[
"id"
] = f"{scene}__{'_'.join(list(map(str, self.env.get_key(pose))))}__{task_info['object_type']}"

self._last_sampled_task = ObjectNaviThorGridTask(
env=self.env,
Expand Down
6 changes: 3 additions & 3 deletions allenact_plugins/ithor_plugin/ithor_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def query_expert(self, **kwargs) -> Tuple[int, bool]:
if standing == 1
)

self._CACHED_LOCATIONS_FROM_WHICH_OBJECT_IS_VISIBLE[key] = (
locations_from_which_object_is_visible
)
self._CACHED_LOCATIONS_FROM_WHICH_OBJECT_IS_VISIBLE[
key
] = locations_from_which_object_is_visible

self._subsampled_locations_from_which_obj_visible = (
self._CACHED_LOCATIONS_FROM_WHICH_OBJECT_IS_VISIBLE[key]
Expand Down
Loading

0 comments on commit 927d990

Please sign in to comment.