-
Notifications
You must be signed in to change notification settings - Fork 0
/
transform.py
60 lines (53 loc) · 2.02 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import Any, List, MutableMapping
import datetime
import numpy as np
from .. import shared
class Log(shared.Plugin):
def __init__(
self,
*variable_names,
transform_obs: bool = True,
minimum: float = -np.inf,
log10: bool = True
):
self.variable_names = frozenset(variable_names)
self.variable_metadata: List[Any] = []
self.transform_obs = transform_obs
self.minimum = minimum
self.log10 = log10
self.forward = np.log10 if self.log10 else np.log
self.backward = (lambda x: 10.0 ** x) if self.log10 else np.exp
def initialize(self, variables: MutableMapping[str, Any], *args, **kwargs):
for name in self.variable_names:
self.variable_metadata.append(variables[name])
def before_analysis(
self,
time: datetime.datetime,
state: np.ndarray,
iobs: np.ndarray,
obs: np.ndarray,
obs_sds: np.ndarray,
*args,
**kwargs
):
for metadata in self.variable_metadata:
affected_obs = (iobs >= metadata["start"]) & (iobs < metadata["stop"])
if self.transform_obs and affected_obs.any():
# Transform mean and sd of observations,
# assuming their distribution is log-normal
mean = np.maximum(obs[affected_obs], self.minimum)
sd = obs_sds[affected_obs]
sigma2 = np.log((sd / mean) ** 2 + 1.0)
mu = np.log(mean) - 0.5 * sigma2
sigma = np.sqrt(sigma2)
if self.log10:
mu /= np.log(10.0)
sigma /= np.log(10.0)
obs[affected_obs] = mu
obs_sds[affected_obs] = sigma
metadata["data"][...] = self.forward(
np.maximum(metadata["data"], self.minimum)
)
def after_analysis(self, *args, **kwargs):
for metadata in self.variable_metadata:
metadata["data"][...] = self.backward(metadata["data"])