diff --git a/psiflow/data/dataset.py b/psiflow/data/dataset.py index fb5c936..f74fafc 100644 --- a/psiflow/data/dataset.py +++ b/psiflow/data/dataset.py @@ -38,6 +38,11 @@ @psiflow.serializable class Dataset: + """ + A class representing a dataset of atomic structures. + + This class provides methods for manipulating and analyzing collections of atomic structures. + """ extxyz: psiflow._DataFuture def __init__( @@ -45,6 +50,16 @@ def __init__( states: Optional[list[Union[AppFuture, Geometry]], AppFuture], extxyz: Optional[psiflow._DataFuture] = None, ): + """ + Initialize a Dataset. + + Args: + states: List of Geometry instances or AppFutures representing geometries. + extxyz: Optional DataFuture representing an existing extxyz file. + + Note: + Either states or extxyz should be provided, not both. + """ if extxyz is not None: assert states is None self.extxyz = extxyz @@ -62,9 +77,21 @@ def __init__( ).outputs[0] def length(self) -> AppFuture: + """ + Get the number of structures in the dataset. + + Returns: + AppFuture: Future representing the number of structures. + """ return count_frames(inputs=[self.extxyz]) def shuffle(self): + """ + Shuffle the order of structures in the dataset. + + Returns: + Dataset: A new Dataset with shuffled structures. + """ extxyz = shuffle( inputs=[self.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -75,6 +102,15 @@ def __getitem__( self, index: Union[int, slice, list[int], AppFuture], ) -> Union[Dataset, AppFuture]: + """ + Get a subset of the dataset or a single structure. + + Args: + index: Integer, slice, list of integers, or AppFuture representing indices. + + Returns: + Union[Dataset, AppFuture]: A new Dataset or an AppFuture of a single Geometry. + """ if isinstance(index, int): future = read_frames( [index], @@ -91,6 +127,15 @@ def __getitem__( return Dataset(None, extxyz) def save(self, path: Union[Path, str]) -> AppFuture: + """ + Save the dataset to a file. + + Args: + path: Path to save the dataset. + + Returns: + AppFuture: Future representing the completion of the save operation. + """ path = psiflow.resolve_and_check(Path(path)) _ = copy_data_future( inputs=[self.extxyz], @@ -98,9 +143,24 @@ def save(self, path: Union[Path, str]) -> AppFuture: ) def geometries(self) -> AppFuture: + """ + Get all geometries in the dataset. + + Returns: + AppFuture: Future representing a list of Geometry instances. + """ return read_frames(inputs=[self.extxyz]) def __add__(self, dataset: Dataset) -> Dataset: + """ + Concatenate two datasets. + + Args: + dataset: Another Dataset to add to this one. + + Returns: + Dataset: A new Dataset containing structures from both datasets. + """ extxyz = join_frames( inputs=[self.extxyz, dataset.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -108,6 +168,15 @@ def __add__(self, dataset: Dataset) -> Dataset: return Dataset(None, extxyz) def subtract_offset(self, **atomic_energies: Union[float, AppFuture]) -> Dataset: + """ + Subtract atomic energy offsets from the dataset. + + Args: + **atomic_energies: Atomic energies for each element. + + Returns: + Dataset: A new Dataset with adjusted energies. + """ assert len(atomic_energies) > 0 extxyz = apply_offset( True, @@ -118,6 +187,15 @@ def subtract_offset(self, **atomic_energies: Union[float, AppFuture]) -> Dataset return Dataset(None, extxyz) def add_offset(self, **atomic_energies) -> Dataset: + """ + Add atomic energy offsets to the dataset. + + Args: + **atomic_energies: Atomic energies for each element. + + Returns: + Dataset: A new Dataset with adjusted energies. + """ assert len(atomic_energies) > 0 extxyz = apply_offset( False, @@ -128,9 +206,21 @@ def add_offset(self, **atomic_energies) -> Dataset: return Dataset(None, extxyz) def elements(self): + """ + Get the set of elements present in the dataset. + + Returns: + AppFuture: Future representing a set of element symbols. + """ return get_elements(inputs=[self.extxyz]) def reset(self): + """ + Reset all structures in the dataset. + + Returns: + Dataset: A new Dataset with reset structures. + """ extxyz = reset_frames( inputs=[self.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -138,6 +228,12 @@ def reset(self): return Dataset(None, extxyz) def clean(self): + """ + Clean all structures in the dataset. + + Returns: + Dataset: A new Dataset with cleaned structures. + """ extxyz = clean_frames( inputs=[self.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -150,6 +246,17 @@ def get( atom_indices: Optional[list[int]] = None, elements: Optional[list[str]] = None, ): + """ + Extract specified quantities from the dataset. + + Args: + *quantities: Names of quantities to extract. + atom_indices: Optional list of atom indices to consider. + elements: Optional list of element symbols to consider. + + Returns: + Union[AppFuture, tuple[AppFuture, ...]]: Future(s) representing the extracted quantities. + """ result = extract_quantities( quantities, atom_indices, @@ -166,6 +273,16 @@ def evaluate( computable: Computable, batch_size: Optional[int] = None, ) -> Dataset: + """ + Evaluate a Computable on the dataset. + + Args: + computable: Computable object to evaluate. + batch_size: Optional batch size for evaluation. + + Returns: + Dataset: A new Dataset with evaluation results. + """ if batch_size is not None: outputs = computable.compute(self, batch_size=batch_size) else: @@ -182,6 +299,15 @@ def filter( self, quantity: str, ) -> Dataset: + """ + Filter the dataset based on a specified quantity. + + Args: + quantity: The quantity to filter on. + + Returns: + Dataset: A new Dataset containing only structures that pass the filter. + """ assert quantity in QUANTITIES extxyz = app_filter( quantity, @@ -191,6 +317,12 @@ def filter( return Dataset(None, extxyz) def not_null(self) -> Dataset: + """ + Remove null states from the dataset. + + Returns: + Dataset: A new Dataset without null states. + """ extxyz = not_null( inputs=[self.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -198,6 +330,12 @@ def not_null(self) -> Dataset: return Dataset(None, extxyz) def align_axes(self): + """ + Adopt a canonical orientation for all (periodic) structures in the dataset. + + Returns: + Dataset: A new Dataset with aligned structures. + """ extxyz = align_axes( inputs=[self.extxyz], outputs=[psiflow.context().new_file("data_", ".xyz")], @@ -205,6 +343,16 @@ def align_axes(self): return Dataset(None, extxyz) def split(self, fraction, shuffle=True): # auto-shuffles + """ + Split the dataset into training and validation sets. + + Args: + fraction: Fraction of data to use for training. + shuffle: Whether to shuffle before splitting. + + Returns: + tuple[Dataset, Dataset]: Training and validation datasets. + """ train, valid = get_train_valid_indices( self.length(), fraction, @@ -215,6 +363,15 @@ def split(self, fraction, shuffle=True): # auto-shuffles def assign_identifiers( self, identifier: Union[int, AppFuture, None] = None ) -> AppFuture: + """ + Assign identifiers to structures in the dataset. + + Args: + identifier: Starting identifier or AppFuture representing it. + + Returns: + AppFuture: Future representing the next available identifier. + """ result = assign_identifiers( identifier, inputs=[self.extxyz], @@ -228,6 +385,15 @@ def load( cls, path_xyz: Union[Path, str], ) -> Dataset: + """ + Load a dataset from a file. + + Args: + path_xyz: Path to the XYZ file. + + Returns: + Dataset: Loaded dataset. + """ path_xyz = psiflow.resolve_and_check(Path(path_xyz)) assert path_xyz.exists() # needs to be locally accessible return cls(None, extxyz=File(str(path_xyz))) @@ -235,6 +401,18 @@ def load( @typeguard.typechecked def _concatenate_multiple(*args: list[np.ndarray]) -> list[np.ndarray]: + """ + Concatenate multiple lists of arrays. + + Args: + *args: Lists of numpy arrays to concatenate. + + Returns: + list[np.ndarray]: List of concatenated arrays. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ narrays = len(args[0]) for arg in args: assert isinstance(arg, list) @@ -254,6 +432,19 @@ def _aggregate_multiple( *arrays_list, coefficients: Optional[np.ndarray] = None, ) -> list[np.ndarray]: + """ + Aggregate multiple lists of arrays with optional coefficients. + + Args: + *arrays_list: Lists of arrays to aggregate. + coefficients: Optional coefficients for weighted aggregation. + + Returns: + list[np.ndarray]: List of aggregated arrays. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ if coefficients is None: coefficients = np.ones(len(arrays_list)) else: @@ -280,6 +471,24 @@ def batch_apply( reduce_func: Optional[PythonApp] = None, **app_kwargs, ) -> AppFuture: + """ + Apply a set of apps to batches of data. + + Args: + apply_apps: Tuple of PythonApps or Callables to apply. + arg: Dataset or list of Geometries to process. + batch_size: Size of each batch. + length: Total number of items to process. + outputs: List of output files. + reduce_func: Optional function to reduce results. + **app_kwargs: Additional keyword arguments for the apps. + + Returns: + AppFuture: Future representing the result of batch application. + + Note: + This function is wrapped as a Parsl join_app. + """ nbatches = math.ceil(length / batch_size) batches = [psiflow.context().new_file("data_", ".xyz") for _ in range(nbatches)] future = batch_frames(batch_size, inputs=[arg.extxyz], outputs=batches) @@ -301,6 +510,18 @@ def batch_apply( @python_app(executors=["default_threads"]) def get_length(arg): + """ + Get the length of the input argument. + + Args: + arg: Input to get the length of. + + Returns: + int: Length of the input. + + Note: + This function is wrapped as a Parsl app and executed using the default_threads executor. + """ if isinstance(arg, list): return len(arg) else: @@ -315,6 +536,19 @@ def compute( reduce_func: Union[PythonApp, Callable] = aggregate_multiple, batch_size: Optional[int] = None, ) -> Union[list[AppFuture], AppFuture]: + """ + Compute results by applying apps to the input data. + + Args: + arg: Input data to compute on. + *apply_apps: Apps to apply to the data. + outputs_: Names of output quantities. + reduce_func: Function to reduce results. + batch_size: Optional batch size for processing. + + Returns: + Union[list[AppFuture], AppFuture]: Future(s) representing computation results. + """ if outputs_ is not None and not isinstance(outputs_, list): outputs_ = [outputs_] if batch_size is not None: @@ -359,6 +593,13 @@ def compute( @typeguard.typechecked class Computable: + """ + Base class for computable objects. + + Attributes: + outputs (ClassVar[tuple[str, ...]]): Names of output quantities. + batch_size (ClassVar[Optional[int]]): Default batch size for computation. + """ outputs: ClassVar[tuple[str, ...]] = () batch_size: ClassVar[Optional[int]] = None @@ -368,6 +609,17 @@ def compute( outputs: Union[str, list[str], None] = None, batch_size: Optional[int] = -1, # if -1: take class default ) -> Union[list[AppFuture], AppFuture]: + """ + Compute results for the given input. + + Args: + arg: Input data to compute on. + outputs: Names of output quantities. + batch_size: Batch size for computation. + + Returns: + Union[list[AppFuture], AppFuture]: Future(s) representing computation results. + """ if outputs is None: outputs = list(self.__class__.outputs) if batch_size == -1: