-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added KF with sparse sites and related test (#18)
* Added KF with sparse sites and related test * Update tests/integration/test_kalman_filter_with_sparse_sites.py Co-authored-by: Vincent Adam <vincent.adam.ml@gmail.com> * Incorporated suggested changes * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam <vincent.adam.ml@gmail.com> * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam <vincent.adam.ml@gmail.com> * Incorporated PR changes * sparse observations saved * passing sparse sites rather than dense sites Co-authored-by: Vincent Adam <vincent.adam.ml@gmail.com>
- Loading branch information
1 parent
08dac0b
commit 06f21e0
Showing
2 changed files
with
188 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import tensorflow as tf | ||
import pytest | ||
from gpflow.config import default_float | ||
|
||
from markovflow.kernels.matern import Matern12 | ||
from markovflow.mean_function import LinearMeanFunction | ||
from markovflow.models.gaussian_process_regression import GaussianProcessRegression | ||
from markovflow.kalman_filter import KalmanFilterWithSparseSites, UnivariateGaussianSitesNat | ||
from markovflow.likelihoods import MultivariateGaussian | ||
|
||
@pytest.fixture( | ||
name="time_step_homogeneous", params=[(0.01, True), (0.01, False), (0.001, True), (0.001, False)], | ||
) | ||
def _time_step_homogeneous_fixture(request): | ||
return request.param | ||
|
||
|
||
@pytest.fixture(name="kalman_gpr_setup") | ||
def _setup(batch_shape, time_step_homogeneous): | ||
""" | ||
Create a Gaussian Process model and an equivalent kalman filter model | ||
with more latent states than observations. | ||
FIXME: Currently batch_shape isn't used. | ||
""" | ||
dt, homogeneous = time_step_homogeneous | ||
|
||
time_grid = np.arange(0.0, 1.0, dt) | ||
if not homogeneous: | ||
time_grid = np.sort(np.random.choice(time_grid, 50, replace=False)) | ||
|
||
time_points = time_grid[::10] | ||
observations = np.sin(12 * time_points[..., None]) + np.random.randn(len(time_points), 1) * 0.1 | ||
|
||
input_data = ( | ||
tf.constant(time_points, dtype=default_float()), | ||
tf.constant(observations, dtype=default_float()), | ||
) | ||
|
||
observation_covariance = 1.0 # Same as GPFlow default | ||
kernel = Matern12(lengthscale=1.0, variance=1.0, output_dim=observations.shape[-1]) | ||
kernel.set_state_mean(tf.random.normal((1,), dtype=default_float())) | ||
gpr_model = GaussianProcessRegression( | ||
input_data=input_data, | ||
kernel=kernel, | ||
mean_function=LinearMeanFunction(1.1), | ||
chol_obs_covariance=tf.constant([[np.sqrt(observation_covariance)]], dtype=default_float()), | ||
) | ||
|
||
prior_ssm = kernel.state_space_model(time_grid) | ||
emission_model = kernel.generate_emission_model(time_grid) | ||
observations_index = tf.where(tf.equal(time_grid[..., None], time_points))[:, 0][..., None] | ||
|
||
observations -= gpr_model.mean_function(time_points) | ||
|
||
nat1 = observations / observation_covariance | ||
nat2 = (-0.5 / observation_covariance) * tf.ones_like(nat1)[..., None] | ||
lognorm = tf.zeros_like(nat1) | ||
sites = UnivariateGaussianSitesNat(nat1=nat1, nat2=nat2, log_norm=lognorm) | ||
|
||
kf_sparse_sites = KalmanFilterWithSparseSites(prior_ssm, emission_model, sites, time_grid.shape[0], | ||
observations_index, observations) | ||
|
||
return gpr_model, kf_sparse_sites | ||
|
||
def test_kalman_loglikelihood(with_tf_random_seed, kalman_gpr_setup): | ||
gpr_model, kf_sparse_sites = kalman_gpr_setup | ||
|
||
np.testing.assert_allclose(gpr_model.log_likelihood(), kf_sparse_sites.log_likelihood()) |