Skip to content

Commit

Permalink
utils renamed and black formatting applied
Browse files Browse the repository at this point in the history
  • Loading branch information
allaffa committed Aug 19, 2024
1 parent a92b460 commit ed1c31b
Show file tree
Hide file tree
Showing 74 changed files with 522 additions and 475 deletions.
2 changes: 1 addition & 1 deletion examples/alexandria/find_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def find_json_files(url):

url_root = "https://alexandria.icams.rub.de/data" # Replace with the actual URL

dirpath = "dataset/compressed_data"
dirpath = "datasets/compressed_data"

if os.path.exists(dirpath) and os.path.isdir(dirpath):
shutil.rmtree(dirpath)
Expand Down
42 changes: 24 additions & 18 deletions examples/alexandria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
from torch_geometric.transforms import Distance, Spherical, LocalCartesian

import hydragnn
from hydragnn.utils.time_utils import Timer
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
from hydragnn.utils.model import print_model
from hydragnn.utils.abstractbasedataset import AbstractBaseDataset
from hydragnn.utils.distdataset import DistDataset
from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset
from hydragnn.preprocess.utils import gather_deg
from hydragnn.preprocess.utils import RadiusGraph, RadiusGraphPBC
from hydragnn.utils.datasets.abstractbasedataset import AbstractBaseDataset
from hydragnn.utils.datasets.distdataset import DistDataset
from hydragnn.utils.datasets.pickledataset import (
SimplePickleWriter,
SimplePickleDataset,
)
from hydragnn.preprocess.graph_samples_checks_and_updates import gather_deg
from hydragnn.preprocess.graph_samples_checks_and_updates import (
RadiusGraph,
RadiusGraphPBC,
)
from hydragnn.preprocess.load_data import split_dataset

import hydragnn.utils.tracer as tr
from hydragnn.utils.print_utils import iterate_tqdm, log
import hydragnn.utils.profiling_and_tracing.tracer as tr
from hydragnn.utils.print.print_utils import iterate_tqdm, log

