Skip to content

Commit

Permalink
Merge pull request #94 from ziatdinovmax/util
Browse files Browse the repository at this point in the history
Move 'priors' out of 'utils' and turn them into a separate module
  • Loading branch information
ziatdinovmax authored Mar 20, 2024
2 parents 64bbec2 + cb28ab8 commit b03480a
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 235 deletions.
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ GPax is a small Python package for physics-based Gaussian processes (GPs) built
:caption: Package Content

models
hypo
acquisition
kernels
priors
hypo
utils

.. toctree::
Expand Down
10 changes: 10 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,13 @@ Multi-Task Learning
:undoc-members:
:member-order: bysource
:show-inheritance:

Structured Probabilistic Models
-------------------------------
.. autoclass:: gpax.models.spm.sPM
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

23 changes: 23 additions & 0 deletions docs/source/priors.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Priors
======

.. autofunction:: gpax.utils.normal_dist

.. autofunction:: gpax.utils.lognormal_dist

.. autofunction:: gpax.utils.halfnormal_dist

.. autofunction:: gpax.utils.gamma_dist

.. autofunction:: gpax.utils.uniform_dist

.. autofunction:: gpax.utils.place_normal_prior

.. autofunction:: gpax.utils.place_lognormal_prior

.. autofunction:: gpax.utils.place_halfnormal_prior

.. autofunction:: gpax.utils.place_uniform_prior

.. autofunction:: gpax.utils.place_gamma_prior

31 changes: 4 additions & 27 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
@@ -1,40 +1,17 @@
Utilities
=========

Priors
------
Automatic function setters
--------------------------

.. autofunction:: gpax.utils.normal_dist
.. autofunction:: gpax.utils.set_fn

.. autofunction:: gpax.utils.lognormal_dist

.. autofunction:: gpax.utils.halfnormal_dist

.. autofunction:: gpax.utils.gamma_dist

.. autofunction:: gpax.utils.uniform_dist

.. autofunction:: gpax.utils.place_normal_prior

.. autofunction:: gpax.utils.place_lognormal_prior

.. autofunction:: gpax.utils.place_halfnormal_prior

.. autofunction:: gpax.utils.place_uniform_prior

.. autofunction:: gpax.utils.place_gamma_prior
.. autofunction:: gpax.utils.set_kernel_fn


Other utilities
---------------

.. autoclass:: gpax.models.spm.sPM
:members:
:inherited-members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. autofunction:: gpax.utils.dviz

.. autofunction:: gpax.utils.get_keys
Expand Down
3 changes: 2 additions & 1 deletion gpax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .__version__ import version as __version__
from . import priors
from . import utils
from . import kernels
from . import acquisition
Expand All @@ -7,6 +8,6 @@
vi_iBNN, viDKL, viGP, sPM, viMTDKL, VarNoiseGP, UIGP,
MeasuredNoiseGP, viSparseGP, BNN)

