From 02088bc2baea4ecfe323701ece565aa9961ceeee Mon Sep 17 00:00:00 2001 From: eruijsena Date: Thu, 28 Sep 2023 14:28:37 +0200 Subject: [PATCH 1/8] add plots to show the change in maxContrib state --- devtools/conda-envs/full_env.yaml | 1 + reeds/function_libs/analysis/sampling.py | 68 +++++++++++++++++ .../visualization/sampling_plots.py | 73 ++++++++++++++++++- 3 files changed, 141 insertions(+), 1 deletion(-) diff --git a/devtools/conda-envs/full_env.yaml b/devtools/conda-envs/full_env.yaml index 5ad5c940..bbe807d6 100644 --- a/devtools/conda-envs/full_env.yaml +++ b/devtools/conda-envs/full_env.yaml @@ -19,6 +19,7 @@ dependencies: - pandas - pytables - matplotlib + - plotly - mpmath #Docs diff --git a/reeds/function_libs/analysis/sampling.py b/reeds/function_libs/analysis/sampling.py index 707e0c80..e22fc875 100644 --- a/reeds/function_libs/analysis/sampling.py +++ b/reeds/function_libs/analysis/sampling.py @@ -6,6 +6,7 @@ import pandas as pd import reeds.function_libs.visualization.sampling_plots +from pygromos.files.repdat import Repdat def undersampling_occurence_potential_threshold_densityClustering(ene_trajs: List[pd.DataFrame], @@ -349,7 +350,74 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame], return final_results, out_path +def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirectional: bool = False): + """ + Count the number of times a transition occurs between pairs of states, based on the repdat info. + + Parameters + ---------- + repdat: Repdat + Redat object containing the information regarding the replica exchange trials in the + RE-EDS simulation. + normalize: bool, optional + Normalize the transitions by the total number of outgoing transitions per state + bidirectional: bool, optional + Count the transitions symmetrically (state A to B together with state B to A) + Returns + ------- + np.ndarray + number of transitions between all pairs of states + """ + if normalize and bidirectional: + raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional") + + num_replicas = len(repdat.system.s) + repdat_eoffs = repdat.system.state_eir + num_states = len(repdat_eoffs) + eoffs = [[repdat_eoffs[state][s] for state in repdat_eoffs] for s in range(len(repdat_eoffs[1]))] + + # Add end-state potentials as individual columns (this is pretty slow) + + Vi = repdat.DATA["state_potentials"].apply(pd.Series) + values_to_subtract = np.array([eoffs[int(s_val_index-1)] for s_val_index in repdat.DATA['ID'].values]) + + corrected_Vi_array = Vi.values - values_to_subtract + corrected_Vi = pd.DataFrame(corrected_Vi_array, columns=[f"Vr{i+1}" for i in range(5)]) + max_contrib = corrected_Vi.idxmin(axis=1) + max_contrib.name = "Vmin" + enhanced_repdat = pd.concat([repdat.DATA, corrected_Vi, max_contrib], axis=1) + + # Initialize transition counts to zero for all pairs of states + transition_counts = np.zeros((num_states, num_states)) + + for replica in range(1, num_replicas+1): + state_trajectory = enhanced_repdat.query(f"coord_ID == {replica}")[["Vmin", "run"]].reset_index(drop=True) + + # Count the transitions between different states + for i in range(len(state_trajectory) - 1): + current_state = int(state_trajectory["Vmin"][i][-1]) # Take the i in Vri + next_state = int(state_trajectory["Vmin"][i + 1][-1]) + current_run = state_trajectory["run"][i] # To check in case the trajectory is spliced + next_run = state_trajectory["run"][i+1] + if next_run == current_run +1 and current_state != next_state: + transition_counts[current_state-1][next_state-1] += 1 + + if normalize: + # Normalize by total number of transitions per state + tot_trans = np.sum(transition_counts, axis=1) + transition_counts = transition_counts / tot_trans[:, np.newaxis] + + elif bidirectional: + # Consider exchanges in both directions together + bidirectional_counts = np.zeros((num_states, num_states)) + for state1 in range(len(transition_counts)): + for state2 in range(len(transition_counts[state1])): + bidirectional_counts[state1][state2] += transition_counts[state1][state2] + bidirectional_counts[state2][state1] += transition_counts[state1][state2] + transition_counts = bidirectional_counts + + return transition_counts def detect_undersampling(ene_trajs: List[pd.DataFrame], state_potential_treshold: List[float], diff --git a/reeds/function_libs/visualization/sampling_plots.py b/reeds/function_libs/visualization/sampling_plots.py index 64a5a054..f98086ab 100644 --- a/reeds/function_libs/visualization/sampling_plots.py +++ b/reeds/function_libs/visualization/sampling_plots.py @@ -2,11 +2,15 @@ import numpy as np from matplotlib import pyplot as plt +from matplotlib.colors import to_rgba + +import plotly.graph_objects as go +from plotly.colors import convert_to_RGB_255 from reeds.function_libs.visualization import plots_style as ps from reeds.function_libs.visualization.utils import nice_s_vals -import reeds.function_libs.visualization.plots_style as ps + def plot_sampling_convergence(ene_trajs, opt_trajs, outfile, title = None, trim_beg = 0.1): """ @@ -377,3 +381,70 @@ def plot_stateOccurence_matrix(data: dict, if (not out_dir is None): fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight') plt.close() + +def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: List[str] = ps.active_qualitative_map, out_path: str = None): + """ + Make a Sankey plot showing the flows between states. + + Parameters + ---------- + state_transitions : np.ndarray + num_states * num_states 2D array containing the number of transitions between states + title: str, optional + printed title of the plot + colors: List[str], optional + if you don't like the default colors + out_path: str, optional + path to save the image to. if none, the image is returned as a plotly figure + + Returns + ------- + None or fig + plotly figure if if was not saved + """ + num_states = len(state_transitions) + + def v_distribute(total_transitions): + # Vertically distribute states in plot based on total number of transitions + box_sizes = total_transitions / total_transitions.sum() + box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))] + return box_vplace + + y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0)) + + # Convert colors to plotly format and make them transparent + rgba_colors = [] + for color in colors: + rgba = to_rgba(color) + rgba_plotly = convert_to_RGB_255(rgba[:-1]) + # Add opacity + rgba_plotly = rgba_plotly + (0.8,) + # Make string + rgba_colors.append("rgba" + str(rgba_plotly)) + + # Indices 0..n-1 are the source and n..2n-1 are the target. + fig = go.Figure(data=[go.Sankey( + node = dict( + thickness = 20, + line = dict(color = "black", width = 2), + label = [f"state {i+1}" for i in range(num_states)]*2, + color = rgba_colors[:num_states]*2, + x = [0.1,0.1,0.1,0.1,0.1,1,1,1,1,1], + y = y_placements + ), + link = dict( + arrowlen = 30, + source = np.array([[i]*num_states for i in range(num_states)]).flatten(), + target = np.array([[i for i in range(num_states, 2*num_states)] for _ in range(num_states)]).flatten(), + value = state_transitions.flatten(), + color = np.array([[c]*num_states for c in rgba_colors[:num_states]]).flatten() + ), + arrangement="fixed", + )]) + + fig.update_layout(title_text=title, font_size=20, title_x=0.5) + + if out_path: + fig.write_image(out_path) + else: + return fig \ No newline at end of file From aa468457c7f7843973ad68066bf1139f295ed0b2 Mon Sep 17 00:00:00 2001 From: eruijsena Date: Fri, 29 Sep 2023 17:05:27 +0200 Subject: [PATCH 2/8] make it work for more than 5 states --- reeds/function_libs/analysis/sampling.py | 16 +++++----- .../visualization/sampling_plots.py | 29 +++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/reeds/function_libs/analysis/sampling.py b/reeds/function_libs/analysis/sampling.py index e22fc875..98910871 100644 --- a/reeds/function_libs/analysis/sampling.py +++ b/reeds/function_libs/analysis/sampling.py @@ -353,7 +353,6 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame], def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirectional: bool = False): """ Count the number of times a transition occurs between pairs of states, based on the repdat info. - Parameters ---------- repdat: Repdat @@ -363,7 +362,6 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti Normalize the transitions by the total number of outgoing transitions per state bidirectional: bool, optional Count the transitions symmetrically (state A to B together with state B to A) - Returns ------- np.ndarray @@ -371,7 +369,7 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti """ if normalize and bidirectional: raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional") - + num_replicas = len(repdat.system.s) repdat_eoffs = repdat.system.state_eir num_states = len(repdat_eoffs) @@ -383,7 +381,7 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti values_to_subtract = np.array([eoffs[int(s_val_index-1)] for s_val_index in repdat.DATA['ID'].values]) corrected_Vi_array = Vi.values - values_to_subtract - corrected_Vi = pd.DataFrame(corrected_Vi_array, columns=[f"Vr{i+1}" for i in range(5)]) + corrected_Vi = pd.DataFrame(corrected_Vi_array, columns=[f"Vr{i+1}" for i in range(num_states)]) max_contrib = corrected_Vi.idxmin(axis=1) max_contrib.name = "Vmin" enhanced_repdat = pd.concat([repdat.DATA, corrected_Vi, max_contrib], axis=1) @@ -392,13 +390,15 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti transition_counts = np.zeros((num_states, num_states)) for replica in range(1, num_replicas+1): - state_trajectory = enhanced_repdat.query(f"coord_ID == {replica}")[["Vmin", "run"]].reset_index(drop=True) + state_repdat = enhanced_repdat.query(f"coord_ID == {replica}") + + state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True) # Count the transitions between different states for i in range(len(state_trajectory) - 1): - current_state = int(state_trajectory["Vmin"][i][-1]) # Take the i in Vri - next_state = int(state_trajectory["Vmin"][i + 1][-1]) - current_run = state_trajectory["run"][i] # To check in case the trajectory is spliced + current_state = int("".join([char for char in state_trajectory["Vmin"][i] if char.isdigit()])) # Take the i in Vri + next_state = int("".join([char for char in state_trajectory["Vmin"][i + 1] if char.isdigit()])) + current_run = state_trajectory["run"][i] # Check whether you are actually comparing consecutive exchanges next_run = state_trajectory["run"][i+1] if next_run == current_run +1 and current_state != next_state: transition_counts[current_state-1][next_state-1] += 1 diff --git a/reeds/function_libs/visualization/sampling_plots.py b/reeds/function_libs/visualization/sampling_plots.py index f98086ab..b99bb018 100644 --- a/reeds/function_libs/visualization/sampling_plots.py +++ b/reeds/function_libs/visualization/sampling_plots.py @@ -1,8 +1,8 @@ -from typing import List +from typing import Union, List import numpy as np from matplotlib import pyplot as plt -from matplotlib.colors import to_rgba +from matplotlib.colors import Colormap, to_rgba import plotly.graph_objects as go from plotly.colors import convert_to_RGB_255 @@ -382,7 +382,7 @@ def plot_stateOccurence_matrix(data: dict, fig.savefig(out_dir + '/sampling_maxContrib_matrix.png', bbox_inches='tight') plt.close() -def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: List[str] = ps.active_qualitative_map, out_path: str = None): +def plot_state_transitions(state_transitions: np.ndarray, title: str = None, colors: Union[List[str], Colormap] = ps.qualitative_tab_map, out_path: str = None): """ Make a Sankey plot showing the flows between states. @@ -392,26 +392,30 @@ def plot_state_transitions(state_transitions: np.ndarray, title: str = None, col num_states * num_states 2D array containing the number of transitions between states title: str, optional printed title of the plot - colors: List[str], optional + colors: Union[List[str], Colormap], optional if you don't like the default colors out_path: str, optional path to save the image to. if none, the image is returned as a plotly figure - Returns ------- None or fig plotly figure if if was not saved """ num_states = len(state_transitions) + + if isinstance(colors, Colormap): + colors = [colors(i) for i in np.linspace(0, 1, num_states)] + elif len(colors) < num_states: + raise Exception("Insufficient colors to plot all states") def v_distribute(total_transitions): - # Vertically distribute states in plot based on total number of transitions + # Vertically distribute nodes in plot based on total number of transitions per state box_sizes = total_transitions / total_transitions.sum() box_vplace = [np.sum(box_sizes[:i]) + box_sizes[i]/2 for i in range(len(box_sizes))] return box_vplace - + y_placements = v_distribute(np.sum(state_transitions, axis=1)) + v_distribute(np.sum(state_transitions, axis=0)) - + # Convert colors to plotly format and make them transparent rgba_colors = [] for color in colors: @@ -421,15 +425,16 @@ def v_distribute(total_transitions): rgba_plotly = rgba_plotly + (0.8,) # Make string rgba_colors.append("rgba" + str(rgba_plotly)) - + # Indices 0..n-1 are the source and n..2n-1 are the target. fig = go.Figure(data=[go.Sankey( node = dict( + pad = 5, thickness = 20, line = dict(color = "black", width = 2), label = [f"state {i+1}" for i in range(num_states)]*2, color = rgba_colors[:num_states]*2, - x = [0.1,0.1,0.1,0.1,0.1,1,1,1,1,1], + x = [0.1]*num_states + [1]*num_states, y = y_placements ), link = dict( @@ -441,10 +446,10 @@ def v_distribute(total_transitions): ), arrangement="fixed", )]) + fig.update_layout(title_text=title, font_size=20, title_x=0.5, height=max(600, num_states*100)) - fig.update_layout(title_text=title, font_size=20, title_x=0.5) - if out_path: fig.write_image(out_path) + return None else: return fig \ No newline at end of file From 3b0f1fa90eafb2ad475ad5db1dda467e0afa6b2c Mon Sep 17 00:00:00 2001 From: eruijsena Date: Fri, 29 Sep 2023 17:06:20 +0200 Subject: [PATCH 3/8] add option to consider only high s values --- reeds/function_libs/analysis/sampling.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/reeds/function_libs/analysis/sampling.py b/reeds/function_libs/analysis/sampling.py index 98910871..059df169 100644 --- a/reeds/function_libs/analysis/sampling.py +++ b/reeds/function_libs/analysis/sampling.py @@ -350,7 +350,7 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame], return final_results, out_path -def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirectional: bool = False): +def analyse_state_transitions(repdat: Repdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False): """ Count the number of times a transition occurs between pairs of states, based on the repdat info. Parameters @@ -358,6 +358,8 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti repdat: Repdat Redat object containing the information regarding the replica exchange trials in the RE-EDS simulation. + min_s: int, optional + Index of the lowest s_value to consider for the transitions. If None, consider all s values. normalize: bool, optional Normalize the transitions by the total number of outgoing transitions per state bidirectional: bool, optional @@ -390,7 +392,10 @@ def analyse_state_transitions(repdat: Repdat, normalize: bool = False, bidirecti transition_counts = np.zeros((num_states, num_states)) for replica in range(1, num_replicas+1): - state_repdat = enhanced_repdat.query(f"coord_ID == {replica}") + if min_s: + state_repdat = enhanced_repdat.query(f"coord_ID == {replica} & ID <= {min_s}") + else: + state_repdat = enhanced_repdat.query(f"coord_ID == {replica}") state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True) From 3b8b5b88dbdee7fbc47d5243df0b0f4c3bd35392 Mon Sep 17 00:00:00 2001 From: eruijsena Date: Fri, 29 Sep 2023 17:08:14 +0200 Subject: [PATCH 4/8] add plotly to test-env --- devtools/conda-envs/test_env.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 92c7e794..9aaef8c0 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -19,6 +19,7 @@ dependencies: - pandas - pytables - matplotlib + - plotlty # Pip-only installs From 609f26ac632e8efdfcd395ca20e31c1b8ece9d52 Mon Sep 17 00:00:00 2001 From: eruijsena Date: Fri, 29 Sep 2023 17:08:42 +0200 Subject: [PATCH 5/8] spelling --- devtools/conda-envs/test_env.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 9aaef8c0..88a08745 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -19,7 +19,7 @@ dependencies: - pandas - pytables - matplotlib - - plotlty + - plotly # Pip-only installs From 6de0513e633f6eed016e7188acbe2218d3daac70 Mon Sep 17 00:00:00 2001 From: eruijsena Date: Mon, 2 Oct 2023 12:51:24 +0200 Subject: [PATCH 6/8] analyse transitions with expanded repdat --- reeds/function_libs/analysis/sampling.py | 35 ++++++++---------------- reeds/submodules/pygromos | 2 +- 2 files changed, 12 insertions(+), 25 deletions(-) diff --git a/reeds/function_libs/analysis/sampling.py b/reeds/function_libs/analysis/sampling.py index 059df169..c095a197 100644 --- a/reeds/function_libs/analysis/sampling.py +++ b/reeds/function_libs/analysis/sampling.py @@ -6,7 +6,7 @@ import pandas as pd import reeds.function_libs.visualization.sampling_plots -from pygromos.files.repdat import Repdat +from pygromos.files.repdat import ExpandedRepdat def undersampling_occurence_potential_threshold_densityClustering(ene_trajs: List[pd.DataFrame], @@ -350,14 +350,14 @@ def sampling_analysis(ene_trajs: List[pd.DataFrame], return final_results, out_path -def analyse_state_transitions(repdat: Repdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False): +def analyse_state_transitions(repdat: ExpandedRepdat, min_s: int = None, normalize: bool = False, bidirectional: bool = False): """ Count the number of times a transition occurs between pairs of states, based on the repdat info. Parameters ---------- - repdat: Repdat - Redat object containing the information regarding the replica exchange trials in the - RE-EDS simulation. + repdat: ExpandedRepdat + ExpandedRepdat object (created from a Repdat) which contains all the exchange information of a + RE-EDS simulation plus the potential energies of the end-states min_s: int, optional Index of the lowest s_value to consider for the transitions. If None, consider all s values. normalize: bool, optional @@ -373,37 +373,24 @@ def analyse_state_transitions(repdat: Repdat, min_s: int = None, normalize: bool raise Exception("Transitions cannot be normalized w.r.t leaving state and bidirectional") num_replicas = len(repdat.system.s) - repdat_eoffs = repdat.system.state_eir - num_states = len(repdat_eoffs) - eoffs = [[repdat_eoffs[state][s] for state in repdat_eoffs] for s in range(len(repdat_eoffs[1]))] - - # Add end-state potentials as individual columns (this is pretty slow) - - Vi = repdat.DATA["state_potentials"].apply(pd.Series) - values_to_subtract = np.array([eoffs[int(s_val_index-1)] for s_val_index in repdat.DATA['ID'].values]) - - corrected_Vi_array = Vi.values - values_to_subtract - corrected_Vi = pd.DataFrame(corrected_Vi_array, columns=[f"Vr{i+1}" for i in range(num_states)]) - max_contrib = corrected_Vi.idxmin(axis=1) - max_contrib.name = "Vmin" - enhanced_repdat = pd.concat([repdat.DATA, corrected_Vi, max_contrib], axis=1) + num_states = len(repdat.system.state_eir) # Initialize transition counts to zero for all pairs of states transition_counts = np.zeros((num_states, num_states)) for replica in range(1, num_replicas+1): + # Get exchange data per state if min_s: - state_repdat = enhanced_repdat.query(f"coord_ID == {replica} & ID <= {min_s}") + state_repdat = repdat.DATA.query(f"coord_ID == {replica} & ID <= {min_s}") else: - state_repdat = enhanced_repdat.query(f"coord_ID == {replica}") - - state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True) + state_repdat = repdat.DATA.query(f"coord_ID == {replica}") + state_trajectory = state_repdat[["Vmin", "run"]].reset_index(drop=True).copy() # Count the transitions between different states for i in range(len(state_trajectory) - 1): current_state = int("".join([char for char in state_trajectory["Vmin"][i] if char.isdigit()])) # Take the i in Vri next_state = int("".join([char for char in state_trajectory["Vmin"][i + 1] if char.isdigit()])) - current_run = state_trajectory["run"][i] # Check whether you are actually comparing consecutive exchanges + current_run = state_trajectory["run"][i] # Check that you are actually comparing consecutive exchanges next_run = state_trajectory["run"][i+1] if next_run == current_run +1 and current_state != next_state: transition_counts[current_state-1][next_state-1] += 1 diff --git a/reeds/submodules/pygromos b/reeds/submodules/pygromos index 47837858..ad25a499 160000 --- a/reeds/submodules/pygromos +++ b/reeds/submodules/pygromos @@ -1 +1 @@ -Subproject commit 4783785811265b169f26d16e881e662a9d58316d +Subproject commit ad25a49907d5a09caa8b3b40b709da50d015ab64 From 170fd1b30f139b4711cb48d626678fda51b05315 Mon Sep 17 00:00:00 2001 From: eruijsena Date: Mon, 2 Oct 2023 15:49:11 +0200 Subject: [PATCH 7/8] update link to pygromos --- reeds/submodules/pygromos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reeds/submodules/pygromos b/reeds/submodules/pygromos index ad25a499..d9c692ac 160000 --- a/reeds/submodules/pygromos +++ b/reeds/submodules/pygromos @@ -1 +1 @@ -Subproject commit ad25a49907d5a09caa8b3b40b709da50d015ab64 +Subproject commit d9c692acbfa77c27c6366acd87ae5f6dbff1fe7f From 1060309592ef4bc7d46bf15661c1c1f9c08b6ccb Mon Sep 17 00:00:00 2001 From: eruijsena Date: Tue, 3 Oct 2023 14:53:37 +0200 Subject: [PATCH 8/8] updated submodule --- reeds/submodules/pygromos | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reeds/submodules/pygromos b/reeds/submodules/pygromos index d9c692ac..37b32f2b 160000 --- a/reeds/submodules/pygromos +++ b/reeds/submodules/pygromos @@ -1 +1 @@ -Subproject commit d9c692acbfa77c27c6366acd87ae5f6dbff1fe7f +Subproject commit 37b32f2bb897cb1b3b0a0224b010fbf03da7c7b6