diff --git a/dli_gpr.py b/dli_gpr.py index 2830567..48f4db6 100644 --- a/dli_gpr.py +++ b/dli_gpr.py @@ -140,6 +140,13 @@ def conditional_distribution(self, new_y): #print(k.T @ precision @ k) return mean, cov + def conditional_log_prob(self, new_y, new_t): + """Compute log probability of out of sample points based on conditional distribution + """ + mean, cov = self.conditional_distribution(new_y) + mvn = dist.MultivariateNormal(mean, covariance_matrix = cov + self.sigma * torch.eye(len(new_y))) + return mvn.log_prob(new_t) + class dli_gpr: """ Implementation of Gaussian Process regression with non-isotropic noise @@ -195,18 +202,11 @@ def initialize_variables(self, jitter=1e-5): self.scale_tril = torch.cholesky(self.scale + torch.eye(self.n) * jitter) # precision - self.beta = self.cluster_sizes * self.n / torch.sum(self.cluster_sizes) - self.beta_inverse = 1/self.beta - - # compute inverse cluster sizes for gamma prior - #inverse_cluster_sizes = 1./self.cluster_sizes - - # rate parameter for Gamma prior (beta inverse ~ variance of observations. Want lower beta_inverse for bigger clusters) - # initialize beta so that the MEAN of the gamma prior is proportional to the inverse weights of best linear unbiased estimator - #self.beta_inverse = inverse_cluster_sizes / torch.sum(inverse_cluster_sizes) * self.n - - # beta is proportional to the expected precision - #self.beta = 1./self.beta_inverse + inverse_cluster_sizes = 1./self.cluster_sizes + self.beta_inverse = inverse_cluster_sizes / torch.sum(inverse_cluster_sizes) * self.n + self.beta = 1./self.beta_inverse + #self.beta = self.cluster_sizes * self.n / torch.sum(self.cluster_sizes) + #self.beta_inverse = 1/self.beta def model(self): """Generative process""" @@ -284,4 +284,11 @@ def conditional_distribution(self, new_y): return mean, cov + def conditional_log_prob(self, new_y, new_t): + """Compute log probability of out of sample points based on conditional distribution + """ + mean, cov = self.conditional_distribution(new_y) + mvn = dist.MultivariateNormal(mean, covariance_matrix = cov) + return mvn.log_prob(new_t) +