Skip to content

Commit

Permalink
apply Black 2024 style in fbcode (4/16)
Browse files Browse the repository at this point in the history
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
  • Loading branch information
amyreese authored and facebook-github-bot committed Mar 3, 2024
1 parent c261d71 commit dbeed97
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 57 deletions.
8 changes: 5 additions & 3 deletions examples/albef/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ def __init__(
self.transform = Sequential(
Truncate(max_seq_len=max_seq_len) if truncate else torch.nn.Identity(),
ToTensor(padding_value=self.pad_token_id),
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
if pad_to_max_seq_len
else torch.nn.Identity(),
(
PadTransform(max_length=max_seq_len, pad_value=self.pad_token_id)
if pad_to_max_seq_len
else torch.nn.Identity()
),
)

def pre_process(self, text: str) -> str:
Expand Down
48 changes: 30 additions & 18 deletions examples/mdetr/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,11 @@ def __getitem__(self, idx):
self.type_to_id[coco_img["question_type"]], dtype=torch.long
)
target["answer_type_mask"] = {
f"answer_{k}": torch.BoolTensor([True])
if coco_img["question_type"] == k
else torch.BoolTensor([False])
f"answer_{k}": (
torch.BoolTensor([True])
if coco_img["question_type"] == k
else torch.BoolTensor([False])
)
for k in self.type_to_id.keys()
}
target["answer_type_mask"]["answer_type"] = torch.BoolTensor([True])
Expand All @@ -113,9 +115,11 @@ def __getitem__(self, idx):
else:
answer = coco_img["answer"]
target["answer_attr"] = torch.as_tensor(
self.answer2id_by_type["answer_attr"][answer]
if coco_img["question_type"] == "attr"
else -100,
(
self.answer2id_by_type["answer_attr"][answer]
if coco_img["question_type"] == "attr"
else -100
),
dtype=torch.long,
)

Expand All @@ -124,9 +128,11 @@ def __getitem__(self, idx):
else:
answer = coco_img["answer"]
target["answer_global"] = torch.as_tensor(
self.answer2id_by_type["answer_global"][answer]
if coco_img["question_type"] == "global"
else -100,
(
self.answer2id_by_type["answer_global"][answer]
if coco_img["question_type"] == "global"
else -100
),
dtype=torch.long,
)

Expand All @@ -135,9 +141,11 @@ def __getitem__(self, idx):
else:
answer = coco_img["answer"]
target["answer_rel"] = torch.as_tensor(
self.answer2id_by_type["answer_rel"][answer]
if coco_img["question_type"] == "rel"
else -100,
(
self.answer2id_by_type["answer_rel"][answer]
if coco_img["question_type"] == "rel"
else -100
),
dtype=torch.long,
)

Expand All @@ -146,9 +154,11 @@ def __getitem__(self, idx):
else:
answer = coco_img["answer"]
target["answer_cat"] = torch.as_tensor(
self.answer2id_by_type["answer_cat"][answer]
if coco_img["question_type"] == "cat"
else -100,
(
self.answer2id_by_type["answer_cat"][answer]
if coco_img["question_type"] == "cat"
else -100
),
dtype=torch.long,
)

Expand All @@ -157,9 +167,11 @@ def __getitem__(self, idx):
else:
answer = coco_img["answer"]
target["answer_obj"] = torch.as_tensor(
self.answer2id_by_type["answer_obj"][answer]
if coco_img["question_type"] == "obj"
else -100,
(
self.answer2id_by_type["answer_obj"][answer]
if coco_img["question_type"] == "obj"
else -100
),
dtype=torch.long,
)
return img, target
Expand Down
2 changes: 1 addition & 1 deletion examples/mdetr/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def create_positive_map(tokenized, tokens_positive):
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
for j, tok_list in enumerate(tokens_positive):
for (beg, end) in tok_list:
for beg, end in tok_list:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
if beg_pos is None:
Expand Down
2 changes: 1 addition & 1 deletion examples/mdetr/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def construct_positive_map(
for i, ((idx_src, idx_tgt), tgt) in enumerate(zip(indices, target_tokens)):
cur_tokens = [tgt[j] for j in idx_tgt]
for j, tok_list in enumerate(cur_tokens):
for (beg, end) in tok_list:
for beg, end in tok_list:
beg_pos = char_to_token(tokenized, i, beg)
end_pos = char_to_token(tokenized, i, end - 1)

Expand Down
6 changes: 3 additions & 3 deletions examples/mdetr/tests/test_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def batched_pos_map(self, n_tokens, n_classes, pos_map):
batched_pos_map = torch.zeros((n_boxes, n_classes + 1), dtype=torch.bool)
cur_count = 0
for sample in pos_map:
batched_pos_map[
cur_count : cur_count + len(sample), : sample.shape[1]
] = sample
batched_pos_map[cur_count : cur_count + len(sample), : sample.shape[1]] = (
sample
)
cur_count += len(sample)
assert cur_count == len(batched_pos_map)

Expand Down
4 changes: 3 additions & 1 deletion examples/mugen/data/mugen_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ class MUGENDatasetArgs:
True # render smap for mugen (and shield) as bounding boxes
)
bbox_smap_for_monsters: bool = True # render smap for monsters as bounding boxes
use_manual_annotation: bool = False # if True will only use videos with manual annotation and skip those without
use_manual_annotation: bool = (
False # if True will only use videos with manual annotation and skip those without
)
use_auto_annotation: bool = (
True # if True will only use videos with auto annotation and skip those without
)
Expand Down
4 changes: 3 additions & 1 deletion examples/mugen/retrieval/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ class EvaluationArgs:
datamodule_args: DataModuleArgs = DataModuleArgs()
lightningmodule_args: LightningModuleArgs = LightningModuleArgs()
videoclip_args: VideoCLIPArgs = VideoCLIPArgs()
checkpoint_path: str = "https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
checkpoint_path: str = (
"https://pytorch.s3.amazonaws.com/models/multimodal/mugen/videoclip_lightning_mugen.pt"
)
accelerator: str = "auto"


