Skip to content

Commit

Permalink
Add L-H threshold comparison plot and improve bootstrap comparison la…
Browse files Browse the repository at this point in the history
…bels. Make plot_radial_build only shows values to 3dp
  • Loading branch information
chris-ashe committed Nov 20, 2024
1 parent 181627a commit c9564c5
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 6 deletions.
144 changes: 139 additions & 5 deletions process/io/plot_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,18 +2955,18 @@ def plot_bootstrap_comparison(axis, mfile_data, scan):

# Plot average, standard deviation, and median as text
axis.text(
0.65, 0.9, f"Average: {avg_bootstrap:.4f}", transform=axis.transAxes, fontsize=9
1.02, 0.2, f"Average: {avg_bootstrap:.4f}", transform=axis.transAxes, fontsize=9
)
axis.text(
0.65,
0.85,
1.02,
0.15,
f"Standard Dev: {std_bootstrap:.4f}",
transform=axis.transAxes,
fontsize=9,
)
axis.text(
0.65,
0.8,
1.02,
0.1,
f"Median: {median_bootstrap:.4f}",
transform=axis.transAxes,
fontsize=9,
Expand All @@ -2977,6 +2977,137 @@ def plot_bootstrap_comparison(axis, mfile_data, scan):
axis.set_xlim([0.5, 1.5])
axis.set_xticks([])
axis.set_xticklabels([])
axis.set_facecolor("#f0f0f0")


def plot_h_threshold_comparison(
axis: plt.Axes, mfile_data: mf.MFile, scan: int
) -> None:
"""
Function to plot a scatter box plot of L-H threshold power comparisons.
Arguments:
axis (plt.Axes): Axis object to plot to.
mfile_data (mf.MFile): MFILE data object.
scan (int): Scan number to use.
"""
iter_nominal = mfile_data.data["pthrmw(1)"].get_scan(scan)
iter_upper = mfile_data.data["pthrmw(2)"].get_scan(scan)
iter_lower = mfile_data.data["pthrmw(3)"].get_scan(scan)
iter_1997_1 = mfile_data.data["pthrmw(4)"].get_scan(scan)
iter_1997_2 = mfile_data.data["pthrmw(5)"].get_scan(scan)
martin_nominal = mfile_data.data["pthrmw(6)"].get_scan(scan)
martin_upper = mfile_data.data["pthrmw(7)"].get_scan(scan)
martin_lower = mfile_data.data["pthrmw(8)"].get_scan(scan)
snipes_nominal = mfile_data.data["pthrmw(9)"].get_scan(scan)
snipes_upper = mfile_data.data["pthrmw(10)"].get_scan(scan)
snipes_lower = mfile_data.data["pthrmw(11)"].get_scan(scan)
snipes_closed_nominal = mfile_data.data["pthrmw(12)"].get_scan(scan)
snipes_closed_upper = mfile_data.data["pthrmw(13)"].get_scan(scan)
snipes_closed_lower = mfile_data.data["pthrmw(14)"].get_scan(scan)
hubbard_nominal = mfile_data.data["pthrmw(15)"].get_scan(scan)
hubbard_lower = mfile_data.data["pthrmw(16)"].get_scan(scan)
hubbard_upper = mfile_data.data["pthrmw(17)"].get_scan(scan)
hubbard_2017 = mfile_data.data["pthrmw(18)"].get_scan(scan)
martin_aspect_nominal = mfile_data.data["pthrmw(19)"].get_scan(scan)
martin_aspect_upper = mfile_data.data["pthrmw(20)"].get_scan(scan)
martin_aspect_lower = mfile_data.data["pthrmw(21)"].get_scan(scan)

# Data for the box plot
data = {
"ITER 1996 Nominal": iter_nominal,
"ITER 1996 Upper": iter_upper,
"ITER 1996 Lower": iter_lower,
"ITER 1997 (1)": iter_1997_1,
"ITER 1997 (2)": iter_1997_2,
"Martin Nominal": martin_nominal,
"Martin Upper": martin_upper,
"Martin Lower": martin_lower,
"Snipes Nominal": snipes_nominal,
"Snipes Upper": snipes_upper,
"Snipes Lower": snipes_lower,
"Snipes Closed Divertor Nominal": snipes_closed_nominal,
"Snipes Closed Divertor Upper": snipes_closed_upper,
"Snipes Closed Divertor Lower": snipes_closed_lower,
"Hubbard Nominal (I-mode)": hubbard_nominal,
"Hubbard Lower (I-mode)": hubbard_lower,
"Hubbard Upper (I-mode)": hubbard_upper,
"Hubbard 2017 (I-mode)": hubbard_2017,
"Martin Aspect Corrected Nominal": martin_aspect_nominal,
"Martin Aspect Corrected Upper": martin_aspect_upper,
"Martin Aspect Corrected Lower": martin_aspect_lower,
}

# Create the violin plot
axis.violinplot(data.values(), showextrema=False)

# Create the box plot
axis.boxplot(
data.values(), showfliers=True, showmeans=True, meanline=True, widths=0.3
)

# Scatter plot for each data point
colors = plt.cm.plasma(np.linspace(0, 1, len(data.values())))
x_values = np.random.normal(loc=1, scale=0.01, size=len(data.values()))
for index, (key, value) in enumerate(data.items()):
if "ITER 1996" in key:
color = "blue"
elif "ITER 1997" in key:
color = "cyan"
elif "Martin" in key and "Aspect" not in key:
color = "green"
elif "Snipes" in key and "Closed" not in key:
color = "red"
elif "Snipes Closed" in key:
color = "orange"
elif "Martin Aspect" in key:
color = "yellow"
elif "Hubbard" in key and "2017" not in key:
color = "purple"
elif "Hubbard 2017" in key:
color = "magenta"
else:
color = colors[index]
axis.scatter(x_values[index], value, color=color, label=key, alpha=1.0)
axis.legend(loc="upper left", bbox_to_anchor=(-1.1, 1), ncol=2)

# Calculate average, standard deviation, and median
data_values = list(data.values())
avg_threshold = np.mean(data_values)
std_threshold = np.std(data_values)
median_threshold = np.median(data_values)

# Plot average, standard deviation, and median as text
axis.text(
-0.45,
0.15,
f"Average: {avg_threshold:.4f}",
transform=axis.transAxes,
fontsize=9,
)
axis.text(
-0.45,
0.1,
f"Standard Dev: {std_threshold:.4f}",
transform=axis.transAxes,
fontsize=9,
)
axis.text(
-0.45,
0.05,
f"Median: {median_threshold:.4f}",
transform=axis.transAxes,
fontsize=9,
)

axis.set_title("L-H Threshold Comparison")
axis.set_ylabel("L-H threshold power [MW]")
axis.set_xlim([0.5, 1.5])
axis.set_xticks([])
axis.set_xticklabels([])

# Add background color
axis.set_facecolor("#f0f0f0")


def main_plot(
Expand Down Expand Up @@ -3083,6 +3214,9 @@ def main_plot(
plot_9 = fig4.add_subplot(221)
plot_bootstrap_comparison(plot_9, m_file_data, scan)

plot_10 = fig4.add_subplot(224)
plot_h_threshold_comparison(plot_10, m_file_data, scan)


def main(args=None):
# TODO The use of globals here isn't ideal, but is required to get main()
Expand Down
3 changes: 2 additions & 1 deletion process/io/plot_radial_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,8 @@ def main(args=None):
radial_build[kk, :],
left=lower,
height=0.8,
label=f"{radial_labels[kk]}" + f"\n {radial_build[kk][0]} m" * args.numbers,
label=f"{radial_labels[kk]}"
+ f"\n {radial_build[kk][0]:.3f} m" * args.numbers,
color=radial_color[kk],
edgecolor="black",
linewidth=0.05,
Expand Down

0 comments on commit c9564c5

Please sign in to comment.