Skip to content

Commit

Permalink
minor changes to vis.plotter and wrappers.diagnostics
Browse files Browse the repository at this point in the history
see changes for more details
  • Loading branch information
ray-chew committed May 14, 2024
1 parent 034cda8 commit 5ac0be0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
16 changes: 11 additions & 5 deletions vis/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, fig, nhi, nhj, cbar=True, set_label=True):
self.set_label = set_label

def phys_panel(
self, axs, data, title="", extent=None, xlabel="", ylabel="", v_extent=None
self, axs, data, title="", extent=None, xlabel="", ylabel="", v_extent=None,
):
"""
Plots a physical depiction of the input data.
Expand Down Expand Up @@ -268,6 +268,7 @@ def error_bar_plot(
fs=(10.0, 6.0),
ylabel="",
fontsize=8,
show_grid=True
):
"""
Bar plot of errors.
Expand Down Expand Up @@ -298,6 +299,8 @@ def error_bar_plot(
y-axis label, by default ""
fontsize : int, optional
by default 8
show_grid : bool, optional
toggles grid in output, by default True
"""

data = pd.DataFrame(pmf_diff, index=idx_name, columns=["values"])
Expand Down Expand Up @@ -333,7 +336,8 @@ def error_bar_plot(
fontsize=fontsize,
)

plt.grid()
if show_grid:
plt.grid()

plt.xlabel("first grid pair index", fontsize=fontsize + 3)

Expand Down Expand Up @@ -375,6 +379,7 @@ def error_bar_split_plot(
bs,
ts,
ts_ticks,
color,
fs=(3.5, 3.5),
title="",
output_fig=False,
Expand All @@ -396,10 +401,11 @@ def error_bar_split_plot(
ax2.set_ylim(0, bs)
ax1.set_ylim(ts[0], ts[1])
ax1.set_yticks(ts_ticks)
ax1.ticklabel_format(style='plain')

bars1 = ax1.bar(XX.index, XX.values, color=("C0"))
bars2 = ax2.bar(XX.index, XX.values, color=("C0", "C1", "C2", "r"))
ax1.bar_label(bars1, padding=3)
bars1 = ax1.bar(XX.index, XX.values, color=color)
bars2 = ax2.bar(XX.index, XX.values, color=color)
ax1.bar_label(bars1, padding=3, fmt = '%d')
ax2.bar_label(bars2, padding=3)

for tick in ax2.get_xticklabels():
Expand Down
8 changes: 6 additions & 2 deletions wrappers/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,13 @@ def __write(self):

def __gen_percentage_errs(self):
"""Computes the relative and maximum errors in percentage"""
max_idx = np.argmax(np.abs(self.pmf_refs))
if hasattr(self, "max_val"):
max_val = self.max_val
else:
max_idx = np.argmax(np.abs(self.pmf_refs))
max_val = self.pmf_refs[max_idx]
self.max_errs = self.__get_max_diff(
self.pmf_sums, self.pmf_refs, np.array(self.pmf_refs[max_idx])
self.pmf_sums, self.pmf_refs, max_val
)
self.rel_errs = self.__get_rel_diff(self.pmf_sums, self.pmf_refs)

Expand Down

0 comments on commit 5ac0be0

Please sign in to comment.