-
Notifications
You must be signed in to change notification settings - Fork 0
/
function.py
37 lines (30 loc) · 1.24 KB
/
function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
class MODEL:
def __init__(self, density=None, log_density=None,
log_density_gradient=None):
if (density is None) and (log_density is None):
raise TypeError
if (density is not None) and (log_density is not None):
self.density_func = density
self.log_density_func = log_density
elif density is None:
self.density_func = lambda x: torch.exp(log_density(x))
self.log_density_func = log_density
else:
self.density_func = density
self.log_density_func = lambda x: torch.log(density(x))
if log_density_gradient is not None:
self.log_density_gradient_func = log_density_gradient
else:
def log_density_gradient_func(x):
x = x.clone().detach().requires_grad_(True)
w = self.log_density_func(x)
w.backward()
return x.clone().detach()
self.log_density_gradient_func = log_density_gradient_func
def log_gradient(self, x):
return self.log_density_gradient_func(x)
def density(self, x):
return self.density_func(x)
def log_density(self, x):
return self.log_density_func(x)