Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
melihyilmaz committed Feb 13, 2024
1 parent f53e90b commit c040ccd
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions tests/unit_tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def test_beam_search_decode():

# Test _get_topk_beams() with finished beams in the batch.
model = Spec2Pep(n_beams=1, residues="massivekb", min_peptide_len=3)

# Sizes and other variables.
batch = 2 # B
beam = model.n_beams # S
Expand All @@ -440,33 +440,33 @@ def test_beam_search_decode():
size=(batch, length, vocab, beam), fill_value=torch.nan
)
scores = einops.rearrange(scores, "B L V S -> (B S) L V")
tokens = torch.zeros(batch * beam, length, dtype=torch.int64)
tokens = torch.zeros(batch * beam, length, dtype=torch.int64)

# Simulate non-zero amino acid-level probability scores.
scores[:, :step+1, :] = torch.rand(batch, step+1, vocab)
scores[:, : step + 1, :] = torch.rand(batch, step + 1, vocab)
scores[:, step, range(1, 4)] = torch.tensor([1.0, 2.0, 3.0])
# Simulate one finished and one unfinished beam in the same batch
tokens[0,:step] = torch.tensor([4, 14, 4, 28])
tokens[1,:step] = torch.tensor([4, 14, 4, 1])
# Set finished beams array to allow decoding from only one beam

# Simulate one finished and one unfinished beam in the same batch.
tokens[0, :step] = torch.tensor([4, 14, 4, 28])
tokens[1, :step] = torch.tensor([4, 14, 4, 1])

# Set finished beams array to allow decoding from only one beam.
test_finished_beams = torch.tensor([True, False])

new_tokens, new_scores = model._get_topk_beams(
tokens, scores, test_finished_beams, batch, step
)
# Only the second peptide should have a new token predicted

# Only the second peptide should have a new token predicted.
expected_tokens = torch.tensor(
[
[4, 14, 4, 28, 0],
[4, 14, 4, 1, 3],

]
)

assert torch.equal(new_tokens[:, : step + 1], expected_tokens)

assert torch.equal(new_tokens[:, : step + 1], expected_tokens)


def test_eval_metrics():
"""
Expand Down

0 comments on commit c040ccd

Please sign in to comment.