Skip to content

Commit

Permalink
random seed fixed at the beginning of each example python script (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa authored Sep 6, 2024
1 parent a92b460 commit 3a5bdf9
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 27 deletions.
12 changes: 8 additions & 4 deletions examples/ani1_x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
from mpi4py import MPI
import argparse

import glob
import numpy as np

import random
import numpy as np

import torch
from torch import tensor
from torch_geometric.data import Data

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

from torch_geometric.data import Data
from torch_geometric.transforms import Distance, Spherical, LocalCartesian

import hydragnn
Expand Down Expand Up @@ -132,6 +134,8 @@ def convert_trajectories_to_graphs(self):
flush=True,
)

random.shuffle(self.dataset)

def iter_data_buckets(self, h5filename, keys=["wb97x_dz.energy"]):
"""Iterate over buckets of data in ANI HDF5 file.
Yields dicts with atomic numbers (shape [Na,]) coordinated (shape [Nc, Na, 3])
Expand Down
8 changes: 3 additions & 5 deletions examples/csce/train_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@

import logging
import sys
from tqdm import tqdm
from mpi4py import MPI
from itertools import chain
import argparse
import time

import hydragnn
from hydragnn.utils.print_utils import print_distributed, iterate_tqdm, log
Expand All @@ -35,10 +32,11 @@
except ImportError:
pass

import torch_geometric.data
import torch
import torch.distributed as dist

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

csce_node_types = {"C": 0, "F": 1, "H": 2, "N": 3, "O": 4, "S": 5}

Expand Down
4 changes: 4 additions & 0 deletions examples/ising_model/train_ising.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
import torch
import torch.distributed as dist

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import warnings

## For create_configurations
Expand Down
5 changes: 4 additions & 1 deletion examples/lsms/lsms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from hydragnn.utils.lsmsdataset import LSMSDataset
from hydragnn.utils.serializeddataset import SerializedWriter, SerializedDataset
from hydragnn.preprocess.load_data import split_dataset
from hydragnn.utils.print_utils import log

try:
from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset
Expand All @@ -21,6 +20,10 @@
import torch
import torch.distributed as dist

# FIX random seed
random_state = 0
torch.manual_seed(random_state)


def info(*args, logtype="info", sep=" "):
getattr(logging, logtype)(sep.join(map(str, args)))
Expand Down
5 changes: 5 additions & 0 deletions examples/md17/md17.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os, json

import torch

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import torch_geometric

# deprecated in torch_geometric 2.0
Expand Down
11 changes: 7 additions & 4 deletions examples/mptrj/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from mpi4py import MPI
import argparse

import glob

import random
import numpy as np

import torch
from torch import tensor

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

from torch_geometric.data import Data

from torch_geometric.transforms import Distance, Spherical, LocalCartesian
Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(
flush=True,
)

random.shuffle(self.dataset)

def check_forces_values(self, forces):

# Calculate the L2 norm for each row
Expand Down
5 changes: 5 additions & 0 deletions examples/multidataset/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import argparse

import torch

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import numpy as np

import hydragnn
Expand Down
4 changes: 4 additions & 0 deletions examples/multidataset_hpo/gfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
## FIMME
torch.backends.cudnn.enabled = False

# FIX random seed
random_state = 0
torch.manual_seed(random_state)


def info(*args, logtype="info", sep=" "):
getattr(logging, logtype)(sep.join(map(str, args)))
Expand Down
4 changes: 4 additions & 0 deletions examples/multidataset_hpo/gfm_deephyper_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

torch.backends.cudnn.enabled = False

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

# deprecated in torch_geometric 2.0
try:
from torch_geometric.loader import DataLoader
Expand Down
6 changes: 5 additions & 1 deletion examples/multidataset_hpo/gfm_deephyper_multi_perlmutter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

torch.backends.cudnn.enabled = False

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

# deprecated in torch_geometric 2.0
try:
from torch_geometric.loader import DataLoader
Expand Down Expand Up @@ -163,7 +167,7 @@ def run(trial, dequed=None):
evaluator,
acq_func="UCB",
multi_point_strategy="cl_min", # Constant liar strategy
random_state=42,
random_state=random_state,
# Location where to store the results
log_dir=log_name,
# Number of threads used to update surrogate model of BO
Expand Down
4 changes: 4 additions & 0 deletions examples/ogb/train_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
import torch
import torch.distributed as dist

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

deepspeed_available = True
try:
import deepspeed
Expand Down
11 changes: 6 additions & 5 deletions examples/open_catalyst_2020/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from mpi4py import MPI
import argparse

import glob

import random
import numpy as np

import torch
from torch import tensor
from torch_geometric.data import Data

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import hydragnn
from hydragnn.utils.time_utils import Timer
Expand Down Expand Up @@ -105,6 +104,8 @@ def __init__(
flush=True,
)

random.shuffle(self.dataset)

def check_forces_values(self, forces):

# Calculate the L2 norm for each row
Expand Down
11 changes: 8 additions & 3 deletions examples/open_catalyst_2022/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
from mpi4py import MPI
import argparse

import glob
import numpy as np

import random
import numpy as np

import torch
from torch import tensor

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

from torch_geometric.data import Data

from torch_geometric.transforms import Distance, Spherical, LocalCartesian
Expand Down Expand Up @@ -112,6 +115,8 @@ def __init__(
flush=True,
)

random.shuffle(self.dataset)

def ase_to_torch_geom(self, atoms):
# set the atomic numbers, positions, and cell
atomic_numbers = torch.Tensor(atoms.get_atomic_numbers()).unsqueeze(1)
Expand Down
13 changes: 9 additions & 4 deletions examples/qm7x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@

import numpy as np

from torch_geometric.data import Data
from torch_geometric.transforms import RadiusGraph, Distance
import torch

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import torch.distributed as dist

from torch_geometric.data import Data
from torch_geometric.transforms import RadiusGraph, Distance


try:
from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset
except ImportError:
Expand Down Expand Up @@ -118,8 +125,6 @@ def read_setids(self, dirpath, setids_files):
mol_ids = list(fMOL.keys())

if self.dist:
## Random shuffle dirlist to avoid the same test/validation set
random.seed(43)
random.shuffle(mol_ids)

x = torch.tensor(len(mol_ids), requires_grad=False).to(get_device())
Expand Down
5 changes: 5 additions & 0 deletions examples/qm9/qm9.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os, json

import torch

# FIX random seed
random_state = 0
torch.manual_seed(random_state)

import torch_geometric

# deprecated in torch_geometric 2.0
Expand Down
4 changes: 4 additions & 0 deletions hydragnn/preprocess/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
import socket

import random

import torch
import torch.distributed as dist
import torch_geometric
Expand Down Expand Up @@ -304,7 +306,9 @@ def split_dataset(
):
if not stratify_splitting:
perc_val = (1 - perc_train) / 2
dataset = list(dataset)
data_size = len(dataset)
random.shuffle(dataset)
trainset = dataset[: int(data_size * perc_train)]
valset = dataset[
int(data_size * perc_train) : int(data_size * (perc_train + perc_val))
Expand Down

0 comments on commit 3a5bdf9

Please sign in to comment.