Skip to content

Commit

Permalink
Add support for non numeric config data in Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
MCFlowMace committed Aug 15, 2023
1 parent e499484 commit b74a44d
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions hercules/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.interpolate import interp1d
import dill as pickle
from pathlib import Path
import numbers

from .constants import PY_DATA_NAME

Expand All @@ -28,11 +29,12 @@ class Dataset:

_class_version = '2.0'

def __init__(self, directory, config_list):
def __init__(self, directory, config_list, interpolate=True):

self._directory = Path(directory)
self._directory.mkdir(parents=True, exist_ok=True)
self._version = self._class_version
self._interpolate_axes = interpolate
self._make_index(config_list)

def _make_index(self, config_list):
Expand All @@ -51,7 +53,7 @@ def _make_index(self, config_list):
self._config_data_keys = list(config_list.get_config_data_keys()) # maps a list index to a key
config_list_internal = config_list.get_internal_list()

self._axes = [np.empty(len(config_list_internal)) for k in self._config_data_keys]
self._initialize_axes(config_list_internal)

for i, sim_config in enumerate(config_list_internal):
path = sim_config.sim_name
Expand All @@ -66,7 +68,22 @@ def _make_index(self, config_list):
for i in range(len(self._axes)):
self._axes[i] = np.sort(np.unique(self._axes[i]))

self._interpolate_all()
if self._interpolate_axes:
self._interpolate_all()

def _initialize_axes(self, config_list_internal):
self._axes = []
config_data0 = config_list_internal[0].get_config_data()
for k in self._config_data_keys:
var0 = config_data0[k]

if isinstance(var0, numbers.Number):
self._axes.append(np.empty(len(config_list_internal), dtype=type(var0)))
else:
self._axes.append([i for i in range(len(config_list_internal))])
if self._interpolate:
print('Warning! Interpolation of axes not possible for non-numeric types! Deactivating interpolation')
self._interpolate_axes = False

def _interpolate_all(self):

Expand Down Expand Up @@ -103,6 +120,8 @@ def get_path(self, params, method='interpolated'):
raise ValueError(f'params has len {len(params)} but dataset expects len {len(self._axes)}!')

if method == 'interpolated':
if not self._interpolate:
raise ValueError('Dataset is not interpolated!')
key = [self._axes_int[i](params[i]).item() for i in range(len(params))]
elif method == 'index':
key = [self._axes[i][params[i]] for i in range(len(params))]
Expand Down

0 comments on commit b74a44d

Please sign in to comment.