From 06f21e08f379f167861c89ed75da08e0e34a9ed3 Mon Sep 17 00:00:00 2001 From: Prakhar Verma Date: Sat, 3 Sep 2022 18:00:40 +0300 Subject: [PATCH] Added KF with sparse sites and related test (#18) * Added KF with sparse sites and related test * Update tests/integration/test_kalman_filter_with_sparse_sites.py Co-authored-by: Vincent Adam * Incorporated suggested changes * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam * Update markovflow/kalman_filter.py Co-authored-by: Vincent Adam * Incorporated PR changes * sparse observations saved * passing sparse sites rather than dense sites Co-authored-by: Vincent Adam --- markovflow/kalman_filter.py | 121 +++++++++++++++++- .../test_kalman_filter_with_sparse_sites.py | 69 ++++++++++ 2 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 tests/integration/test_kalman_filter_with_sparse_sites.py diff --git a/markovflow/kalman_filter.py b/markovflow/kalman_filter.py index e26f2e6..9fb4d1d 100644 --- a/markovflow/kalman_filter.py +++ b/markovflow/kalman_filter.py @@ -96,9 +96,8 @@ def _k_inv_post(self): # The emission matrix is tiled across the time_points, so for a time invariant matrix # this is equivalent to Gᵀ Σ⁻¹ G = (I_N ⊗ HᵀR⁻¹H), likelihood_precision = SymmetricBlockTriDiagonal(h_t_r_h) - _k_inv_prior = self.prior_ssm.precision # K⁻¹ + GᵀΣ⁻¹G - return _k_inv_prior + likelihood_precision + return self._k_inv_prior + likelihood_precision @property def _log_det_observation_precision(self): @@ -495,3 +494,121 @@ def _log_det_observation_precision(self): def observations(self): """ Observation vector """ return self.sites.means + + +@tf_scope_class_decorator +class KalmanFilterWithSparseSites(BaseKalmanFilter): + r""" + Performs a Kalman filter on a :class:`~markovflow.state_space_model.StateSpaceModel` + and :class:`~markovflow.emission_model.EmissionModel`, with Gaussian sites, over a time grid. + """ + + def __init__(self, state_space_model: StateSpaceModel, emission_model: EmissionModel, sites: GaussianSites, + num_grid_points: int, observations_index: tf.Tensor, observations: tf.Tensor): + """ + :param state_space_model: Parameterises the latent chain. + :param emission_model: Maps the latent chain to the observations. + :param sites: Gaussian sites over the observations. + :param num_grid_points: number of grid points. + :param observations_index: Index of the observations in the time grid with shape (N,). + :param observations: Sparse observations with shape (N, output_dim). + """ + self.sites = sites + self.observations_index = observations_index + self.sparse_observations = observations + self.grid_shape = tf.TensorShape((num_grid_points, 1)) + super().__init__(state_space_model, emission_model) + + @property + def _r_inv(self): + """ + Precisions of the observation model over the time grid. + """ + data_sites_precision = self.sites.precisions + return self.sparse_to_dense(data_sites_precision, output_shape=self.grid_shape + (1,)) + + @property + def _log_det_observation_precision(self): + """ + Sum of log determinant of the precisions of the observation model. It only calculates for the data_sites as + other sites precision is anyways zero. + """ + return tf.reduce_sum(tf.linalg.logdet(self._r_inv_data), axis=-1) + + @property + def observations(self): + """ Sparse observation vector """ + return self.sparse_observations + + @property + def _r_inv_data(self): + """ + Precisions of the observation model for only the data sites. + """ + return self.sites.precisions + + def sparse_to_dense(self, tensor: tf.Tensor, output_shape: tf.TensorShape) -> tf.Tensor: + """ + Convert a sparse tensor to a dense one on the basis of observations index, output tensor is of the output_shape. + """ + return tf.scatter_nd(self.observations_index, tensor, output_shape) + + def dense_to_sparse(self, tensor: tf.Tensor) -> tf.Tensor: + """ + Convert a dense tensor to a sparse one on the basis of observations index. + """ + tensor_shape = tensor.shape + expand_dims = len(tensor_shape) == 3 + + tensor = tf.gather_nd(tf.reshape(tensor, (-1, 1)), self.observations_index) + if expand_dims: + tensor = tf.expand_dims(tensor, axis=-1) + return tensor + + def log_likelihood(self) -> tf.Tensor: + r""" + Construct a TensorFlow function to compute the likelihood. + + For more mathematical details, look at the log_likelihood function of the parent class. + The main difference from the parent class are that the vector of observations is now sparse. + + :return: The likelihood as a scalar tensor (we sum over the `batch_shape`). + """ + # K⁻¹ + GᵀΣ⁻¹G = LLᵀ. + l_post = self._k_inv_post.cholesky + num_data = self.observations_index.shape[0] + + # Hμ [..., num_transitions + 1, output_dim] + marginal = self.emission.project_state_to_f(self.prior_ssm.marginal_means) + + # y = obs - Hμ [..., num_transitions + 1, output_dim] + disp = self.sparse_to_dense(self.observations, marginal.shape) - marginal + disp_data = self.sparse_observations - self.dense_to_sparse(marginal) + + # cst is the constant term for a gaussian log likelihood + cst = ( + -0.5 * np.log(2 * np.pi) * tf.cast(self.emission.output_dim * num_data, default_float()) + ) + + term1 = -0.5 * tf.reduce_sum( + input_tensor=tf.einsum("...op,...p,...o->...o", self._r_inv_data, disp_data, disp_data), axis=[-1, -2] + ) + + # term 2 is: ½|L⁻¹(GᵀΣ⁻¹)y|² + # (GᵀΣ⁻¹)y [..., num_transitions + 1, state_dim] + obs_proj = self._back_project_y_to_state(disp) + + # ½|L⁻¹(GᵀΣ⁻¹)y|² [...] + term2 = 0.5 * tf.reduce_sum( + input_tensor=tf.square(l_post.solve(obs_proj, transpose_left=False)), axis=[-1, -2] + ) + + ## term 3 is: ½log |K⁻¹| - log |L| + ½ log |Σ⁻¹| + # where log |Σ⁻¹| = num_data * log|R⁻¹| + term3 = ( + 0.5 * self.prior_ssm.log_det_precision() + - l_post.abs_log_det() + + 0.5 * self._log_det_observation_precision + ) + + return tf.reduce_sum(cst + term1 + term2 + term3) diff --git a/tests/integration/test_kalman_filter_with_sparse_sites.py b/tests/integration/test_kalman_filter_with_sparse_sites.py new file mode 100644 index 0000000..d88407e --- /dev/null +++ b/tests/integration/test_kalman_filter_with_sparse_sites.py @@ -0,0 +1,69 @@ +import numpy as np +import tensorflow as tf +import pytest +from gpflow.config import default_float + +from markovflow.kernels.matern import Matern12 +from markovflow.mean_function import LinearMeanFunction +from markovflow.models.gaussian_process_regression import GaussianProcessRegression +from markovflow.kalman_filter import KalmanFilterWithSparseSites, UnivariateGaussianSitesNat +from markovflow.likelihoods import MultivariateGaussian + +@pytest.fixture( + name="time_step_homogeneous", params=[(0.01, True), (0.01, False), (0.001, True), (0.001, False)], +) +def _time_step_homogeneous_fixture(request): + return request.param + + +@pytest.fixture(name="kalman_gpr_setup") +def _setup(batch_shape, time_step_homogeneous): + """ + Create a Gaussian Process model and an equivalent kalman filter model + with more latent states than observations. + FIXME: Currently batch_shape isn't used. + """ + dt, homogeneous = time_step_homogeneous + + time_grid = np.arange(0.0, 1.0, dt) + if not homogeneous: + time_grid = np.sort(np.random.choice(time_grid, 50, replace=False)) + + time_points = time_grid[::10] + observations = np.sin(12 * time_points[..., None]) + np.random.randn(len(time_points), 1) * 0.1 + + input_data = ( + tf.constant(time_points, dtype=default_float()), + tf.constant(observations, dtype=default_float()), + ) + + observation_covariance = 1.0 # Same as GPFlow default + kernel = Matern12(lengthscale=1.0, variance=1.0, output_dim=observations.shape[-1]) + kernel.set_state_mean(tf.random.normal((1,), dtype=default_float())) + gpr_model = GaussianProcessRegression( + input_data=input_data, + kernel=kernel, + mean_function=LinearMeanFunction(1.1), + chol_obs_covariance=tf.constant([[np.sqrt(observation_covariance)]], dtype=default_float()), + ) + + prior_ssm = kernel.state_space_model(time_grid) + emission_model = kernel.generate_emission_model(time_grid) + observations_index = tf.where(tf.equal(time_grid[..., None], time_points))[:, 0][..., None] + + observations -= gpr_model.mean_function(time_points) + + nat1 = observations / observation_covariance + nat2 = (-0.5 / observation_covariance) * tf.ones_like(nat1)[..., None] + lognorm = tf.zeros_like(nat1) + sites = UnivariateGaussianSitesNat(nat1=nat1, nat2=nat2, log_norm=lognorm) + + kf_sparse_sites = KalmanFilterWithSparseSites(prior_ssm, emission_model, sites, time_grid.shape[0], + observations_index, observations) + + return gpr_model, kf_sparse_sites + +def test_kalman_loglikelihood(with_tf_random_seed, kalman_gpr_setup): + gpr_model, kf_sparse_sites = kalman_gpr_setup + + np.testing.assert_allclose(gpr_model.log_likelihood(), kf_sparse_sites.log_likelihood())