From 925e9dbb3bc40bbff71247af2564dfe3a6691a81 Mon Sep 17 00:00:00 2001 From: Rui Coelho Date: Tue, 21 May 2024 19:20:16 +0100 Subject: [PATCH] Add option to specify default torch device --- piglot/bin/piglot.py | 9 +++++++++ piglot/optimisers/botorch/bayes.py | 15 +++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/piglot/bin/piglot.py b/piglot/bin/piglot.py index d51bcf6..eb236be 100644 --- a/piglot/bin/piglot.py +++ b/piglot/bin/piglot.py @@ -4,6 +4,7 @@ import argparse import shutil from yaml import safe_dump +import torch from piglot.objectives import read_objective from piglot.optimisers import read_optimiser from piglot.parameter import read_parameters @@ -31,6 +32,13 @@ def parse_args(): type=str, help='Configuration file to use', ) + # PyTorch compute device + parser.add_argument( + '--device', + type=str, + default='cpu', + help='Default device to use with PyTorch', + ) return parser.parse_args() @@ -40,6 +48,7 @@ def main(config_path: str = None): if config_path is None: args = parse_args() config_path = args.config + torch.set_default_device(args.device) config = parse_config_file(config_path) # Build output directory with a copy of the configuration file output_dir = config["output"] diff --git a/piglot/optimisers/botorch/bayes.py b/piglot/optimisers/botorch/bayes.py index 1a233ec..09600cb 100644 --- a/piglot/optimisers/botorch/bayes.py +++ b/piglot/optimisers/botorch/bayes.py @@ -76,6 +76,17 @@ } +def get_default_torch_device() -> str: + """Utility to return the default PyTorch device (pre Pytorch v2.3). + + Returns + ------- + str + Name of the default PyTorch device. + """ + return str(torch.tensor([0.0]).device) + + def default_acquisition( composite: bool, multi_objective: bool, @@ -147,7 +158,7 @@ def __init__( seed: int = 1, load_file: str = None, export: str = None, - device: str = 'cpu', + device: str = None, reference_point: List[float] = None, nadir_scale: float = 0.1, skip_initial: bool = False, @@ -168,7 +179,7 @@ def __init__( self.load_file = load_file self.export = export self.n_test = n_test - self.device = device + self.device = get_default_torch_device() if device is None else device self.skip_initial = bool(skip_initial) self.partitioning: FastNondominatedPartitioning = None self.adjusted_ref_point = reference_point is None