Skip to content

Commit

Permalink
Final changes (#4)
Browse files Browse the repository at this point in the history
* catch FileNotFoundError on first run without data

* plot trajectories first

* catch FileNotFoundError for pendulum

---------

Co-authored-by: jrettberg <johannes.rettberg@itm.uni-stuttgart.de>
  • Loading branch information
JohannesRettberg and jrettberg authored Aug 15, 2024
1 parent 8c5c02e commit 0cb6ffa
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
22 changes: 11 additions & 11 deletions examples/disc_brake/disc_brake_param_dep_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,17 @@ def main(
yscale="log",
)

# %% plot trajectories
use_train_data = False
idx_gen = "rand"
aphin_vis.plot_time_trajectories_all(
disc_brake_data,
disc_brake_data_id,
use_train_data=use_train_data,
idx_gen=idx_gen,
result_dir=result_dir,
)

# %% 3D plots
# save each test parameter set as csv
for i, mu_ in enumerate(disc_brake_data.TEST.Mu):
Expand Down Expand Up @@ -368,17 +379,6 @@ def main(
close_on_end=True,
)

# %% plot trajectories
use_train_data = False
idx_gen = "rand"
aphin_vis.plot_time_trajectories_all(
disc_brake_data,
disc_brake_data_id,
use_train_data=use_train_data,
idx_gen=idx_gen,
result_dir=result_dir,
)


# parameter variation for multiple experiment runs
# requires calc_various_experiments = True
Expand Down
16 changes: 13 additions & 3 deletions examples/mass_spring_damper/mass_spring_damper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
save_weights,
write_to_experiment_overview,
save_evaluation_times,
save_training_times
save_training_times,
)
from aphin.utils.print_matrices import print_matrices

Expand Down Expand Up @@ -75,7 +75,13 @@ def main(config_path_to_file=None):
cache_path = os.path.join(data_dir, msd_cfg["sim_name"])

# %% Load data
msd_data = Dataset.from_data(cache_path)
try:
msd_data = Dataset.from_data(cache_path)
except FileNotFoundError:
raise FileNotFoundError(
f"File could not be found. If this is the first time you run this example, please execute the data generating script `./state_space_ph/mass_spring_damper_data_gen.py` first."
)

# split into train and test data
msd_data.train_test_split(test_size=msd_cfg["test_size"], seed=msd_cfg["seed"])
# scale data
Expand Down Expand Up @@ -162,7 +168,11 @@ def main(config_path_to_file=None):

msd_data.calculate_errors(msd_data_id, domain_split_vals=[1, 1])
use_train_data = False
aphin_vis.plot_errors(msd_data, use_train_data)
aphin_vis.plot_errors(
msd_data,
use_train_data,
save_name=os.path.join(result_dir, "rms_error"),
)

msd_data.calculate_errors(msd_data_id, save_to_txt=True, result_dir=result_dir)

Expand Down
9 changes: 7 additions & 2 deletions examples/pendulum/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
save_weights,
write_to_experiment_overview,
save_training_times,
save_evaluation_times
save_evaluation_times,
)
from aphin.utils.print_matrices import print_matrices

Expand Down Expand Up @@ -78,7 +78,12 @@ def main(config_path_to_file=None):
r = pd_cfg[experiment]["r"]
n_f = pd_cfg["n_n"] * pd_cfg["n_dn"]
cache_path = os.path.join(data_dir, "pendulum.npz")
pendulum_data = Dataset.from_data(cache_path)
try:
pendulum_data = Dataset.from_data(cache_path)
except FileNotFoundError:
raise FileNotFoundError(
f"File could not be found. If this is the first time you run this example, please execute the data generating script `./pendulum_data_generation.py` first."
)
pendulum_data.train_test_split(test_size=0.333, seed=pd_cfg["seed"])
pendulum_data.truncate_time(trunc_time_ratio=pd_cfg["trunc_time_ratio"])
pendulum_data.states_to_features()
Expand Down

0 comments on commit 0cb6ffa

Please sign in to comment.