__all__ = ["utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
__all__ = ["priors", "utils", "kernels", "mtkernels", "acquisition", "ExactGP", "vExactGP", "DKL",
"viDKL", "iBNN", "vi_iBNN", "MultiTaskGP", "viMTDKL", "viGP", "sPM", "VarNoiseGP",
"UIGP", "MeasuredNoiseGP", "viSparseGP", "CoregGP", "BNN", "sample_next", "__version__"]
1 change: 1 addition & 0 deletions gpax/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .priors import *
139 changes: 2 additions & 137 deletions gpax/utils/priors.py → gpax/priors/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
Utility functions for setting priors
Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

import inspect
import re

from typing import Union, Dict, Type, List, Callable, Optional
from typing import Union, Dict, Type, Callable

import numpyro
import jax
import jax.numpy as jnp

from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt


def place_normal_prior(param_name: str, loc: float = 0.0, scale: float = 1.0):
"""
Expand Down Expand Up @@ -183,137 +179,6 @@ def uniform_dist(low: float = None,
return numpyro.distributions.Uniform(low, high)


def set_fn(func: Callable) -> Callable:
"""
Transforms the given deterministic function to use a params dictionary
for its parameters, excluding the first one (assumed to be the dependent variable).
Args:
- func (Callable): The deterministic function to be transformed.
Returns:
- Callable: The transformed function where parameters are accessed
from a `params` dictionary.
"""
# Extract parameter names excluding the first one (assumed to be the dependent variable)
params_names = list(inspect.signature(func).parameters.keys())[1:]

# Create the transformed function definition
transformed_code = f"def {func.__name__}(x, params):\n"

# Retrieve the source code of the function and indent it to be a valid function body
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup using regex
for name in params_names:
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)

# Combine to get the full source
transformed_code += source

# Define the transformed function in the local namespace
local_namespace = {}
exec(transformed_code, globals(), local_namespace)

# Return the transformed function
return local_namespace[func.__name__]


def set_kernel_fn(func: Callable,
independent_vars: List[str] = ["X", "Z"],
jit_decorator: bool = True,
docstring: Optional[str] = None) -> Callable:
"""
Transforms the given kernel function to use a params dictionary for its hyperparameters.
The resultant function will always add jitter before returning the computed kernel.
Args:
func (Callable): The kernel function to be transformed.
independent_vars (List[str], optional): List of independent variable names in the function. Defaults to ["X", "Z"].
jit_decorator (bool, optional): @jax.jit decorator to be applied to the transformed function. Defaults to True.
docstring (Optional[str], optional): Docstring to be added to the transformed function. Defaults to None.
Returns:
Callable: The transformed kernel function where hyperparameters are accessed from a `params` dictionary.
"""

# Extract parameter names excluding the independent variables
params_names = [k for k, v in inspect.signature(func).parameters.items() if v.default == v.empty]
for var in independent_vars:
params_names.remove(var)

transformed_code = ""
if jit_decorator:
transformed_code += "@jit" + "\n"

additional_args = "noise: int = 0, jitter: float = 1e-6, **kwargs"
transformed_code += f"def {func.__name__}({', '.join(independent_vars)}, params: Dict[str, jnp.ndarray], {additional_args}):\n"

if docstring:
transformed_code += ' """' + docstring + '"""\n'

source = inspect.getsource(func).split("\n", 1)[1]
lines = source.split("\n")

for idx, line in enumerate(lines):
# Convert all parameter names to their dictionary lookup throughout the function body
for name in params_names:
lines[idx] = re.sub(rf'\b{name}\b', f'params["{name}"]', lines[idx])

# Combine lines back and then split again by return
modified_source = '\n'.join(lines)
pre_return, return_statement = modified_source.split('return', 1)

# Append custom jitter code
custom_code = f" {pre_return.strip()}\n k = {return_statement.strip()}\n"
custom_code += """
if X.shape == Z.shape:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
"""

transformed_code += custom_code

local_namespace = {"jit": jax.jit}
exec(transformed_code, globals(), local_namespace)

return local_namespace[func.__name__]


def _set_noise_kernel_fn(func: Callable) -> Callable:
"""
Modifies the GPax kernel function to append "_noise" after "k" in dictionary keys it accesses.
Args:
func (Callable): Original function.
Returns:
Callable: Modified function.
"""

# Get the source code of the function
source = inspect.getsource(func)

# Split the source into decorators, definition, and body
decorators_and_def, body = source.split("\n", 1)

# Replace all occurrences of params["k with params["k_noise in the body
modified_body = re.sub(r'params\["k', 'params["k_noise', body)

# Combine decorators, definition, and modified body
modified_source = f"{decorators_and_def}\n{modified_body}"

# Define local namespace including the jit decorator
local_namespace = {"jit": jax.jit}

# Execute the modified source to redefine the function in the provided namespace
exec(modified_source, globals(), local_namespace)

# Return the modified function
return local_namespace[func.__name__]


def auto_priors(func: Callable, params_begin_with: int, dist_type: str = 'normal', loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Generates a function that, when invoked, samples from normal or log-normal distributions
Expand Down
4 changes: 2 additions & 2 deletions gpax/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .utils import *
from .priors import *
from .priors import _set_noise_kernel_fn
from .fn import *
from .fn import _set_noise_kernel_fn
Loading

0 comments on commit b03480a

Please sign in to comment.