diff --git a/docs/source/index.rst b/docs/source/index.rst index a44a3b6..833417c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -19,9 +19,10 @@ GPax is a small Python package for physics-based Gaussian processes (GPs) built :caption: Package Content models - hypo acquisition kernels + priors + hypo utils .. toctree:: diff --git a/docs/source/models.rst b/docs/source/models.rst index b419fd7..f9ab107 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -96,3 +96,13 @@ Multi-Task Learning :undoc-members: :member-order: bysource :show-inheritance: + +Structured Probabilistic Models +------------------------------- +.. autoclass:: gpax.models.spm.sPM + :members: + :inherited-members: + :undoc-members: + :member-order: bysource + :show-inheritance: + diff --git a/docs/source/priors.rst b/docs/source/priors.rst new file mode 100644 index 0000000..2f1ca05 --- /dev/null +++ b/docs/source/priors.rst @@ -0,0 +1,23 @@ +Priors +====== + +.. autofunction:: gpax.utils.normal_dist + +.. autofunction:: gpax.utils.lognormal_dist + +.. autofunction:: gpax.utils.halfnormal_dist + +.. autofunction:: gpax.utils.gamma_dist + +.. autofunction:: gpax.utils.uniform_dist + +.. autofunction:: gpax.utils.place_normal_prior + +.. autofunction:: gpax.utils.place_lognormal_prior + +.. autofunction:: gpax.utils.place_halfnormal_prior + +.. autofunction:: gpax.utils.place_uniform_prior + +.. autofunction:: gpax.utils.place_gamma_prior + diff --git a/docs/source/utils.rst b/docs/source/utils.rst index 4555849..3334c48 100644 --- a/docs/source/utils.rst +++ b/docs/source/utils.rst @@ -1,40 +1,17 @@ Utilities ========= -Priors ------- +Automatic function setters +-------------------------- -.. autofunction:: gpax.utils.normal_dist +.. autofunction:: gpax.utils.set_fn -.. autofunction:: gpax.utils.lognormal_dist - -.. autofunction:: gpax.utils.halfnormal_dist - -.. autofunction:: gpax.utils.gamma_dist - -.. autofunction:: gpax.utils.uniform_dist - -.. autofunction:: gpax.utils.place_normal_prior - -.. autofunction:: gpax.utils.place_lognormal_prior - -.. autofunction:: gpax.utils.place_halfnormal_prior - -.. autofunction:: gpax.utils.place_uniform_prior - -.. autofunction:: gpax.utils.place_gamma_prior +.. autofunction:: gpax.utils.set_kernel_fn Other utilities --------------- -.. autoclass:: gpax.models.spm.sPM - :members: - :inherited-members: - :undoc-members: - :member-order: bysource - :show-inheritance: - .. autofunction:: gpax.utils.dviz .. autofunction:: gpax.utils.get_keys diff --git a/gpax/__init__.py b/gpax/__init__.py index f48d684..84de056 100644 --- a/gpax/__init__.py +++ b/gpax/__init__.py @@ -1,4 +1,5 @@ from .__version__ import version as __version__ +from . import priors from . import utils from . import kernels from . import acquisition @@ -7,6 +8,6 @@ vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP, MeasuredNoiseGP, viSparseGP, BNN) -__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL", +__all__ = ["priors", "utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL", "viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP", "UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"] diff --git a/gpax/priors/__init__.py b/gpax/priors/__init__.py new file mode 100644 index 0000000..3ef70f0 --- /dev/null +++ b/gpax/priors/__init__.py @@ -0,0 +1 @@ +from .priors import * \ No newline at end of file diff --git a/gpax/utils/priors.py b/gpax/priors/priors.py similarity index 67% rename from gpax/utils/priors.py rename to gpax/priors/priors.py index 77b5b7a..fb965cc 100644 --- a/gpax/utils/priors.py +++ b/gpax/priors/priors.py @@ -4,20 +4,16 @@ Utility functions for setting priors -Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) """ import inspect -import re -from typing import Union, Dict, Type, List, Callable, Optional +from typing import Union, Dict, Type, Callable import numpyro -import jax import jax.numpy as jnp -from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt - def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0): """ @@ -183,137 +179,6 @@ def uniform_dist(low: float = None, return numpyro.distributions.Uniform(low, high) -def set_fn(func: Callable) -> Callable: - """ - Transforms the given deterministic function to use a params dictionary - for its parameters, excluding the first one (assumed to be the dependent variable). - - Args: - - func (Callable): The deterministic function to be transformed. - - Returns: - - Callable: The transformed function where parameters are accessed - from a `params` dictionary. - """ - # Extract parameter names excluding the first one (assumed to be the dependent variable) - params_names = list(inspect.signature(func).parameters.keys())[1:] - - # Create the transformed function definition - transformed_code = f"def {func.__name__}(x, params):\n" - - # Retrieve the source code of the function and indent it to be a valid function body - source = inspect.getsource(func).split("\n", 1)[1] - source = " " + source.replace("\n", "\n ") - - # Replace each parameter name with its dictionary lookup using regex - for name in params_names: - source = re.sub(rf'\b{name}\b', f'params["{name}"]', source) - - # Combine to get the full source - transformed_code += source - - # Define the transformed function in the local namespace - local_namespace = {} - exec(transformed_code, globals(), local_namespace) - - # Return the transformed function - return local_namespace[func.__name__] - - -def set_kernel_fn(func: Callable, - independent_vars: List[str] = ["X", "Z"], - jit_decorator: bool = True, - docstring: Optional[str] = None) -> Callable: - """ - Transforms the given kernel function to use a params dictionary for its hyperparameters. - The resultant function will always add jitter before returning the computed kernel. - - Args: - func (Callable): The kernel function to be transformed. - independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"]. - jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True. - docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None. - - Returns: - Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary. - """ - - # Extract parameter names excluding the independent variables - params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty] - for var in independent_vars: - params_names.remove(var) - - transformed_code = "" - if jit_decorator: - transformed_code += "@jit" + "\n" - - additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs" - transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n" - - if docstring: - transformed_code += ' """' + docstring + '"""\n' - - source = inspect.getsource(func).split("\n", 1)[1] - lines = source.split("\n") - - for idx, line in enumerate(lines): - # Convert all parameter names to their dictionary lookup throughout the function body - for name in params_names: - lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx]) - - # Combine lines back and then split again by return - modified_source = '\n'.join(lines) - pre_return, return_statement = modified_source.split('return', 1) - - # Append custom jitter code - custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n" - custom_code += """ - if X.shape == Z.shape: - k += (noise + jitter) * jnp.eye(X.shape[0]) - return k - """ - - transformed_code += custom_code - - local_namespace = {"jit": jax.jit} - exec(transformed_code, globals(), local_namespace) - - return local_namespace[func.__name__] - - -def _set_noise_kernel_fn(func: Callable) -> Callable: - """ - Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses. - - Args: - func (Callable): Original function. - - Returns: - Callable: Modified function. - """ - - # Get the source code of the function - source = inspect.getsource(func) - - # Split the source into decorators, definition, and body - decorators_and_def, body = source.split("\n", 1) - - # Replace all occurrences of params["k with params["k_noise in the body - modified_body = re.sub(r'params\["k', 'params["k_noise', body) - - # Combine decorators, definition, and modified body - modified_source = f"{decorators_and_def}\n{modified_body}" - - # Define local namespace including the jit decorator - local_namespace = {"jit": jax.jit} - - # Execute the modified source to redefine the function in the provided namespace - exec(modified_source, globals(), local_namespace) - - # Return the modified function - return local_namespace[func.__name__] - - def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable: """ Generates a function that, when invoked, samples from normal or log-normal distributions diff --git a/gpax/utils/__init__.py b/gpax/utils/__init__.py index 245c13a..f05f555 100644 --- a/gpax/utils/__init__.py +++ b/gpax/utils/__init__.py @@ -1,3 +1,3 @@ from .utils import * -from .priors import * -from .priors import _set_noise_kernel_fn +from .fn import * +from .fn import _set_noise_kernel_fn diff --git a/gpax/utils/fn.py b/gpax/utils/fn.py new file mode 100644 index 0000000..3119e09 --- /dev/null +++ b/gpax/utils/fn.py @@ -0,0 +1,149 @@ +""" +fn.py +===== + +Utilities for setting up custom mean and kernel functions + +Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com) +""" + +import inspect +import re + +from typing import List, Callable, Optional, Dict + +import jax +import jax.numpy as jnp + +from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt + + +def set_fn(func: Callable) -> Callable: + """ + Transforms the given deterministic function to use a params dictionary + for its parameters, excluding the first one (assumed to be the dependent variable). + + Args: + - func (Callable): The deterministic function to be transformed. + + Returns: + - Callable: The transformed function where parameters are accessed + from a `params` dictionary. + """ + # Extract parameter names excluding the first one (assumed to be the dependent variable) + params_names = list(inspect.signature(func).parameters.keys())[1:] + + # Create the transformed function definition + transformed_code = f"def {func.__name__}(x, params):\n" + + # Retrieve the source code of the function and indent it to be a valid function body + source = inspect.getsource(func).split("\n", 1)[1] + source = " " + source.replace("\n", "\n ") + + # Replace each parameter name with its dictionary lookup using regex + for name in params_names: + source = re.sub(rf'\b{name}\b', f'params["{name}"]', source) + + # Combine to get the full source + transformed_code += source + + # Define the transformed function in the local namespace + local_namespace = {} + exec(transformed_code, globals(), local_namespace) + + # Return the transformed function + return local_namespace[func.__name__] + + +def set_kernel_fn(func: Callable, + independent_vars: List[str] = ["X", "Z"], + jit_decorator: bool = True, + docstring: Optional[str] = None) -> Callable: + """ + Transforms the given kernel function to use a params dictionary for its hyperparameters. + The resultant function will always add jitter before returning the computed kernel. + + Args: + func (Callable): The kernel function to be transformed. + independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"]. + jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True. + docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None. + + Returns: + Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary. + """ + + # Extract parameter names excluding the independent variables + params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty] + for var in independent_vars: + params_names.remove(var) + + transformed_code = "" + if jit_decorator: + transformed_code += "@jit" + "\n" + + additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs" + transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n" + + if docstring: + transformed_code += ' """' + docstring + '"""\n' + + source = inspect.getsource(func).split("\n", 1)[1] + lines = source.split("\n") + + for idx, line in enumerate(lines): + # Convert all parameter names to their dictionary lookup throughout the function body + for name in params_names: + lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx]) + + # Combine lines back and then split again by return + modified_source = '\n'.join(lines) + pre_return, return_statement = modified_source.split('return', 1) + + # Append custom jitter code + custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n" + custom_code += """ + if X.shape == Z.shape: + k += (noise + jitter) * jnp.eye(X.shape[0]) + return k + """ + + transformed_code += custom_code + + local_namespace = {"jit": jax.jit} + exec(transformed_code, globals(), local_namespace) + + return local_namespace[func.__name__] + + +def _set_noise_kernel_fn(func: Callable) -> Callable: + """ + Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses. + + Args: + func (Callable): Original function. + + Returns: + Callable: Modified function. + """ + + # Get the source code of the function + source = inspect.getsource(func) + + # Split the source into decorators, definition, and body + decorators_and_def, body = source.split("\n", 1) + + # Replace all occurrences of params["k with params["k_noise in the body + modified_body = re.sub(r'params\["k', 'params["k_noise', body) + + # Combine decorators, definition, and modified body + modified_source = f"{decorators_and_def}\n{modified_body}" + + # Define local namespace including the jit decorator + local_namespace = {"jit": jax.jit} + + # Execute the modified source to redefine the function in the provided namespace + exec(modified_source, globals(), local_namespace) + + # Return the modified function + return local_namespace[func.__name__] diff --git a/tests/test_func_setter.py b/tests/test_func_setter.py new file mode 100644 index 0000000..2937045 --- /dev/null +++ b/tests/test_func_setter.py @@ -0,0 +1,78 @@ +import sys +import jax.numpy as jnp +from numpy.testing import assert_equal, assert_ + +sys.path.insert(0, "../gpax/") + + +from gpax.utils.fn import set_fn, set_kernel_fn, _set_noise_kernel_fn + + +def linear_kernel_test(X, Z, k_scale): + # Dummy kernel functions for testing purposes + return k_scale * jnp.dot(X, Z.T) + + +def rbf_test(X, Z, k_length, k_scale): + # Dummy kernel functions for testing purposes + scaled_X = X / k_length + scaled_Z = Z / k_length + X2 = (scaled_X ** 2).sum(1, keepdims=True) + Z2 = (scaled_Z ** 2).sum(1, keepdims=True) + XZ = jnp.matmul(scaled_X, scaled_Z.T) + r2 = X2 - 2 * XZ + Z2.T + + return k_scale * jnp.exp(-0.5 * r2) + + +def sample_function(x, a, b): + return a + b * x + + +def test_set_fn(): + transformed_fn = set_fn(sample_function) + result = transformed_fn(2, {"a": 1, "b": 3}) + assert result == 7 # Expected output: 1 + 3*2 = 7 + + +def test_set_kernel_fn(): + + # Convert the dummy kernel functions + new_linear_kernel = set_kernel_fn(linear_kernel_test) + new_rbf = set_kernel_fn(rbf_test) + + X = jnp.array([[1, 2], [3, 4], [5, 6]]) + Z = jnp.array([[1, 2], [3, 4]]) + params_linear = {"k_scale": 1.0} + params_rbf = {"k_length": 1.0, "k_scale": 1.0} + + # Assert the transformed function is working correctly + assert_(jnp.array_equal(linear_kernel_test(X, Z, 1.0), new_linear_kernel(X, Z, params_linear))) + assert_(jnp.array_equal(rbf_test(X, Z, 1.0, 1.0), new_rbf(X, Z, params_rbf))) + + +def test_set_kernel_fn_with_jitter(): + + jitter = 1e-5 + + # Convert the dummy kernel functions + new_linear_kernel = set_kernel_fn(linear_kernel_test) + new_rbf = set_kernel_fn(rbf_test) + + X = jnp.array([[1, 2], [3, 4], [5, 6]]) + params_linear = {"k_scale": 1.0} + params_rbf = {"k_length": 1.0, "k_scale": 1.0} + + # Assert the transformed function is working correctly + assert_(jnp.array_equal(linear_kernel_test(X, X, 1.0) + jitter * jnp.eye(X.shape[0]), new_linear_kernel(X, X, params_linear, jitter=jitter))) + assert_(jnp.array_equal(rbf_test(X, X, 1.0, 1.0) + jitter * jnp.eye(X.shape[0]), new_rbf(X, X, params_rbf, jitter=jitter))) + + +def test_set_noise_kernel_fn(): + from gpax.kernels import RBFKernel + + X = jnp.array([[1, 2], [3, 4], [5, 6]]) + params_i = {"k_length": jnp.array([1.0]), "k_scale": jnp.array(1.0)} + params = {"k_noise_length": jnp.array([1.0]), "k_noise_scale": jnp.array(1.0)} + noise_rbf = _set_noise_kernel_fn(RBFKernel) + assert_(jnp.array_equal(noise_rbf(X, X, params), RBFKernel(X, X, params_i))) diff --git a/tests/test_utilpriors.py b/tests/test_priors.py similarity index 69% rename from tests/test_utilpriors.py rename to tests/test_priors.py index 60ee3c1..ab4a2b3 100644 --- a/tests/test_utilpriors.py +++ b/tests/test_priors.py @@ -6,10 +6,9 @@ sys.path.insert(0, "../gpax/") -from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, place_lognormal_prior -from gpax.utils import uniform_dist, normal_dist, halfnormal_dist, lognormal_dist, gamma_dist -from gpax.utils import auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors -from gpax.utils import set_fn, set_kernel_fn, _set_noise_kernel_fn +from gpax.priors import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, place_lognormal_prior +from gpax.priors import uniform_dist, normal_dist, halfnormal_dist, lognormal_dist, gamma_dist +from gpax.priors import auto_lognormal_priors, auto_normal_priors, auto_lognormal_kernel_priors, auto_normal_kernel_priors, auto_priors def linear_kernel_test(X, Z, k_scale): @@ -17,18 +16,6 @@ def linear_kernel_test(X, Z, k_scale): return k_scale * jnp.dot(X, Z.T) -def rbf_test(X, Z, k_length, k_scale): - # Dummy kernel functions for testing purposes - scaled_X = X / k_length - scaled_Z = Z / k_length - X2 = (scaled_X ** 2).sum(1, keepdims=True) - Z2 = (scaled_Z ** 2).sum(1, keepdims=True) - XZ = jnp.matmul(scaled_X, scaled_Z.T) - r2 = X2 - 2 * XZ + Z2.T - - return k_scale * jnp.exp(-0.5 * r2) - - def sample_function(x, a, b): return a + b * x @@ -162,45 +149,6 @@ def test_get_gamma_dist_error(): uniform_dist() # Neither concentration, nor input_vec -def test_set_fn(): - transformed_fn = set_fn(sample_function) - result = transformed_fn(2, {"a": 1, "b": 3}) - assert result == 7 # Expected output: 1 + 3*2 = 7 - - -def test_set_kernel_fn(): - - # Convert the dummy kernel functions - new_linear_kernel = set_kernel_fn(linear_kernel_test) - new_rbf = set_kernel_fn(rbf_test) - - X = jnp.array([[1, 2], [3, 4], [5, 6]]) - Z = jnp.array([[1, 2], [3, 4]]) - params_linear = {"k_scale": 1.0} - params_rbf = {"k_length": 1.0, "k_scale": 1.0} - - # Assert the transformed function is working correctly - assert_(jnp.array_equal(linear_kernel_test(X, Z, 1.0), new_linear_kernel(X, Z, params_linear))) - assert_(jnp.array_equal(rbf_test(X, Z, 1.0, 1.0), new_rbf(X, Z, params_rbf))) - - -def test_set_kernel_fn_with_jitter(): - - jitter = 1e-5 - - # Convert the dummy kernel functions - new_linear_kernel = set_kernel_fn(linear_kernel_test) - new_rbf = set_kernel_fn(rbf_test) - - X = jnp.array([[1, 2], [3, 4], [5, 6]]) - params_linear = {"k_scale": 1.0} - params_rbf = {"k_length": 1.0, "k_scale": 1.0} - - # Assert the transformed function is working correctly - assert_(jnp.array_equal(linear_kernel_test(X, X, 1.0) + jitter * jnp.eye(X.shape[0]), new_linear_kernel(X, X, params_linear, jitter=jitter))) - assert_(jnp.array_equal(rbf_test(X, X, 1.0, 1.0) + jitter * jnp.eye(X.shape[0]), new_rbf(X, X, params_rbf, jitter=jitter))) - - @pytest.mark.parametrize("prior_type", ["normal", "lognormal"]) def test_auto_priors(prior_type): prior_fn = auto_priors(sample_function, 1, prior_type, loc=2.0, scale=1.0) @@ -234,13 +182,3 @@ def test_auto_normal_kernel_priors(autopriors): with numpyro.handlers.trace() as tr: priors_fn() assert_('k_scale' in tr) - - -def test_set_noise_kernel_fn(): - from gpax.kernels import RBFKernel - - X = jnp.array([[1, 2], [3, 4], [5, 6]]) - params_i = {"k_length": jnp.array([1.0]), "k_scale": jnp.array(1.0)} - params = {"k_noise_length": jnp.array([1.0]), "k_noise_scale": jnp.array(1.0)} - noise_rbf = _set_noise_kernel_fn(RBFKernel) - assert_(jnp.array_equal(noise_rbf(X, X, params), RBFKernel(X, X, params_i))) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4d4ef69..8e7844c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,8 +10,6 @@ sys.path.insert(0, "../gpax/") from gpax.utils import preprocess_sparse_image, split_dict, random_sample_dict, get_keys, initialize_inducing_points -from gpax.utils import place_normal_prior, place_halfnormal_prior, place_uniform_prior, place_gamma_prior, gamma_dist, uniform_dist, normal_dist, halfnormal_dist -from gpax.utils import set_fn, auto_normal_priors def test_sparse_img_processing():