Skip to content

Commit

Permalink
updated the dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
farhadrgh committed Dec 11, 2024
1 parent 5fe66cd commit 2ed23cb
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import tempfile
from pathlib import Path
from typing import Sequence, Tuple

import lightning.pytorch as pl
import pandas as pd
from lightning.pytorch.callbacks import Callback, RichModelSummary
from lightning.pytorch.loggers import TensorBoardLogger
from megatron.core.optimizer.optimizer_config import OptimizerConfig
Expand Down Expand Up @@ -149,7 +150,8 @@ def train_model(

if __name__ == "__main__":
# set the results directory
experiment_results_dir = tempfile.TemporaryDirectory().name
temp_dir = tempfile.mkdtemp()
print(f"Temporary directory created at: {temp_dir}")

# create a List[Tuple] with (sequence, target) values
artificial_sequence_data = [
Expand All @@ -165,9 +167,14 @@ def train_model(
"SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]
data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]
# Create a DataFrame
df = pd.DataFrame(data, columns=["sequences", "labels"])
# Save the DataFrame to a CSV file
csv_file = os.path.join(temp_dir, "protein_dataset.csv")
df.to_csv(csv_file, index=False)

# we are training and validating on the same dataset for simplicity
dataset = InMemorySingleValueDataset(data)
dataset = InMemorySingleValueDataset(csv_file)
data_module = ESM2FineTuneDataModule(train_dataset=dataset, valid_dataset=dataset)

experiment_name = "finetune_regressor"
Expand All @@ -181,7 +188,7 @@ def train_model(

checkpoint, metrics, trainer = train_model(
experiment_name=experiment_name,
experiment_dir=Path(experiment_results_dir), # new checkpoint will land in a subdir of this
experiment_dir=Path(temp_dir), # new checkpoint will land in a subdir of this
config=config, # same config as before since we are just continuing training
data_module=data_module,
n_steps_train=n_steps_train,
Expand Down

0 comments on commit 2ed23cb

Please sign in to comment.