From dbeed9724bc9099be173c86871df7d41f3b7e58c Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Sat, 2 Mar 2024 17:31:19 -0800 Subject: [PATCH] apply Black 2024 style in fbcode (4/16) Summary: Formats the covered files with pyfmt. paintitblack Reviewed By: aleivag Differential Revision: D54447727 fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f --- examples/albef/data/transforms.py | 8 ++-- examples/mdetr/data/dataset.py | 48 ++++++++++++------- examples/mdetr/data/transforms.py | 2 +- examples/mdetr/loss.py | 2 +- examples/mdetr/tests/test_postprocessors.py | 6 +-- examples/mugen/data/mugen_dataset.py | 4 +- examples/mugen/retrieval/definitions.py | 4 +- tests/diffusion_labs/test_dalle2.py | 4 +- .../models/vae/encoder_decoder.py | 6 +-- torchmultimodal/models/flava/model.py | 46 +++++++++++------- torchmultimodal/models/two_tower.py | 6 +-- torchmultimodal/modules/layers/transformer.py | 1 - .../modules/losses/blip2_losses.py | 5 +- 13 files changed, 85 insertions(+), 57 deletions(-) diff --git a/examples/albef/data/transforms.py b/examples/albef/data/transforms.py index 8d6eaf5d8..472d93b0b 100644 --- a/examples/albef/data/transforms.py +++ b/examples/albef/data/transforms.py @@ -69,9 +69,11 @@ def __init__( self.transform = Sequential( Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(), ToTensor(padding_value=self.pad_token_id), - PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id) - if pad_to_max_seq_len - else torch.nn.Identity(), + ( + PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id) + if pad_to_max_seq_len + else torch.nn.Identity() + ), ) def pre_process(self, text: str) -> str: diff --git a/examples/mdetr/data/dataset.py b/examples/mdetr/data/dataset.py index dbbacd222..7fee64e9c 100644 --- a/examples/mdetr/data/dataset.py +++ b/examples/mdetr/data/dataset.py @@ -101,9 +101,11 @@ def __getitem__(self, idx): self.type_to_id[coco_img["question_type"]], dtype=torch.long ) target["answer_type_mask"] = { - f"answer_{k}": torch.BoolTensor([True]) - if coco_img["question_type"] == k - else torch.BoolTensor([False]) + f"answer_{k}": ( + torch.BoolTensor([True]) + if coco_img["question_type"] == k + else torch.BoolTensor([False]) + ) for k in self.type_to_id.keys() } target["answer_type_mask"]["answer_type"] = torch.BoolTensor([True]) @@ -113,9 +115,11 @@ def __getitem__(self, idx): else: answer = coco_img["answer"] target["answer_attr"] = torch.as_tensor( - self.answer2id_by_type["answer_attr"][answer] - if coco_img["question_type"] == "attr" - else -100, + ( + self.answer2id_by_type["answer_attr"][answer] + if coco_img["question_type"] == "attr" + else -100 + ), dtype=torch.long, ) @@ -124,9 +128,11 @@ def __getitem__(self, idx): else: answer = coco_img["answer"] target["answer_global"] = torch.as_tensor( - self.answer2id_by_type["answer_global"][answer] - if coco_img["question_type"] == "global" - else -100, + ( + self.answer2id_by_type["answer_global"][answer] + if coco_img["question_type"] == "global" + else -100 + ), dtype=torch.long, ) @@ -135,9 +141,11 @@ def __getitem__(self, idx): else: answer = coco_img["answer"] target["answer_rel"] = torch.as_tensor( - self.answer2id_by_type["answer_rel"][answer] - if coco_img["question_type"] == "rel" - else -100, + ( + self.answer2id_by_type["answer_rel"][answer] + if coco_img["question_type"] == "rel" + else -100 + ), dtype=torch.long, ) @@ -146,9 +154,11 @@ def __getitem__(self, idx): else: answer = coco_img["answer"] target["answer_cat"] = torch.as_tensor( - self.answer2id_by_type["answer_cat"][answer] - if coco_img["question_type"] == "cat" - else -100, + ( + self.answer2id_by_type["answer_cat"][answer] + if coco_img["question_type"] == "cat" + else -100 + ), dtype=torch.long, ) @@ -157,9 +167,11 @@ def __getitem__(self, idx): else: answer = coco_img["answer"] target["answer_obj"] = torch.as_tensor( - self.answer2id_by_type["answer_obj"][answer] - if coco_img["question_type"] == "obj" - else -100, + ( + self.answer2id_by_type["answer_obj"][answer] + if coco_img["question_type"] == "obj" + else -100 + ), dtype=torch.long, ) return img, target diff --git a/examples/mdetr/data/transforms.py b/examples/mdetr/data/transforms.py index 0193a7a4e..515b119bb 100644 --- a/examples/mdetr/data/transforms.py +++ b/examples/mdetr/data/transforms.py @@ -334,7 +334,7 @@ def create_positive_map(tokenized, tokens_positive): """construct a map such that positive_map[i,j] = True iff box i is associated to token j""" positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float) for j, tok_list in enumerate(tokens_positive): - for (beg, end) in tok_list: + for beg, end in tok_list: beg_pos = tokenized.char_to_token(beg) end_pos = tokenized.char_to_token(end - 1) if beg_pos is None: diff --git a/examples/mdetr/loss.py b/examples/mdetr/loss.py index 7c85b2d30..63834b6b3 100644 --- a/examples/mdetr/loss.py +++ b/examples/mdetr/loss.py @@ -110,7 +110,7 @@ def construct_positive_map( for i, ((idx_src, idx_tgt), tgt) in enumerate(zip(indices, target_tokens)): cur_tokens = [tgt[j] for j in idx_tgt] for j, tok_list in enumerate(cur_tokens): - for (beg, end) in tok_list: + for beg, end in tok_list: beg_pos = char_to_token(tokenized, i, beg) end_pos = char_to_token(tokenized, i, end - 1) diff --git a/examples/mdetr/tests/test_postprocessors.py b/examples/mdetr/tests/test_postprocessors.py index dfdbc2ebe..ac864b9f2 100644 --- a/examples/mdetr/tests/test_postprocessors.py +++ b/examples/mdetr/tests/test_postprocessors.py @@ -94,9 +94,9 @@ def batched_pos_map(self, n_tokens, n_classes, pos_map): batched_pos_map = torch.zeros((n_boxes, n_classes + 1), dtype=torch.bool) cur_count = 0 for sample in pos_map: - batched_pos_map[ - cur_count : cur_count + len(sample), : sample.shape[1] - ] = sample + batched_pos_map[cur_count : cur_count + len(sample), : sample.shape[1]] = ( + sample + ) cur_count += len(sample) assert cur_count == len(batched_pos_map) diff --git a/examples/mugen/data/mugen_dataset.py b/examples/mugen/data/mugen_dataset.py index 7b138e22d..932ab295c 100644 --- a/examples/mugen/data/mugen_dataset.py +++ b/examples/mugen/data/mugen_dataset.py @@ -38,7 +38,9 @@ class MUGENDatasetArgs: True # render smap for mugen (and shield) as bounding boxes ) bbox_smap_for_monsters: bool = True # render smap for monsters as bounding boxes - use_manual_annotation: bool = False # if True will only use videos with manual annotation and skip those without + use_manual_annotation: bool = ( + False # if True will only use videos with manual annotation and skip those without + ) use_auto_annotation: bool = ( True # if True will only use videos with auto annotation and skip those without ) diff --git a/examples/mugen/retrieval/definitions.py b/examples/mugen/retrieval/definitions.py index 569e48f0d..1232e82fc 100644 --- a/examples/mugen/retrieval/definitions.py +++ b/examples/mugen/retrieval/definitions.py @@ -81,7 +81,9 @@ class EvaluationArgs: datamodule_args: DataModuleArgs = DataModuleArgs() lightningmodule_args: LightningModuleArgs = LightningModuleArgs() videoclip_args: VideoCLIPArgs = VideoCLIPArgs() - checkpoint_path: str = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt" + checkpoint_path: str = ( + "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt" + ) accelerator: str = "auto" diff --git a/tests/diffusion_labs/test_dalle2.py b/tests/diffusion_labs/test_dalle2.py index 46cbc04d2..653f8a68d 100644 --- a/tests/diffusion_labs/test_dalle2.py +++ b/tests/diffusion_labs/test_dalle2.py @@ -46,7 +46,5 @@ def test_dalle2_image_transform(): actual = transform({"x": image})["x"].sum() normalized128 = 128 / 255 * 2 - 1 normalized0 = -1 - expected = torch.tensor( - normalized128 * img_size**2 + 2 * normalized0 * img_size**2 - ) + expected = torch.tensor(normalized128 * img_size**2 + 2 * normalized0 * img_size**2) assert_expected(actual, expected, rtol=0, atol=1e-4) diff --git a/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py b/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py index 3dad909c3..cdaa17823 100644 --- a/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py +++ b/torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py @@ -85,9 +85,9 @@ def __init__( block_out, num_res_blocks, dropout, - needs_downsample=True - if level_idx != num_resolutions - 1 - else False, + needs_downsample=( + True if level_idx != num_resolutions - 1 else False + ), norm_groups=norm_groups, norm_eps=norm_eps, ) diff --git a/torchmultimodal/models/flava/model.py b/torchmultimodal/models/flava/model.py index 670b00305..46bc28bd0 100644 --- a/torchmultimodal/models/flava/model.py +++ b/torchmultimodal/models/flava/model.py @@ -195,20 +195,28 @@ def forward( if not skip_unmasked_mm_encoder: # Unmasked multimodal embedding is not currently used by any of the FLAVA losses. multimodal_outputs = self.encode_mm( - image_outputs.hidden_states[-1] # type: ignore - if image_outputs.hidden_states # type: ignore - else None, - text_outputs.hidden_states[-1] # type: ignore - if text_outputs.hidden_states # type: ignore - else None, + ( + image_outputs.hidden_states[-1] # type: ignore + if image_outputs.hidden_states # type: ignore + else None + ), + ( + text_outputs.hidden_states[-1] # type: ignore + if text_outputs.hidden_states # type: ignore + else None + ), ) multimodal_masked_outputs = self.encode_mm( - image_masked_outputs.hidden_states[-1] - if image_masked_outputs.hidden_states - else None, - text_masked_outputs.hidden_states[-1] - if text_masked_outputs.hidden_states - else None, + ( + image_masked_outputs.hidden_states[-1] + if image_masked_outputs.hidden_states + else None + ), + ( + text_masked_outputs.hidden_states[-1] + if text_masked_outputs.hidden_states + else None + ), ) return FLAVAOutput( @@ -266,9 +274,9 @@ def _encode_data_to_embeddings( Union[Tuple[TransformerOutput, Tensor], Optional[TransformerOutput]], ], ) -> Union[Tuple[TransformerOutput, Tensor], Optional[TransformerOutput]]: - output: Union[ - Tuple[TransformerOutput, Tensor], TransformerOutput - ] = TransformerOutput() + output: Union[Tuple[TransformerOutput, Tensor], TransformerOutput] = ( + TransformerOutput() + ) if data is not None and selected_head_encoder in encoder_options: output = encode_callable(data) @@ -355,9 +363,11 @@ def forward( text_sequence=flava_output.text.last_hidden_state, image_masked_sequence=flava_output.image_masked.last_hidden_state, text_masked_sequence=flava_output.text_masked.last_hidden_state, - multimodal_sequence=flava_output.multimodal.last_hidden_state - if not skip_unmasked_mm_encoder - else None, + multimodal_sequence=( + flava_output.multimodal.last_hidden_state + if not skip_unmasked_mm_encoder + else None + ), multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state, itm_labels=itm_labels, mim_labels=image_labels, diff --git a/torchmultimodal/models/two_tower.py b/torchmultimodal/models/two_tower.py index 58dc97751..50bdcccb9 100644 --- a/torchmultimodal/models/two_tower.py +++ b/torchmultimodal/models/two_tower.py @@ -55,9 +55,9 @@ def __init__( raise ValueError( "Towers should be shared if channel mapping is passed in" ) - self.shared_tower_id_to_channel_mapping: Optional[ - Dict[str, Dict[str, str]] - ] = shared_tower_id_to_channel_mapping + self.shared_tower_id_to_channel_mapping: Optional[Dict[str, Dict[str, str]]] = ( + shared_tower_id_to_channel_mapping + ) def forward(self, channel_to_input: Dict[str, Tensor]) -> TwoTowerOutput: tower_embeddings = OrderedDict() diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 0f02c7dd8..530a3b0e0 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -217,7 +217,6 @@ def forward( attention_mask: Optional[Tensor] = None, return_hidden_states: bool = False, ) -> TransformerOutput: - """ Args: hidden_states (Tensor): input to the transformer encoder of shape bsz x seq_len x d_model diff --git a/torchmultimodal/modules/losses/blip2_losses.py b/torchmultimodal/modules/losses/blip2_losses.py index 3bf0ecf10..6e66d85d3 100644 --- a/torchmultimodal/modules/losses/blip2_losses.py +++ b/torchmultimodal/modules/losses/blip2_losses.py @@ -309,7 +309,10 @@ def forward( # calculate similarities assert model_output.text_features is not None - (sim_i2t, sim_t2i,) = compute_image_text_similarity( + ( + sim_i2t, + sim_t2i, + ) = compute_image_text_similarity( model_output.image_features, model_output.text_features, temp=self.temp,