from generate_dictionaries_pure_elements import (
generate_dictionary_bulk_energies,
Expand All @@ -38,7 +44,7 @@
pass

import subprocess
from hydragnn.utils import nsplit
from hydragnn.utils.distributed import nsplit


def info(*args, logtype="info", sep=" "):
Expand Down Expand Up @@ -244,7 +250,7 @@ def get_magmoms_array_from_structure(structure):

def process_file_content(self, filepath):
"""
Download a file from a dataset of the Alexandria database with the respective index
Download a file from a datasets of the Alexandria database with the respective index
and write it to the LMDB file with the respective index.
Parameters
Expand Down Expand Up @@ -311,7 +317,7 @@ def get(self, idx):
type=bool,
default=True,
)
parser.add_argument("--ddstore", action="store_true", help="ddstore dataset")
parser.add_argument("--ddstore", action="store_true", help="ddstore datasets")
parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None)
parser.add_argument("--shmem", action="store_true", help="shmem")
parser.add_argument("--log", help="log name")
Expand All @@ -321,14 +327,14 @@ def get(self, idx):
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--adios",
help="Adios dataset",
help="Adios datasets",
action="store_const",
dest="format",
const="adios",
)
group.add_argument(
"--pickle",
help="Pickle dataset",
help="Pickle datasets",
action="store_const",
dest="format",
const="pickle",
Expand All @@ -341,7 +347,7 @@ def get(self, idx):
node_feature_names = ["atomic_number", "cartesian_coordinates", "forces"]
node_feature_dims = [1, 3, 3]
dirpwd = os.path.dirname(os.path.abspath(__file__))
datadir = os.path.join(dirpwd, "dataset")
datadir = os.path.join(dirpwd, "datasets")
##################################################################################################################
input_filename = os.path.join(dirpwd, args.inputfile)
##################################################################################################################
Expand Down Expand Up @@ -403,7 +409,7 @@ def get(self, idx):
## adios
if args.format == "adios":
fname = os.path.join(
os.path.dirname(__file__), "./dataset/%s.bp" % modelname
os.path.dirname(__file__), "./datasets/%s.bp" % modelname
)
adwriter = AdiosWriter(fname, comm)
adwriter.add("trainset", trainset)
Expand All @@ -417,7 +423,7 @@ def get(self, idx):
## pickle
elif args.format == "pickle":
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
os.path.dirname(__file__), "datasets", "%s.pickle" % modelname
)
attrs = dict()
attrs["pna_deg"] = deg
Expand Down Expand Up @@ -462,14 +468,14 @@ def get(self, idx):
"ddstore": args.ddstore,
"ddstore_width": args.ddstore_width,
}
fname = os.path.join(os.path.dirname(__file__), "./dataset/%s.bp" % modelname)
fname = os.path.join(os.path.dirname(__file__), "./datasets/%s.bp" % modelname)
trainset = AdiosDataset(fname, "trainset", comm, **opt, var_config=var_config)
valset = AdiosDataset(fname, "valset", comm, **opt, var_config=var_config)
testset = AdiosDataset(fname, "testset", comm, **opt, var_config=var_config)
elif args.format == "pickle":
info("Pickle load")
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
os.path.dirname(__file__), "datasets", "%s.pickle" % modelname
)
trainset = SimplePickleDataset(
basedir=basedir, label="trainset", var_config=var_config
Expand Down
47 changes: 24 additions & 23 deletions examples/ani1_x/train.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,42 @@
import os, re, json
import os, json
import logging
import sys
from mpi4py import MPI
import argparse

import glob

import random
import numpy as np

import torch
from torch import tensor
from torch_geometric.data import Data

from torch_geometric.transforms import Distance, Spherical, LocalCartesian

import hydragnn
from hydragnn.utils.time_utils import Timer
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
from hydragnn.utils.model import print_model
from hydragnn.utils.abstractbasedataset import AbstractBaseDataset
from hydragnn.utils.distdataset import DistDataset
from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset
from hydragnn.preprocess.utils import gather_deg
from hydragnn.preprocess.utils import RadiusGraph, RadiusGraphPBC
from hydragnn.utils.datasets.abstractbasedataset import AbstractBaseDataset
from hydragnn.utils.datasets.distdataset import DistDataset
from hydragnn.utils.datasets.pickledataset import (
SimplePickleWriter,
SimplePickleDataset,
)
from hydragnn.preprocess.graph_samples_checks_and_updates import gather_deg
from hydragnn.preprocess.graph_samples_checks_and_updates import (
RadiusGraph,
RadiusGraphPBC,
)
from hydragnn.preprocess.load_data import split_dataset

import hydragnn.utils.tracer as tr
import hydragnn.utils.profiling_and_tracing.tracer as tr

from hydragnn.utils.print_utils import iterate_tqdm, log
from hydragnn.utils.print.print_utils import log

try:
from hydragnn.utils.adiosdataset import AdiosWriter, AdiosDataset
except ImportError:
pass

import subprocess
from hydragnn.utils import nsplit
from hydragnn.utils.distributed import nsplit

import h5py

Expand Down Expand Up @@ -189,7 +190,7 @@ def get(self, idx):
type=bool,
default=True,
)
parser.add_argument("--ddstore", action="store_true", help="ddstore dataset")
parser.add_argument("--ddstore", action="store_true", help="ddstore datasets")
parser.add_argument("--ddstore_width", type=int, help="ddstore width", default=None)
parser.add_argument("--shmem", action="store_true", help="shmem")
parser.add_argument("--log", help="log name")
Expand All @@ -199,14 +200,14 @@ def get(self, idx):
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--adios",
help="Adios dataset",
help="Adios datasets",
action="store_const",
dest="format",
const="adios",
)
group.add_argument(
"--pickle",
help="Pickle dataset",
help="Pickle datasets",
action="store_const",
dest="format",
const="pickle",
Expand All @@ -219,7 +220,7 @@ def get(self, idx):
node_feature_names = ["atomic_number", "cartesian_coordinates", "forces"]
node_feature_dims = [1, 3, 3]
dirpwd = os.path.dirname(os.path.abspath(__file__))
datadir = os.path.join(dirpwd, "dataset")
datadir = os.path.join(dirpwd, "datasets")
##################################################################################################################
input_filename = os.path.join(dirpwd, args.inputfile)
##################################################################################################################
Expand Down Expand Up @@ -281,7 +282,7 @@ def get(self, idx):
## adios
if args.format == "adios":
fname = os.path.join(
os.path.dirname(__file__), "./dataset/%s.bp" % modelname
os.path.dirname(__file__), "./datasets/%s.bp" % modelname
)
adwriter = AdiosWriter(fname, comm)
adwriter.add("trainset", trainset)
Expand All @@ -295,7 +296,7 @@ def get(self, idx):
## pickle
elif args.format == "pickle":
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
os.path.dirname(__file__), "datasets", "%s.pickle" % modelname
)
attrs = dict()
attrs["pna_deg"] = deg
Expand Down Expand Up @@ -340,14 +341,14 @@ def get(self, idx):
"ddstore": args.ddstore,
"ddstore_width": args.ddstore_width,
}
fname = os.path.join(os.path.dirname(__file__), "./dataset/%s.bp" % modelname)
fname = os.path.join(os.path.dirname(__file__), "./datasets/%s.bp" % modelname)
trainset = AdiosDataset(fname, "trainset", comm, **opt, var_config=var_config)
valset = AdiosDataset(fname, "valset", comm, **opt, var_config=var_config)
testset = AdiosDataset(fname, "testset", comm, **opt, var_config=var_config)
elif args.format == "pickle":
info("Pickle load")
basedir = os.path.join(
os.path.dirname(__file__), "dataset", "%s.pickle" % modelname
os.path.dirname(__file__), "datasets", "%s.pickle" % modelname
)
trainset = SimplePickleDataset(
basedir=basedir, label="trainset", var_config=var_config
Expand Down
51 changes: 28 additions & 23 deletions examples/csce/train_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
import time

import hydragnn
from hydragnn.utils.print_utils import print_distributed, iterate_tqdm, log
from hydragnn.utils.time_utils import Timer
from hydragnn.utils.distdataset import DistDataset
from hydragnn.utils.pickledataset import SimplePickleWriter, SimplePickleDataset
from hydragnn.utils.smiles_utils import (
from hydragnn.utils.print.print_utils import print_distributed, iterate_tqdm, log
from hydragnn.utils.profiling_and_tracing.time_utils import Timer
from hydragnn.utils.datasets.distdataset import DistDataset
from hydragnn.utils.datasets.pickledataset import (
SimplePickleWriter,
SimplePickleDataset,
)
from hydragnn.utils.descriptors_and_embeddings.smiles_utils import (
get_node_attribute_name,
generate_graphdata_from_smilestr,
)
from hydragnn.preprocess.utils import gather_deg
from hydragnn.utils import nsplit
import hydragnn.utils.tracer as tr
from hydragnn.preprocess.graph_samples_checks_and_updates import gather_deg
from hydragnn.utils.distributed import nsplit
import hydragnn.utils.profiling_and_tracing.tracer as tr

import numpy as np

Expand Down Expand Up @@ -163,42 +166,42 @@ def __getitem__(self, idx):
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--adios",
help="Adios dataset",
help="Adios datasets",
action="store_const",
dest="format",
const="adios",
)
group.add_argument(
"--pickle",
help="Pickle dataset",
help="Pickle datasets",
action="store_const",
dest="format",
const="pickle",
)
group.add_argument(
"--csv", help="CSV dataset", action="store_const", dest="format", const="csv"
"--csv", help="CSV datasets", action="store_const", dest="format", const="csv"
)
parser.set_defaults(format="adios")
group1 = parser.add_mutually_exclusive_group()
group1.add_argument(
"--shmem",
help="shmem dataset",
help="shmem datasets",
action="store_const",
dest="dataset",
dest="datasets",
const="shmem",
)
group1.add_argument(
"--ddstore",
help="ddstore dataset",
help="ddstore datasets",
action="store_const",
dest="dataset",
dest="datasets",
const="ddstore",
)
group1.add_argument(
"--simple",
help="no special dataset",
help="no special datasets",
action="store_const",
dest="dataset",
dest="datasets",
const="simple",
)
parser.set_defaults(dataset="simple")
Expand All @@ -208,7 +211,7 @@ def __getitem__(self, idx):
graph_feature_names = ["GAP"]
graph_feature_dim = [1]
dirpwd = os.path.dirname(os.path.abspath(__file__))
datafile = os.path.join(dirpwd, "dataset/csce_gap_synth.csv")
datafile = os.path.join(dirpwd, "datasets/csce_gap_synth.csv")
##################################################################################################################
inputfilesubstr = args.inputfilesubstr
input_filename = os.path.join(dirpwd, "csce_" + inputfilesubstr + ".json")
Expand Down Expand Up @@ -295,7 +298,7 @@ def __getitem__(self, idx):
config["pna_deg"] = deg

## pickle
basedir = os.path.join(os.path.dirname(__file__), "dataset", "pickle")
basedir = os.path.join(os.path.dirname(__file__), "datasets", "pickle")
attrs = dict()
attrs["pna_deg"] = deg
SimplePickleWriter(
Expand All @@ -318,7 +321,7 @@ def __getitem__(self, idx):
use_subdir=True,
)

fname = os.path.join(os.path.dirname(__file__), "dataset", "csce_gap.bp")
fname = os.path.join(os.path.dirname(__file__), "datasets", "csce_gap.bp")
adwriter = AdiosWriter(fname, comm)
adwriter.add("trainset", trainset)
adwriter.add("valset", valset)
Expand Down Expand Up @@ -346,20 +349,22 @@ def __getitem__(self, idx):

opt = {"preload": False, "shmem": shmem, "ddstore": ddstore}
fname = fname = os.path.join(
os.path.dirname(__file__), "dataset", "csce_gap.bp"
os.path.dirname(__file__), "datasets", "csce_gap.bp"
)
trainset = AdiosDataset(fname, "trainset", comm, **opt)
valset = AdiosDataset(fname, "valset", comm)
testset = AdiosDataset(fname, "testset", comm)
comm.Barrier()
elif args.format == "csv":
fname = os.path.join(os.path.dirname(__file__), "dataset", "csce_gap_synth.csv")
fname = os.path.join(
os.path.dirname(__file__), "datasets", "csce_gap_synth.csv"
)
fact = CSCEDatasetFactory(fname, args.sampling, var_config=var_config)
trainset = CSCEDataset(fact, "trainset")
valset = CSCEDataset(fact, "valset")
testset = CSCEDataset(fact, "testset")
elif args.format == "pickle":
basedir = os.path.join(os.path.dirname(__file__), "dataset", "pickle")
basedir = os.path.join(os.path.dirname(__file__), "datasets", "pickle")
trainset = SimplePickleDataset(basedir, "trainset")
valset = SimplePickleDataset(basedir, "valset")
testset = SimplePickleDataset(basedir, "testset")
Expand Down
Loading

0 comments on commit ed1c31b

Please sign in to comment.