Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deep GP fit on Step Data #75

Open
Aadesh-1404 opened this issue Jul 6, 2022 · 1 comment
Open

Deep GP fit on Step Data #75

Aadesh-1404 opened this issue Jul 6, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@Aadesh-1404
Copy link

Aadesh-1404 commented Jul 6, 2022

Describe the bug

I want to fit a Deep GP on step data, so I am using the method shown in GPflux tutorial on the motorcycle dataset. But the fit is not as expected, as shown in Prof.Neil Lawerence's blog. I can fit using PyDeepGP. I have attached the code used by me in GPflux and PyDeepGP

To reproduce
Steps to reproduce the behaviour:

GPflux Implementation ```
try:
    import gpflux
except ModuleNotFoundError:
    %pip install gpflux
    import gpflux

from gpflux.architectures import Config, build_constant_input_dim_deep_gp
from gpflux.models import DeepGP

try:
    import tensorflow as tf
except ModuleNotFoundError:
    %pip install tensorflow
    import tensorflow as tf
    
import numpy as np
import pandas as pd
import gpflow
import gpflux

from gpflux.architectures import Config, build_constant_input_dim_deep_gp
from gpflux.models import DeepGP

tf.keras.backend.set_floatx("float64")
tf.get_logger().setLevel("INFO")


## Data

num_low = 25
num_high = 25
gap = -.1
noise = 0.0001
x = np.vstack((np.linspace(-1, -gap/2.0, num_low)[:, np.newaxis],
                np.linspace(gap/2.0, 1, num_high)[:, np.newaxis])).reshape(-1,)
y = np.vstack((np.zeros((num_low, 1)), np.ones((num_high, 1))))
scale = np.sqrt(y.var())
offset = y.mean()
yhat = ((y-offset)/scale).reshape(-1,)

## Model
config = Config(
    num_inducing=x.shape[0], inner_layer_qsqrt_factor=1e-3, likelihood_noise_variance=1e-3, whiten=True
)
deep_gp: DeepGP = build_constant_input_dim_deep_gp(
    np.array(x.reshape(-1, 1)), num_layers=4, config=config)

training_model: tf.keras.Model =deep_gp.as_training_model()

training_model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.01))

callbacks = [ tf.keras.callbacks.ReduceLROnPlateau("loss", factor=0.95, patience=3, min_lr=1e-6, verbose=0),
                      gpflux.callbacks.TensorBoard(),
                      tf.keras.callbacks.ModelCheckpoint(filepath="ckpts/", save_weights_only=True, verbose=0),]


history = training_model.fit(
    {"inputs": x.reshape(-1, 1),
     "targets": y.reshape(-1, 1)},
    batch_size=6,
    epochs=1000,
    callbacks=callbacks,
    verbose=0,
)

## Predict
def plot(model, X, Y, ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    x = X
    x_margin = 1.0
    N = 50
    X = np.linspace(X.min() - x_margin, X.max() + x_margin, N).reshape(-1, 1)
    out = model(X)
    mu = out.f_mean.numpy().squeeze()
    var = out.f_var.numpy().squeeze()
    X = X.squeeze()
    lower = mu - 2 * np.sqrt(var)
    upper = mu + 2 * np.sqrt(var)
    ax.set_ylim(Y.min() - 0.5, Y.max() + 0.5)
    ax.plot(x, Y, "kx", alpha=0.5)
    ax.plot(X, mu, "C1")
    ax.set_xlim(-2, 2)
    ax.fill_between(X, lower, upper, color="C1", alpha=0.3)


prediction_model = deep_gp.as_prediction_model()
plot(prediction_model, x.reshape(-1, 1), y.reshape(-1, 1))

Plot obtained as a result of the above code:

deep_gp_step_tut2

Expected behaviour

PyDeepGP Implementation ```
try:
    import deepgp
except ModuleNotFoundError:
    %pip install git+https://github.com/SheffieldML/PyDeepGP.git
    import deepgp

try:
    import GPy
except ModuleNotFoundError:
    %pip install -qq GPy
    import GPy

try:
    import tinygp
except ModuleNotFoundError:
    %pip install -q tinygp
    import tinygp

import seaborn as sns
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from tinygp import kernels, GaussianProcess
from jax.config import config

