Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving hybrid tf.keras nn with gpflux.layers.GPLayer #103

Open
ems1111 opened this issue Jul 2, 2024 · 0 comments
Open

Saving hybrid tf.keras nn with gpflux.layers.GPLayer #103

ems1111 opened this issue Jul 2, 2024 · 0 comments

Comments

@ems1111
Copy link

ems1111 commented Jul 2, 2024

I can not figure out how to save the model and reload it wit the same performance as it had when i trained the model
Can someone comment on how i might save this and then reload it i? I feel like it should be more straightforward than it is.

CODE:

Ensure inducing points have the correct dimensions

num_inducing = 2500
inducing_variable = gpflow.inducing_variables.InducingPoints(np.random.randn(num_inducing, 62)) # Match the dense layer output

Define kernel

kernel = gpflow.kernels.SquaredExponential()*gpflow.kernels.White(variance=0.01) + gpflow.kernels.White(variance=0.01)

Define likelihood

likelihood = gpflow.likelihoods.Gaussian(0.01)
likelihood_container = gpflux.layers.TrackableLayer()
likelihood_container.likelihood = likelihood

Hybrid model with explicit layer names

gp_hybrid_model = tf.keras.Sequential([
tf.keras.layers.Dense(156, activation="relu", kernel_initializer='he_normal', bias_initializer='he_normal', name="dense_gp_hybrid_1"),
tf.keras.layers.Dropout(0.35),
tf.keras.layers.Dense(62, activation="relu", kernel_initializer='he_normal', bias_initializer='he_normal', name="dense_gp_hybrid_2"),
gpflux.layers.GPLayer(
kernel=kernel,
inducing_variable=inducing_variable,
num_data=X_train.shape[0],
num_latent_gps=31,
mean_function=gpflow.mean_functions.Zero(),
name="gp_layer"
),
likelihood_container
])

Build the model to initialize the layers

gp_hybrid_model.build(input_shape=(None, 78))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant