diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 342a570..212aefc 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -108,8 +108,10 @@ def prepare_clamping_parames( ) # Constant parameters for clamping - sigmoid_sharpness = softplus_sharpness = 1 - sigmoid_center = softplus_center = 0 + self.sigmoid_sharpness = 1 + self.softplus_sharpness = 1 + self.sigmoid_center = 0 + self.softplus_center = 0 normalize_clamping_lim = ( lambda x, feature_idx: (x - self.state_mean[feature_idx]) @@ -129,6 +131,11 @@ def prepare_clamping_parames( for feature_idx, feature in enumerate(state_feature_names): if feature in lower_lims and feature in upper_lims: + assert ( + lower_lims[feature] < upper_lims[feature] + ), f'Invalid clamping limits for feature "{feature}",\ + lower: {lower_lims[feature]}, larger than\ + upper: {upper_lims[feature]}' sigmoid_lower_upper_idx.append(feature_idx) sigmoid_lower_lims.append( normalize_clamping_lim(lower_lims[feature], feature_idx) @@ -147,31 +154,10 @@ def prepare_clamping_parames( normalize_clamping_lim(upper_lims[feature], feature_idx) ) - # Convert to tensors - # self.register_buffer( - # "sigmoid_lower_lims", - # torch.tensor(sigmoid_lower_lims), - # persistent=False, - # ) - # self.register_buffer( - # "sigmoid_upper_lims", - # torch.tensor(sigmoid_upper_lims), - # persistent=False, - # ) - # self.register_buffer( - # "softplus_lower_lims", - # torch.tensor(softplus_lower_lims), - # persistent=False, - # ) - # self.register_buffer( - # "softplus_upper_lims", - # torch.tensor(softplus_upper_lims), - # persistent=False, - # ) - sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) - sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) - softplus_lower_lims = torch.tensor(softplus_lower_lims) - softplus_upper_lims = torch.tensor(softplus_upper_lims) + self.sigmoid_lower_lims = torch.tensor(sigmoid_lower_lims) + self.sigmoid_upper_lims = torch.tensor(sigmoid_upper_lims) + self.softplus_lower_lims = torch.tensor(softplus_lower_lims) + self.softplus_upper_lims = torch.tensor(softplus_upper_lims) self.clamp_lower_upper_idx = torch.tensor(sigmoid_lower_upper_idx) self.clamp_lower_idx = torch.tensor(softplus_lower_idx) @@ -179,20 +165,20 @@ def prepare_clamping_parames( # Define clamping functions self.clamp_lower_upper = lambda x: ( - sigmoid_lower_lims - + (sigmoid_upper_lims - sigmoid_lower_lims) - * torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center)) + self.sigmoid_lower_lims + + (self.sigmoid_upper_lims - self.sigmoid_lower_lims) + * torch.sigmoid(self.sigmoid_sharpness * (x - self.sigmoid_center)) ) self.clamp_lower = lambda x: ( - softplus_lower_lims + self.softplus_lower_lims + torch.nn.functional.softplus( - x - softplus_center, beta=softplus_sharpness + x - self.softplus_center, beta=self.softplus_sharpness ) ) self.clamp_upper = lambda x: ( - softplus_upper_lims + self.softplus_upper_lims - torch.nn.functional.softplus( - softplus_center - x, beta=softplus_sharpness + self.softplus_center - x, beta=self.softplus_sharpness ) ) @@ -200,13 +186,11 @@ def prepare_clamping_parames( def inverse_softplus(x, beta=1, threshold=20): # If x*beta is above threshold, returns linear function # for numerical stability - under_lim = x * beta <= threshold - x[under_lim] = ( - torch.log( - torch.clamp_min(torch.expm1(x[under_lim] * beta), 1e-6) - ) - / beta + non_linear_part = ( + torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta ) + x = torch.where(x * beta <= threshold, non_linear_part, x) + return x def inverse_sigmoid(x): @@ -214,20 +198,24 @@ def inverse_sigmoid(x): return torch.log(x_clamped / (1 - x_clamped)) self.inverse_clamp_lower_upper = lambda x: ( - sigmoid_center + self.sigmoid_center + inverse_sigmoid( - (x - sigmoid_lower_lims) - / (sigmoid_upper_lims - sigmoid_lower_lims) + (x - self.sigmoid_lower_lims) + / (self.sigmoid_upper_lims - self.sigmoid_lower_lims) ) - / sigmoid_sharpness + / self.sigmoid_sharpness ) self.inverse_clamp_lower = lambda x: ( - inverse_softplus(x - softplus_lower_lims, beta=softplus_sharpness) - + softplus_center + inverse_softplus( + x - self.softplus_lower_lims, beta=self.softplus_sharpness + ) + + self.softplus_center ) self.inverse_clamp_upper = lambda x: ( - -inverse_softplus(softplus_upper_lims - x, beta=softplus_sharpness) - + softplus_center + -inverse_softplus( + self.softplus_upper_lims - x, beta=self.softplus_sharpness + ) + + self.softplus_center ) def clamp_prediction(self, state_delta, prev_state): diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml index a57266f..d311c12 100644 --- a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml +++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml @@ -15,3 +15,4 @@ training: r2m: 0 upper: r2m: 100.0 + u100m: 100.0 diff --git a/tests/test_clamping.py b/tests/test_clamping.py new file mode 100644 index 0000000..457be63 --- /dev/null +++ b/tests/test_clamping.py @@ -0,0 +1,283 @@ +# Standard library +from pathlib import Path + +# Third-party +import torch + +# First-party +from neural_lam import config as nlconfig +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.datastore.mdp import MDPDatastore +from neural_lam.models.graph_lam import GraphLAM +from tests.conftest import init_datastore_example + + +def test_clamping(): + datastore = init_datastore_example(MDPDatastore.SHORT_NAME) + + graph_name = "1level" + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + graph = graph_name + hidden_dim = 4 + hidden_layers = 1 + processor_layers = 2 + mesh_aggr = "sum" + lr = 1.0e-3 + val_steps_to_log = [1, 3] + metrics_watch = [] + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 + + model_args = ModelArgs() + + config = nlconfig.NeuralLAMConfig( + datastore=nlconfig.DatastoreSelection( + kind=datastore.SHORT_NAME, config_path=datastore.root_path + ), + training=nlconfig.TrainingConfig( + output_clamping=nlconfig.OutputClamping( + lower={"t2m": 0.0, "r2m": 0.0}, + upper={"r2m": 100.0, "u100m": 100.0}, + ) + ), + ) + + model = GraphLAM( + args=model_args, + datastore=datastore, + config=config, + ) + + features = datastore.get_vars_names(category="state") + original_state = torch.zeros(1, 1, len(features)) + zero_delta = original_state.clone() + + # Get a state well within the bounds + original_state[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_lower_lims + model.sigmoid_upper_lims + ) / 2 + original_state[:, :, model.clamp_lower_idx] = model.softplus_lower_lims + 10 + original_state[:, :, model.clamp_upper_idx] = model.softplus_upper_lims - 10 + + # Get a delta that tries to push the state out of bounds + delta = torch.ones_like(zero_delta) + delta[:, :, model.clamp_lower_upper_idx] = ( + model.sigmoid_upper_lims - model.sigmoid_lower_lims + ) / 3 + delta[:, :, model.clamp_lower_idx] = -5 + delta[:, :, model.clamp_upper_idx] = 5 + + # Check that a delta of 0 gives unchanged state + zero_prediction = model.clamp_prediction(zero_delta, original_state) + assert (abs(original_state - zero_prediction) < 1e-6).all().item() + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.clamp_prediction(delta, prediction) + + # check that unclamped states are as expected + # delta is 1, so they should be 1*n_loops + assert ( + ( + abs( + prediction[ + :, + :, + list( + set(range(len(features))) + - set(model.clamp_lower_upper_idx.tolist()) + - set(model.clamp_lower_idx.tolist()) + - set(model.clamp_upper_idx.tolist()) + ), + ] + - n_loops + ) + < 1e-6 + ) + .all() + .item() + ) + + # Check that clamped states are within bounds + # they should not be at the bounds but allow it due to numerical precision + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + (model.softplus_lower_lims <= prediction[:, :, model.clamp_lower_idx]) + .all() + .item() + ) + assert ( + (prediction[:, :, model.clamp_upper_idx] <= model.softplus_upper_lims) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + unscaled_prediction = prediction * model.state_std + model.state_mean + features_idx = {f: i for i, f in enumerate(features)} + lower_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.lower.items() + } + upper_lims = { + features_idx[f]: lim + for f, lim in config.training.output_clamping.upper.items() + } + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state + 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + assert ( + not ( + model.softplus_lower_lims + <= invalid_state[:, :, model.clamp_lower_idx] + ) + .any() + .item() + ) + assert ( + not ( + invalid_state[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + assert ( + ( + model.softplus_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_idx] + ) + .all() + .item() + ) + assert ( + ( + invalid_prediction[:, :, model.clamp_upper_idx] + <= model.softplus_upper_lims + ) + .all() + .item() + ) + + # Above tests only check the upper sigmoid limit. + # Repeat to check lower sigmoid limit + + # Make predictions towards bounds for each feature + prediction = zero_prediction.clone() + n_loops = 100 + for i in range(n_loops): + prediction = model.clamp_prediction(-delta, prediction) + + # Check that clamped states are within bounds + assert ( + ( + model.sigmoid_lower_lims + <= prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + ) + + # Check that prediction is within bounds in original non-normalized space + assert ( + ( + torch.tensor(list(lower_lims.values())) + <= unscaled_prediction[:, :, list(lower_lims.keys())] + ) + .all() + .item() + ) + assert ( + ( + unscaled_prediction[:, :, list(upper_lims.keys())] + <= torch.tensor(list(upper_lims.values())) + ) + .all() + .item() + ) + + # Check that a prediction from a state starting outside the bounds is also + # pushed within bounds. 3 delta should be enough to give an initial state + # out of bounds so 5 is well outside + invalid_state = original_state - 5 * delta + assert ( + not ( + model.sigmoid_lower_lims + <= invalid_state[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .any() + .item() + ) + invalid_prediction = model.clamp_prediction(zero_delta, invalid_state) + assert ( + ( + model.sigmoid_lower_lims + <= invalid_prediction[:, :, model.clamp_lower_upper_idx] + <= model.sigmoid_upper_lims + ) + .all() + .item() + )