Skip to content

Commit

Permalink
added test
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Kamuk Christiansen committed Dec 11, 2024
1 parent 5ba72ce commit 5c7567d
Show file tree
Hide file tree
Showing 3 changed files with 320 additions and 48 deletions.
84 changes: 36 additions & 48 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ def prepare_clamping_parames(
)

# Constant parameters for clamping
sigmoid_sharpness = softplus_sharpness = 1
sigmoid_center = softplus_center = 0
self.sigmoid_sharpness = 1
self.softplus_sharpness = 1
self.sigmoid_center = 0
self.softplus_center = 0

normalize_clamping_lim = (
lambda x, feature_idx: (x - self.state_mean[feature_idx])
Expand All @@ -129,6 +131,11 @@ def prepare_clamping_parames(

for feature_idx, feature in enumerate(state_feature_names):
if feature in lower_lims and feature in upper_lims:
assert (
lower_lims[feature] < upper_lims[feature]
), f'Invalid clamping limits for feature "{feature}",\
lower: {lower_lims[feature]}, larger than\
upper: {upper_lims[feature]}'
sigmoid_lower_upper_idx.append(feature_idx)
sigmoid_lower_lims.append(
normalize_clamping_lim(lower_lims[feature], feature_idx)
Expand All @@ -147,87 +154,68 @@ def prepare_clamping_parames(
normalize_clamping_lim(upper_lims[feature], feature_idx)
)

# Convert to tensors
# self.register_buffer(
# "sigmoid_lower_lims",
# torch.tensor(sigmoid_lower_lims),
# persistent=False,
# )
# self.register_buffer(
# "sigmoid_upper_lims",
# torch.tensor(sigmoid_upper_lims),
# persistent=False,
# )
# self.register_buffer(
# "softplus_lower_lims",
# torch.tensor(softplus_lower_lims),
# persistent=False,
# )
# self.register_buffer(
# "softplus_upper_lims",
# torch.tensor(softplus_upper_lims),
# persistent=False,
# )
sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims)
sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims)
softplus_lower_lims = torch.tensor(softplus_lower_lims)
softplus_upper_lims = torch.tensor(softplus_upper_lims)
self.sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims)
self.sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims)
self.softplus_lower_lims = torch.tensor(softplus_lower_lims)
self.softplus_upper_lims = torch.tensor(softplus_upper_lims)

self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx)
self.clamp_lower_idx = torch.tensor(softplus_lower_idx)
self.clamp_upper_idx = torch.tensor(softplus_upper_idx)

# Define clamping functions
self.clamp_lower_upper = lambda x: (
sigmoid_lower_lims
+ (sigmoid_upper_lims - sigmoid_lower_lims)
* torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center))
self.sigmoid_lower_lims
+ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
* torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center))
)
self.clamp_lower = lambda x: (
softplus_lower_lims
self.softplus_lower_lims
+ torch.nn.functional.softplus(
x - softplus_center, beta=softplus_sharpness
x - self.softplus_center, beta=self.softplus_sharpness
)
)
self.clamp_upper = lambda x: (
softplus_upper_lims
self.softplus_upper_lims
- torch.nn.functional.softplus(
softplus_center - x, beta=softplus_sharpness
self.softplus_center - x, beta=self.softplus_sharpness
)
)

# Define inverse clamping functions
def inverse_softplus(x, beta=1, threshold=20):
# If x*beta is above threshold, returns linear function
# for numerical stability
under_lim = x * beta <= threshold
x[under_lim] = (
torch.log(
torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6)
)
/ beta
non_linear_part = (
torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta
)
x = torch.where(x * beta <= threshold, non_linear_part, x)

return x

def inverse_sigmoid(x):
x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6)
return torch.log(x_clamped / (1 - x_clamped))

self.inverse_clamp_lower_upper = lambda x: (
sigmoid_center
self.sigmoid_center
+ inverse_sigmoid(
(x - sigmoid_lower_lims)
/ (sigmoid_upper_lims - sigmoid_lower_lims)
(x - self.sigmoid_lower_lims)
/ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
)
/ sigmoid_sharpness
/ self.sigmoid_sharpness
)
self.inverse_clamp_lower = lambda x: (
inverse_softplus(x - softplus_lower_lims, beta=softplus_sharpness)
+ softplus_center
inverse_softplus(
x - self.softplus_lower_lims, beta=self.softplus_sharpness
)
+ self.softplus_center
)
self.inverse_clamp_upper = lambda x: (
-inverse_softplus(softplus_upper_lims - x, beta=softplus_sharpness)
+ softplus_center
-inverse_softplus(
self.softplus_upper_lims - x, beta=self.softplus_sharpness
)
+ self.softplus_center
)

def clamp_prediction(self, state_delta, prev_state):
Expand Down
1 change: 1 addition & 0 deletions tests/datastore_examples/mdp/danra_100m_winds/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ training:
r2m: 0
upper:
r2m: 100.0
u100m: 100.0
Loading

0 comments on commit 5c7567d

Please sign in to comment.