Expand Down
4 changes: 1 addition & 3 deletions tests/diffusion_labs/test_dalle2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,5 @@ def test_dalle2_image_transform():
actual = transform({"x": image})["x"].sum()
normalized128 = 128 / 255 * 2 - 1
normalized0 = -1
expected = torch.tensor(
normalized128 * img_size**2 + 2 * normalized0 * img_size**2
)
expected = torch.tensor(normalized128 * img_size**2 + 2 * normalized0 * img_size**2)
assert_expected(actual, expected, rtol=0, atol=1e-4)
6 changes: 3 additions & 3 deletions torchmultimodal/diffusion_labs/models/vae/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def __init__(
block_out,
num_res_blocks,
dropout,
needs_downsample=True
if level_idx != num_resolutions - 1
else False,
needs_downsample=(
True if level_idx != num_resolutions - 1 else False
),
norm_groups=norm_groups,
norm_eps=norm_eps,
)
Expand Down
46 changes: 28 additions & 18 deletions torchmultimodal/models/flava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,28 @@ def forward(
if not skip_unmasked_mm_encoder:
# Unmasked multimodal embedding is not currently used by any of the FLAVA losses.
multimodal_outputs = self.encode_mm(
image_outputs.hidden_states[-1] # type: ignore
if image_outputs.hidden_states # type: ignore
else None,
text_outputs.hidden_states[-1] # type: ignore
if text_outputs.hidden_states # type: ignore
else None,
(
image_outputs.hidden_states[-1] # type: ignore
if image_outputs.hidden_states # type: ignore
else None
),
(
text_outputs.hidden_states[-1] # type: ignore
if text_outputs.hidden_states # type: ignore
else None
),
)
multimodal_masked_outputs = self.encode_mm(
image_masked_outputs.hidden_states[-1]
if image_masked_outputs.hidden_states
else None,
text_masked_outputs.hidden_states[-1]
if text_masked_outputs.hidden_states
else None,
(
image_masked_outputs.hidden_states[-1]
if image_masked_outputs.hidden_states
else None
),
(
text_masked_outputs.hidden_states[-1]
if text_masked_outputs.hidden_states
else None
),
)

return FLAVAOutput(
Expand Down Expand Up @@ -266,9 +274,9 @@ def _encode_data_to_embeddings(
Union[Tuple[TransformerOutput, Tensor], Optional[TransformerOutput]],
],
) -> Union[Tuple[TransformerOutput, Tensor], Optional[TransformerOutput]]:
output: Union[
Tuple[TransformerOutput, Tensor], TransformerOutput
] = TransformerOutput()
output: Union[Tuple[TransformerOutput, Tensor], TransformerOutput] = (
TransformerOutput()
)

if data is not None and selected_head_encoder in encoder_options:
output = encode_callable(data)
Expand Down Expand Up @@ -355,9 +363,11 @@ def forward(
text_sequence=flava_output.text.last_hidden_state,
image_masked_sequence=flava_output.image_masked.last_hidden_state,
text_masked_sequence=flava_output.text_masked.last_hidden_state,
multimodal_sequence=flava_output.multimodal.last_hidden_state
if not skip_unmasked_mm_encoder
else None,
multimodal_sequence=(
flava_output.multimodal.last_hidden_state
if not skip_unmasked_mm_encoder
else None
),
multimodal_masked_sequence=flava_output.multimodal_masked.last_hidden_state,
itm_labels=itm_labels,
mim_labels=image_labels,
Expand Down
6 changes: 3 additions & 3 deletions torchmultimodal/models/two_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(
raise ValueError(
"Towers should be shared if channel mapping is passed in"
)
self.shared_tower_id_to_channel_mapping: Optional[
Dict[str, Dict[str, str]]
] = shared_tower_id_to_channel_mapping
self.shared_tower_id_to_channel_mapping: Optional[Dict[str, Dict[str, str]]] = (
shared_tower_id_to_channel_mapping
)

def forward(self, channel_to_input: Dict[str, Tensor]) -> TwoTowerOutput:
tower_embeddings = OrderedDict()
Expand Down
1 change: 0 additions & 1 deletion torchmultimodal/modules/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def forward(
attention_mask: Optional[Tensor] = None,
return_hidden_states: bool = False,
) -> TransformerOutput:

"""
Args:
hidden_states (Tensor): input to the transformer encoder of shape bsz x seq_len x d_model
Expand Down
5 changes: 4 additions & 1 deletion torchmultimodal/modules/losses/blip2_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def forward(

# calculate similarities
assert model_output.text_features is not None
(sim_i2t, sim_t2i,) = compute_image_text_similarity(
(
sim_i2t,
sim_t2i,
) = compute_image_text_similarity(
model_output.image_features,
model_output.text_features,
temp=self.temp,
Expand Down

0 comments on commit dbeed97

Please sign in to comment.