Skip to content

Commit

Permalink
fix blip2 gpu test (facebookresearch#502)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#502

as title

1. add attn_mask fixture for oss unittest
2. remove images fixture for both internal and oss tests because it's not in use.

Reviewed By: ebsmothers

Differential Revision: D50578298

fbshipit-source-id: d3f4e8b92cdb8d67eb9755fbef37678919a02b7e
  • Loading branch information
Peng Chen authored and facebook-github-bot committed Oct 24, 2023
1 parent 49cf907 commit 759a8e4
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions tests/modules/losses/test_blip2_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def vit():


class TestBLIP2Stage1Loss:
@pytest.fixture
def images(self):
return torch.ones(4, 3, 2, 2)

@pytest.fixture
def input_ids(self):
return torch.ones(4, 4).long()
Expand Down Expand Up @@ -139,6 +135,10 @@ def blip2(self, dim_q, dim_kv, qformer_model_for_clm, vit):
blip2.eval()
return blip2

@pytest.fixture
def attn_mask(self):
return torch.Tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]])

def test_local_loss(self, all_attn_mask, blip2_output, blip2, dim_q, input_ids):
blip2_loss = Blip2Phase1Loss(dim_q=dim_q)
init_weights_with_constant(blip2_loss)
Expand Down Expand Up @@ -201,7 +201,6 @@ def _model_worker(
sync_file: str,
world_size: int,
global_batch_size: int,
all_images: torch.Tensor,
all_input_ids: torch.Tensor,
all_attn_mask: torch.Tensor,
blip2_output: Blip2Output,
Expand All @@ -216,7 +215,6 @@ def _model_worker(
all_attn_mask = torch.ones([4, 4])

# Split inputs across GPUs
local_images = torch.split(all_images, local_batch_size)[gpu_id].cuda(gpu_id)
local_input_ids = torch.split(all_input_ids, local_batch_size)[gpu_id].cuda(
gpu_id
)
Expand Down Expand Up @@ -256,7 +254,6 @@ def _model_worker(
loss = loss_fn(
model_output=local_blip2_output,
blip2=blip2,
images=local_images,
input_ids=local_input_ids,
attention_mask=local_attn_mask,
).total_loss
Expand Down

0 comments on commit 759a8e4

Please sign in to comment.