Skip to content

Commit

Permalink
Updated formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
LimbeckKat committed Jan 9, 2025
1 parent 17822f1 commit 21143c7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 13 deletions.
8 changes: 6 additions & 2 deletions magnipy/diversipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def get_magnitude_functions(self):
# │ Diversity Summaries │
# ╰──────────────────────────────────────────────────────────╯

def MagAreas(self, integration="trapz", absolute_area=True, scale=True, plot=False):
def MagAreas(
self, integration="trapz", absolute_area=True, scale=True, plot=False
):
"""
Compute the areas under the magnitude functions for all datasets.
Expand Down Expand Up @@ -366,7 +368,9 @@ def plot_MagDiffs_heatmap(self):
import matplotlib.pyplot as plt
import pandas as pd

df = pd.DataFrame(self._MagDiffs, columns=self._names, index=self._names)
df = pd.DataFrame(
self._MagDiffs, columns=self._names, index=self._names
)
sns.heatmap(df, annot=False, cmap="rocket_r")
plt.show()
return None
44 changes: 33 additions & 11 deletions magnipy/utils/tutorial_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def get_Xs():
# Sample 100 points each from Gaussians at (0.5, 0.5) and (1.5, 1.5)
mean2 = [[0.5, 0.5], [1.5, 1.5]]
cov2 = np.eye(2) * 0.02
X3 = np.concatenate([sample_points_gaussian(mean, cov2, 100) for mean in mean2])
X3 = np.concatenate(
[sample_points_gaussian(mean, cov2, 100) for mean in mean2]
)

# Sample 200 points from a Gaussian centered at (0, 0.5)
mean1 = [0.5, 0.5]
Expand Down Expand Up @@ -134,7 +136,9 @@ def show_magnitude_function(df, ts):
def plot_df(df, title):
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.scatter(df["x"], df["y"], df["z"], c=df["z"], cmap="viridis", marker="o", s=5)
ax.scatter(
df["x"], df["y"], df["z"], c=df["z"], cmap="viridis", marker="o", s=5
)
plt.title(title)
plt.show()

Expand All @@ -143,7 +147,9 @@ def plot_dfs(dfs, titles):
n = len(dfs)

# Create a figure with 1 row and n columns of 3D subplots
fig, axes = plt.subplots(1, n, figsize=(18, 6), subplot_kw={"projection": "3d"})
fig, axes = plt.subplots(
1, n, figsize=(18, 6), subplot_kw={"projection": "3d"}
)

for idx in range(0, n):
df = dfs[idx]
Expand Down Expand Up @@ -188,12 +194,18 @@ def plot_matrix_heatmaps(matrices, distance=True, metric="Euclidean Distance"):
vmin = min(matrices[0].min(), matrices[1].min(), matrices[2].min())
vmax = max(matrices[0].max(), matrices[1].max(), matrices[2].max())

rando_heatmap = axs[0].imshow(matrices[0], cmap="viridis", vmin=vmin, vmax=vmax)
rando_heatmap = axs[0].imshow(
matrices[0], cmap="viridis", vmin=vmin, vmax=vmax
)
axs[0].set_title("Random")
axs[0].set_ylabel("Index of Datapoint")
blob_heatmap = axs[1].imshow(matrices[1], cmap="viridis", vmin=vmin, vmax=vmax)
blob_heatmap = axs[1].imshow(
matrices[1], cmap="viridis", vmin=vmin, vmax=vmax
)
axs[1].set_title("Blobs / Clusters")
swiss_heatmap = axs[2].imshow(matrices[2], cmap="viridis", vmin=vmin, vmax=vmax)
swiss_heatmap = axs[2].imshow(
matrices[2], cmap="viridis", vmin=vmin, vmax=vmax
)
axs[2].set_title("Swiss Roll")

fig.colorbar(
Expand All @@ -208,7 +220,9 @@ def plot_matrix_heatmaps(matrices, distance=True, metric="Euclidean Distance"):

def plot_weights(dfs, ts, weights, titles):
# scaling colorbar
vmin = min(weights[0][:, 0].min(), weights[1][:, 0].min(), weights[2][:, 0].min())
vmin = min(
weights[0][:, 0].min(), weights[1][:, 0].min(), weights[2][:, 0].min()
)
vmax = max(
weights[0][:, -1].max(),
weights[1][:, -1].max(),
Expand Down Expand Up @@ -272,7 +286,9 @@ def plot_weights(dfs, ts, weights, titles):
"$t=1/2 * t_{{conv}}$ \n1/2 of the Convergence Scale"
)
else:
axes[idx, t_idx].set_title("$t=t_{{conv}}$ \nConvergence Scale")
axes[idx, t_idx].set_title(
"$t=t_{{conv}}$ \nConvergence Scale"
)

# Adjust layout and show the figure
cbar = fig.colorbar(
Expand Down Expand Up @@ -430,7 +446,9 @@ def plot_simulation_progression(Xs, colors, size):

fig, ax = plt.subplots(1, 3, figsize=(15, 5))

ax[0].scatter(Xs[0][:, 0], Xs[0][:, 1], c=colors[0], cmap="viridis", alpha=0.6)
ax[0].scatter(
Xs[0][:, 0], Xs[0][:, 1], c=colors[0], cmap="viridis", alpha=0.6
)
ax[0].set_xlim(x_int)
ax[0].set_ylim(y_int)
ax[0].set_title("Beginning of Simulation (X0)")
Expand All @@ -444,7 +462,9 @@ def plot_simulation_progression(Xs, colors, size):
ax[1].set_title("Midway Through Simulation")
ax[1].set_xlim(x_int)
ax[1].set_ylim(y_int)
ax[2].scatter(Xs[-1][:, 0], Xs[-1][:, 1], c=colors[-1], cmap="viridis", alpha=0.6)
ax[2].scatter(
Xs[-1][:, 0], Xs[-1][:, 1], c=colors[-1], cmap="viridis", alpha=0.6
)
ax[2].set_title("End of Simulation")
ax[2].set_xlim(x_int)
ax[2].set_ylim(y_int)
Expand Down Expand Up @@ -482,7 +502,9 @@ def create_animation(Xs, colors, div, is_dropping: bool, metric="magdiff"):

# Initial figure
fig, ax = plt.subplots(figsize=(10, 5))
scat = ax.scatter(Xs[0][:, 0], Xs[0][:, 1], c=colors[0], cmap="viridis", alpha=0.6)
scat = ax.scatter(
Xs[0][:, 0], Xs[0][:, 1], c=colors[0], cmap="viridis", alpha=0.6
)

def update(frame):
fig.clear()
Expand Down

0 comments on commit 21143c7

Please sign in to comment.