Skip to content

Commit

Permalink
Drop subprocess test and test the core logic instead (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
farhadrgh authored and polinabinder1 committed Dec 17, 2024
1 parent f0dcad2 commit 492c00c
Showing 1 changed file with 12 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
# limitations under the License.

import glob
import os
import shlex
import subprocess
from pathlib import Path
from typing import get_args

import pandas as pd
import pytest
import torch
from lightning.fabric.plugins.environments.lightning import find_free_network_port
from torch.utils.data import DataLoader

from bionemo.core.data.load import load
Expand Down Expand Up @@ -147,7 +142,12 @@ def test_esm2_fine_tune_data_module_val_dataloader(data_module):
@pytest.mark.parametrize("prediction_interval", get_args(IntervalT))
@pytest.mark.skipif(check_gpu_memory(30), reason="Skipping test due to insufficient GPU memory")
def test_infer_runs(
tmpdir, dummy_protein_csv, dummy_protein_sequences, precision, padded_tokenized_sequences, prediction_interval
tmpdir,
dummy_protein_csv,
dummy_protein_sequences,
precision,
prediction_interval,
padded_tokenized_sequences,
):
data_path = dummy_protein_csv
result_dir = tmpdir / "results"
Expand Down Expand Up @@ -188,35 +188,9 @@ def test_infer_runs(
# token_logits are [sequence, batch, num_tokens]
assert results["token_logits"].shape[:-1] == (min_seq_len, len(dummy_protein_sequences))


@pytest.mark.skipif(check_gpu_memory(40), reason="Skipping test due to insufficient GPU memory")
@pytest.mark.parametrize("checkpoint_path", [esm2_3b_checkpoint_path, esm2_650m_checkpoint_path])
def test_infer_cli(tmpdir, dummy_protein_csv, checkpoint_path):
# Clear the GPU cache before starting the test
torch.cuda.empty_cache()

result_dir = Path(tmpdir.mkdir("results"))
results_path = result_dir / "esm2_infer_results.pt"
open_port = find_free_network_port()
env = dict(**os.environ)
env["MASTER_PORT"] = str(open_port)

cmd_str = f"""infer_esm2 \
--checkpoint-path {checkpoint_path} \
--data-path {dummy_protein_csv} \
--results-path {results_path} \
--precision bf16-mixed \
--include-hiddens \
--include-embeddings \
--include-logits \
--include-input-ids
""".strip()

cmd = shlex.split(cmd_str)
result = subprocess.run(
cmd,
cwd=tmpdir,
env=env,
capture_output=True,
)
assert result.returncode == 0, f"Failed with: {cmd_str}"
# test 1:1 mapping between input sequence and results
# this does not apply to "batch" prediction_interval mode since the order of batches may not be consistent
# due distributed processing. To address this, we optionally include input_ids in the predictions, allowing
# for accurate mapping post-inference.
if prediction_interval == "epoch":
assert torch.equal(padded_tokenized_sequences, results["input_ids"])

0 comments on commit 492c00c

Please sign in to comment.