-
Notifications
You must be signed in to change notification settings - Fork 0
/
builders.py
140 lines (124 loc) · 4.08 KB
/
builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.autograd as autograd
from ortho.orthopoly import OrthonormalPolynomial
from ortho.builders import OrthoBuilder
from ortho.measure import MaximalEntropyDensity
from ortho.basis_functions import OrthonormalBasis
from mercergp.MGP import MercerKernel, MercerGP
from mercergp.eigenvalue_gen import EigenvalueGenerator, FavardEigenvalues
from mercergp.likelihood import MercerLikelihood
import matplotlib.pyplot as plt
from typing import Callable
def train_favard_params(
parameters: dict,
# eigenvalue_generator: EigenvalueGenerator,
order: int,
input_sample: torch.Tensor,
output_sample: torch.Tensor,
weight_function: Callable,
optimiser: torch.optim.Optimizer,
dim=1,
) -> dict:
""" """
basis = (
OrthoBuilder(order)
.set_sample(input_sample)
.set_weight_function(weight_function)
.get_orthonormal_basis()
)
# x = torch.tensor(0.0)
x = torch.Tensor([0.0])
x.requires_grad = True
# get the basis at zero and its derivatives
f0 = basis(x)
# breakpoint()
df0 = autograd.functional.jacobian(basis, x).squeeze(2)
d2f0 = f0.clone() # autograd.grad(df0, x)[0]
breakpoint()
assert (
f0.shape == df0.shape == d2f0.shape
), "derivative and second derivative are the wrong shape : ASSRT STATEMENT"
# build the Favard eigenvalue generator
eigenvalue_generator = FavardEigenvalues(order, f0, df0, d2f0)
mgp_likelihood = MercerLikelihood(
order,
optimiser,
basis,
input_sample,
output_sample,
eigenvalue_generator,
)
new_parameters = parameters.copy()
mgp_likelihood.fit(new_parameters)
# for param in filter(
# lambda param: (isinstance(param, torch.Tensor))
# and (param.requires_grad),
# new_parameters.values(),
# ):
# # new_parameters[param] = new_parameters[param].detach()
# param = param.detach()
for param in filter(
lambda param: isinstance(new_parameters[param], torch.Tensor),
new_parameters,
):
print(new_parameters[param])
new_parameters[param] = new_parameters[param].detach()
return new_parameters
def build_favard_gp(
parameters: dict,
order: int,
input_sample: torch.Tensor,
weight_function: Callable,
eigenvalue_generator: EigenvalueGenerator,
dim=1,
) -> MercerGP:
"""
Returns a Mercer Gaussian process with a basis constructed from the input
sample measure.
param eigenvalue_generator: a callable that returns, a tensor of "order"
eigenvalues, with input being the parameter dictionary 'parameters'
param weight_function: a function w(x). When constructing the basis,
will have the square root applied to get w^{1/2}
"""
# get the corresponding orthonormal basis.
# weight_function
# basis = get_orthonormal_basis_from_sample(
# input_sample, weight_function, order
# )
basis = (
OrthoBuilder(order)
.set_sample(input_sample)
.set_weight_function(weight_function)
.get_orthonormal_basis()
)
eigenvalues = eigenvalue_generator(parameters)
# build the kernel
kernel = MercerKernel(order, basis, eigenvalues, parameters)
# build the gp
mgp = MercerGP(basis, order, dim, kernel)
# breakpoint()
return mgp
if __name__ == "__main__":
"""
Test here whether it is simple to get the Favard eigenvalues.
"""
# set the parameters
order = 10
x = torch.Tensor([0.0])
x.requires_grad = True
basis = (
OrthoBuilder(order)
.set_sample(input_sample)
.set_weight_function(weight_function)
.get_orthonormal_basis()
)
f0 = basis(x)
# breakpoint()
df0 = autograd.functional.jacobian(basis, x).squeeze(2)
d2f0 = f0.clone() # autograd.grad(df0, x)[0]
breakpoint()
assert (
f0.shape == df0.shape == d2f0.shape
), "derivative and second derivative are the wrong shape : ASSRT STATEMENT"
# build the Favard eigenvalue generator
eigenvalue_generator = FavardEigenvalues(order, f0, df0, d2f0)