import numpy as np

try:
    import jaxopt
except ModuleNotFoundError:
    %pip install jaxopt
    import jaxopt
config.update("jax_enable_x64", True)


num_low = 25
num_high = 25
gap = -0.1
noise = 0.0001
x = jnp.vstack(
    (jnp.linspace(-1, -gap / 2.0, num_low)[:, jnp.newaxis], jnp.linspace(gap / 2.0, 1, num_high)[:, jnp.newaxis])
).reshape(
    -1,
)
y = jnp.vstack((jnp.zeros((num_low, 1)), jnp.ones((num_high, 1))))
scale = jnp.sqrt(y.var())
offset = y.mean()
yhat = ((y - offset) / scale).reshape(
    -1,
)
xnew = jnp.vstack(
    (jnp.linspace(-2, -gap / 2.0, 25)[:, jnp.newaxis], jnp.linspace(gap / 2.0, 2, 25)[:, jnp.newaxis])
).reshape(
    -1,
)


num_hidden = 3
latent_dim = 1

kernels = [*[GPy.kern.RBF(latent_dim, ARD=True)] * num_hidden]  # hidden kernels
kernels.append(GPy.kern.RBF(np.array(x.reshape(-1, 1)).shape[1]))  # we append a kernel for the input layer

m = deepgp.DeepGP(

    [y.reshape(-1, 1).shape[1], *[latent_dim] * num_hidden, x.reshape(-1, 1).shape[1]],
    X=np.array(x.reshape(-1, 1)),  # training input
    Y=np.array(y.reshape(-1, 1)),  # training outout
    inits=[*["PCA"] * num_hidden, "PCA"],  # initialise layers
    kernels=kernels,
    num_inducing=x.shape[0],
    back_constraint=False,
)
m.initialize_parameter()



def optimise_dgp(model, messages=True):
    """Utility function for optimising deep GP by first
    reinitiailising the Gaussian noise at each layer
    (for reasons pertaining to stability)
    """
    model.initialize_parameter()
    for layer in model.layers:
        layer.likelihood.variance.constrain_positive(warning=False)
        layer.likelihood.variance = 1.0  # small variance may cause collapse
    model.optimize(messages=messages, max_iters=10000)


optimise_dgp(m, messages=True)


mu_dgp, var_dgp = m.predict(xnew.reshape(-1, 1))


plt.figure()

latexify(width_scale_factor=2, fig_height=1.75)
plt.plot(xnew, mu_dgp, "blue")
plt.scatter(x, y, c="r", s=marksize)
plt.fill_between(
    xnew.flatten(),
    mu_dgp.flatten() - 1.96 * jnp.sqrt(var_dgp.flatten()),
    mu_dgp.flatten() + 1.96 * jnp.sqrt(var_dgp.flatten()),
    alpha=0.3,
    color="C1",
)
sns.despine()
legendsize = 4.5 if is_latexify_enabled() else 9
plt.legend(labels=["Mean", "Data", "Confidence"], loc=2, prop={"size": legendsize}, frameon=False)
plt.xlabel("$x$")
plt.ylabel("$y$")
sns.despine()
plt.show()

Plot obtained from above code
deep_gp_step_pydeepgp

System information

  • OS: Ubuntu 20.04.2 LTS
  • Python version: 3.10.4
  • GPflux version: 0.3.0
  • TensorFlow version: 2.8.2
  • GPflow version: 2.5.2
@Aadesh-1404 Aadesh-1404 added the bug Something isn't working label Jul 6, 2022
@sebastianober
Copy link
Collaborator

Hi @Aadesh-1404 ,

This is actually a difference in the models - the DGP model used in PyDeepGP is from Damianou and Lawrence (2013) and uses latent variables in the intermediate layers, whereas the model we follow is from Salimbeni and Deisenroth (2017), and doesn't use latent variables. This means that the fits will be different. However, you should be able to get similar fits (but not exactly the same) using GPflux by following the tutorial https://secondmind-labs.github.io/GPflux/notebooks/deep_cde.html

References
Damianou and Lawrence (2013): http://proceedings.mlr.press/v31/damianou13a
Salimbeni and Deisenroth (2017): https://arxiv.org/abs/1705.08933

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants