Skip to content

Commit

Permalink
Add option to specify default torch device
Browse files Browse the repository at this point in the history
  • Loading branch information
ruicoelhopedro committed May 21, 2024
1 parent 0f934a8 commit 925e9db
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
9 changes: 9 additions & 0 deletions piglot/bin/piglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
import shutil
from yaml import safe_dump
import torch
from piglot.objectives import read_objective
from piglot.optimisers import read_optimiser
from piglot.parameter import read_parameters
Expand Down Expand Up @@ -31,6 +32,13 @@ def parse_args():
type=str,
help='Configuration file to use',
)
# PyTorch compute device
parser.add_argument(
'--device',
type=str,
default='cpu',
help='Default device to use with PyTorch',
)

return parser.parse_args()

Expand All @@ -40,6 +48,7 @@ def main(config_path: str = None):
if config_path is None:
args = parse_args()
config_path = args.config
torch.set_default_device(args.device)
config = parse_config_file(config_path)
# Build output directory with a copy of the configuration file
output_dir = config["output"]
Expand Down
15 changes: 13 additions & 2 deletions piglot/optimisers/botorch/bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@
}


def get_default_torch_device() -> str:
"""Utility to return the default PyTorch device (pre Pytorch v2.3).
Returns
-------
str
Name of the default PyTorch device.
"""
return str(torch.tensor([0.0]).device)


def default_acquisition(
composite: bool,
multi_objective: bool,
Expand Down Expand Up @@ -147,7 +158,7 @@ def __init__(
seed: int = 1,
load_file: str = None,
export: str = None,
device: str = 'cpu',
device: str = None,
reference_point: List[float] = None,
nadir_scale: float = 0.1,
skip_initial: bool = False,
Expand All @@ -168,7 +179,7 @@ def __init__(
self.load_file = load_file
self.export = export
self.n_test = n_test
self.device = device
self.device = get_default_torch_device() if device is None else device
self.skip_initial = bool(skip_initial)
self.partitioning: FastNondominatedPartitioning = None
self.adjusted_ref_point = reference_point is None
Expand Down

0 comments on commit 925e9db

Please sign in to comment.