diff --git a/Extra/plot_fig1.ipynb b/Extra/plot_fig1.ipynb
index 6e62a04..09e44bc 100644
--- a/Extra/plot_fig1.ipynb
+++ b/Extra/plot_fig1.ipynb
@@ -25,13 +25,13 @@
"\n",
"\n",
"# Set the global default size of the axis labels\n",
- "plt.rcParams['axes.labelsize'] = 20\n",
+ "plt.rcParams[\"axes.labelsize\"] = 20\n",
"# Set the global default size of the tick labels\n",
- "plt.rcParams['xtick.labelsize'] = 15\n",
- "plt.rcParams['ytick.labelsize'] = 15\n",
- "plt.rcParams['axes.titlesize'] = 25\n",
- "plt.rcParams['legend.fontsize'] = 15\n",
- "plt.rcParams['xtick.major.size'] = 7 # length in points\n"
+ "plt.rcParams[\"xtick.labelsize\"] = 15\n",
+ "plt.rcParams[\"ytick.labelsize\"] = 15\n",
+ "plt.rcParams[\"axes.titlesize\"] = 25\n",
+ "plt.rcParams[\"legend.fontsize\"] = 15\n",
+ "plt.rcParams[\"xtick.major.size\"] = 7 # length in points"
]
},
{
@@ -40,10 +40,10 @@
"metadata": {},
"outputs": [],
"source": [
- "#pallete = [\"#F8B195\",\"#355C7D\",\"#F67280\",\"#C06C84\",\"#6C5B7B\"]\n",
- "pallete = [\"#355C7D\",\"#F67280\",\"#F8B195\",\"#C06C84\",\"#6C5B7B\"]\n",
- "mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color= pallete)\n",
- "cmap = mcolors.LinearSegmentedColormap.from_list('my_cmap',[pallete[0],pallete[1]])"
+ "# pallete = [\"#F8B195\",\"#355C7D\",\"#F67280\",\"#C06C84\",\"#6C5B7B\"]\n",
+ "pallete = [\"#355C7D\", \"#F67280\", \"#F8B195\", \"#C06C84\", \"#6C5B7B\"]\n",
+ "mpl.rcParams[\"axes.prop_cycle\"] = mpl.cycler(color=pallete)\n",
+ "cmap = mcolors.LinearSegmentedColormap.from_list(\"my_cmap\", [pallete[0], pallete[1]])"
]
},
{
@@ -97,7 +97,7 @@
"\n",
"x = contagion_process(A, gamma, c, x0, tmin=0, tmax=100, random_seed=2)\n",
"\n",
- "infected_color = 'C1' \n",
+ "infected_color = \"C1\"\n",
"susceptible_color = \"white\"\n",
"subgraph_color = \"black\"\n",
"graph_color = (0.1, 0.1, 0.1, 0.1)\n",
@@ -111,7 +111,7 @@
"pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=5, k=0.3))\n",
"node_fc = [infected_color if x[t, i] else susceptible_color for i in H.nodes]\n",
"node_ec = [subgraph_node_lc if n in nbrs else graph_node_lc for n in H.nodes]\n",
- "node_fc[12] = 'C0'\n",
+ "node_fc[12] = \"C0\"\n",
"\n",
"dyad_color = [subgraph_color if e in sg else graph_color for e in H.edges]\n",
"\n",
@@ -245,7 +245,7 @@
"\n",
"# simple contagion\n",
"c1_mean = c1_samples.mean(axis=0)\n",
- "plt.plot(nus, c1, \"-\", color='C0', label=\"Simple contagion\")\n",
+ "plt.plot(nus, c1, \"-\", color=\"C0\", label=\"Simple contagion\")\n",
"# plt.scatter(nus, c1_mean, linewidth=0.5, color=colors[2])\n",
"\n",
"err_c1 = np.zeros((2, n))\n",
@@ -255,11 +255,11 @@
" x, y = interval\n",
" err_c1[0, i] = max(c1_mean[i] - x, 0)\n",
" err_c1[1, i] = max(y - c1_mean[i], 0)\n",
- "plt.errorbar(nus, c1_mean, err_c1, color='C0', fmt=\"o\")\n",
+ "plt.errorbar(nus, c1_mean, err_c1, color=\"C0\", fmt=\"o\")\n",
"\n",
"# threshold contagion, tau=2\n",
"c2_mean = c2_samples.mean(axis=0)\n",
- "plt.plot(nus, c2, \"-\", color='C1', label=\"Complex contagion\")\n",
+ "plt.plot(nus, c2, \"-\", color=\"C1\", label=\"Complex contagion\")\n",
"# plt.scatter(nus, c2_mean, linewidth=0.5, color=colors[1])\n",
"\n",
"err_c2 = np.zeros((2, n))\n",
@@ -269,7 +269,7 @@
" x, y = interval\n",
" err_c2[0, i] = max(c2_mean[i] - x, 0)\n",
" err_c2[1, i] = max(y - c2_mean[i], 0)\n",
- "plt.errorbar(nus, c2_mean, err_c2, color='C1', fmt=\"o\")\n",
+ "plt.errorbar(nus, c2_mean, err_c2, color=\"C1\", fmt=\"o\")\n",
"\n",
"plt.xticks(np.arange(0, n, 5))\n",
"plt.xlabel(r\"$\\nu$\")\n",
@@ -310,7 +310,7 @@
"with open(\"Data/zkc_infer_vs_tmax.json\") as file:\n",
" data = json.load(file)\n",
"\n",
- "#colors = [\"steelblue\", \"darksalmon\", \"mediumseagreen\"]\n",
+ "# colors = [\"steelblue\", \"darksalmon\", \"mediumseagreen\"]\n",
"colors = pallete\n",
"\n",
"tmax = data[\"tmax\"]\n",
@@ -320,22 +320,22 @@
"\n",
"fig = plt.figure(figsize=(4, 3))\n",
"\n",
- "#plt.semilogx(tmax, sps[0].mean(axis=1), color=colors[2], label=\"Simple contagion\")\n",
- "plt.semilogx(tmax, sps[0].mean(axis=1), color='C0', label=\"Simple contagion\")\n",
- "plt.semilogx(tmax, sps[1].mean(axis=1), color='C1', label=\"Complex contagion\")\n",
+ "# plt.semilogx(tmax, sps[0].mean(axis=1), color=colors[2], label=\"Simple contagion\")\n",
+ "plt.semilogx(tmax, sps[0].mean(axis=1), color=\"C0\", label=\"Simple contagion\")\n",
+ "plt.semilogx(tmax, sps[1].mean(axis=1), color=\"C1\", label=\"Complex contagion\")\n",
"plt.fill_between(\n",
" tmax,\n",
" sps[0].mean(axis=1) - sps[0].std(axis=1),\n",
" sps[0].mean(axis=1) + sps[0].std(axis=1),\n",
" alpha=0.3,\n",
- " color='C0',\n",
+ " color=\"C0\",\n",
")\n",
"plt.fill_between(\n",
" tmax,\n",
" sps[1].mean(axis=1) - sps[1].std(axis=1),\n",
" sps[1].mean(axis=1) + sps[1].std(axis=1),\n",
" alpha=0.3,\n",
- " color='C1',\n",
+ " color=\"C1\",\n",
")\n",
"plt.ylabel(\"F-Score\")\n",
"plt.xlabel(r\"$t_{max}$\")\n",
@@ -373,7 +373,7 @@
"sps = np.array(data[\"sps\"], dtype=float)\n",
"fce = np.array(data[\"fce\"], dtype=float)\n",
"\n",
- "#cmap = cmr.gem\n",
+ "# cmap = cmr.gem\n",
"cmap = cmap\n",
"\n",
"sps_summary = sps.mean(axis=2)\n",
@@ -428,7 +428,9 @@
}
],
"source": [
- "fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(2,2,figsize=(8,6), sharey=False, sharex=False)\n",
+ "fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(\n",
+ " 2, 2, figsize=(8, 6), sharey=False, sharex=False\n",
+ ")\n",
"\n",
"\"\"\"\n",
"Panel 1: Network Viz\n",
@@ -450,7 +452,7 @@
"\n",
"x = contagion_process(A, gamma, c, x0, tmin=0, tmax=100, random_seed=2)\n",
"\n",
- "infected_color = 'C1' \n",
+ "infected_color = \"C1\"\n",
"susceptible_color = \"white\"\n",
"subgraph_color = \"black\"\n",
"graph_color = (0.1, 0.1, 0.1, 0.1)\n",
@@ -464,12 +466,11 @@
"pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=5, k=0.3))\n",
"node_fc = [infected_color if x[t, i] else susceptible_color for i in H.nodes]\n",
"node_ec = [subgraph_node_lc if n in nbrs else graph_node_lc for n in H.nodes]\n",
- "node_fc[12] = 'C0'\n",
+ "node_fc[12] = \"C0\"\n",
"\n",
"dyad_color = [subgraph_color if e in sg else graph_color for e in H.edges]\n",
"\n",
"\n",
- "\n",
"xgi.draw(\n",
" H,\n",
" pos=pos,\n",
@@ -478,7 +479,7 @@
" dyad_color=dyad_color,\n",
" node_ec=node_ec,\n",
" node_lw=0.5,\n",
- " ax = ax1\n",
+ " ax=ax1,\n",
")\n",
"\n",
"# plt.savefig(\"Figures/Fig1/zkc_network.svg\", dpi=1000)\n",
@@ -517,7 +518,7 @@
"\n",
"# simple contagion\n",
"c1_mean = c1_samples.mean(axis=0)\n",
- "ax2.plot(nus, c1, \"-\", color='C0', label=\"Simple contagion\")\n",
+ "ax2.plot(nus, c1, \"-\", color=\"C0\", label=\"Simple contagion\")\n",
"# ax2.scatter(nus, c1_mean, linewidth=0.5, color=colors[2])\n",
"\n",
"err_c1 = np.zeros((2, n))\n",
@@ -527,11 +528,11 @@
" x, y = interval\n",
" err_c1[0, i] = max(c1_mean[i] - x, 0)\n",
" err_c1[1, i] = max(y - c1_mean[i], 0)\n",
- "ax2.errorbar(nus, c1_mean, err_c1, color='C0', fmt=\"o\")\n",
+ "ax2.errorbar(nus, c1_mean, err_c1, color=\"C0\", fmt=\"o\")\n",
"\n",
"# threshold contagion, tau=2\n",
"c2_mean = c2_samples.mean(axis=0)\n",
- "ax2.plot(nus, c2, \"-\", color='C1', label=\"Complex contagion\")\n",
+ "ax2.plot(nus, c2, \"-\", color=\"C1\", label=\"Complex contagion\")\n",
"# ax2.scatter(nus, c2_mean, linewidth=0.5, color=colors[1])\n",
"\n",
"err_c2 = np.zeros((2, n))\n",
@@ -541,7 +542,7 @@
" x, y = interval\n",
" err_c2[0, i] = max(c2_mean[i] - x, 0)\n",
" err_c2[1, i] = max(y - c2_mean[i], 0)\n",
- "ax2.errorbar(nus, c2_mean, err_c2, color='C1', fmt=\"o\")\n",
+ "ax2.errorbar(nus, c2_mean, err_c2, color=\"C1\", fmt=\"o\")\n",
"\n",
"ax2.set_xticks(np.arange(0, n, 5))\n",
"ax2.set_xlabel(r\"$\\nu$\")\n",
@@ -563,7 +564,7 @@
"with open(\"Data/zkc_infer_vs_tmax.json\") as file:\n",
" data = json.load(file)\n",
"\n",
- "#colors = [\"steelblue\", \"darksalmon\", \"mediumseagreen\"]\n",
+ "# colors = [\"steelblue\", \"darksalmon\", \"mediumseagreen\"]\n",
"colors = pallete\n",
"\n",
"tmax = data[\"tmax\"]\n",
@@ -572,22 +573,22 @@
"fce = np.array(data[\"fce\"], dtype=float)\n",
"\n",
"\n",
- "#ax3.semilogx(tmax, sps[0].mean(axis=1), color=colors[2], label=\"Simple contagion\")\n",
- "ax3.semilogx(tmax, sps[0].mean(axis=1), color='C0', label=\"Simple contagion\")\n",
- "ax3.semilogx(tmax, sps[1].mean(axis=1), color='C1', label=\"Complex contagion\")\n",
+ "# ax3.semilogx(tmax, sps[0].mean(axis=1), color=colors[2], label=\"Simple contagion\")\n",
+ "ax3.semilogx(tmax, sps[0].mean(axis=1), color=\"C0\", label=\"Simple contagion\")\n",
+ "ax3.semilogx(tmax, sps[1].mean(axis=1), color=\"C1\", label=\"Complex contagion\")\n",
"ax3.fill_between(\n",
" tmax,\n",
" sps[0].mean(axis=1) - sps[0].std(axis=1),\n",
" sps[0].mean(axis=1) + sps[0].std(axis=1),\n",
" alpha=0.3,\n",
- " color='C0',\n",
+ " color=\"C0\",\n",
")\n",
"ax3.fill_between(\n",
" tmax,\n",
" sps[1].mean(axis=1) - sps[1].std(axis=1),\n",
" sps[1].mean(axis=1) + sps[1].std(axis=1),\n",
" alpha=0.3,\n",
- " color='C1',\n",
+ " color=\"C1\",\n",
")\n",
"ax3.set_ylabel(\"F-Score\")\n",
"ax3.set_xlabel(r\"$t_{max}$\")\n",
@@ -596,7 +597,6 @@
"sns.despine()\n",
"\n",
"\n",
- "\n",
"with open(\"Data/zkc_frac_vs_beta.json\") as file:\n",
" data = json.load(file)\n",
"beta = np.array(data[\"beta\"], dtype=float)\n",
@@ -605,7 +605,7 @@
"sps = np.array(data[\"sps\"], dtype=float)\n",
"fce = np.array(data[\"fce\"], dtype=float)\n",
"\n",
- "#cmap = cmr.gem\n",
+ "# cmap = cmr.gem\n",
"cmap = cmap\n",
"\n",
"sps_summary = sps.mean(axis=2)\n",
@@ -625,11 +625,8 @@
"ax4.set_xticks([0, 0.5, 1], [0, 0.5, 1])\n",
"ax4.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])\n",
"\n",
- "cbar = plt.colorbar(c,ax = ax4)\n",
- "cbar.set_label(r\"F-Score\", fontsize=12, rotation=270, labelpad=15)\n",
- "\n",
- "\n",
- "\n"
+ "cbar = plt.colorbar(c, ax=ax4)\n",
+ "cbar.set_label(r\"F-Score\", fontsize=12, rotation=270, labelpad=15)"
]
},
{
diff --git a/Extra/plot_fig2.ipynb b/Extra/plot_fig2.ipynb
index 3abb144..dbe7437 100644
--- a/Extra/plot_fig2.ipynb
+++ b/Extra/plot_fig2.ipynb
@@ -15,6 +15,7 @@
"import xgi\n",
"\n",
"import fig_settings as fs\n",
+ "\n",
"cmap = fs.cmap\n",
"\n",
"\n",
@@ -28,7 +29,7 @@
"outputs": [],
"source": [
"fs.set_fonts({\"font.family\": \"sans-serif\"})\n",
- "#cmap = cmr.gem"
+ "# cmap = cmr.gem"
]
},
{
diff --git a/Figures/Fig1/figure1_4panel.pdf b/Figures/Fig1/figure1_4panel.pdf
deleted file mode 100644
index a6ddcae..0000000
Binary files a/Figures/Fig1/figure1_4panel.pdf and /dev/null differ
diff --git a/Figures/Fig1/figure1_4panel.png b/Figures/Fig1/figure1_4panel.png
deleted file mode 100644
index 2b74ba8..0000000
Binary files a/Figures/Fig1/figure1_4panel.png and /dev/null differ
diff --git a/Figures/Fig1/illustration.pdf b/Figures/Fig1/illustration.pdf
index ce16ae9..64558b7 100644
Binary files a/Figures/Fig1/illustration.pdf and b/Figures/Fig1/illustration.pdf differ
diff --git a/Figures/Fig1/illustration.png b/Figures/Fig1/illustration.png
new file mode 100644
index 0000000..1a5993b
Binary files /dev/null and b/Figures/Fig1/illustration.png differ
diff --git a/Figures/Fig1/illustration.svg b/Figures/Fig1/illustration.svg
deleted file mode 100644
index 71e4032..0000000
--- a/Figures/Fig1/illustration.svg
+++ /dev/null
@@ -1,3547 +0,0 @@
-
-
-
-
diff --git a/Figures/Fig1/zkc_frac_vs_beta.png b/Figures/Fig1/zkc_frac_vs_beta.png
deleted file mode 100644
index c08daf8..0000000
Binary files a/Figures/Fig1/zkc_frac_vs_beta.png and /dev/null differ
diff --git a/Figures/Fig1/zkc_frac_vs_beta.svg b/Figures/Fig1/zkc_frac_vs_beta.svg
deleted file mode 100644
index a4c7e1a..0000000
--- a/Figures/Fig1/zkc_frac_vs_beta.svg
+++ /dev/null
@@ -1,886 +0,0 @@
-
-
-
diff --git a/Figures/Fig1/zkc_infer_contagion_function.png b/Figures/Fig1/zkc_infer_contagion_function.png
deleted file mode 100644
index 8016603..0000000
Binary files a/Figures/Fig1/zkc_infer_contagion_function.png and /dev/null differ
diff --git a/Figures/Fig1/zkc_infer_contagion_function.svg b/Figures/Fig1/zkc_infer_contagion_function.svg
deleted file mode 100644
index 83ff8a7..0000000
--- a/Figures/Fig1/zkc_infer_contagion_function.svg
+++ /dev/null
@@ -1,1369 +0,0 @@
-
-
-
diff --git a/Figures/Fig1/zkc_infer_vs_tmax.png b/Figures/Fig1/zkc_infer_vs_tmax.png
deleted file mode 100644
index 8ada19d..0000000
Binary files a/Figures/Fig1/zkc_infer_vs_tmax.png and /dev/null differ
diff --git a/Figures/Fig1/zkc_infer_vs_tmax.svg b/Figures/Fig1/zkc_infer_vs_tmax.svg
deleted file mode 100644
index f42a9fb..0000000
--- a/Figures/Fig1/zkc_infer_vs_tmax.svg
+++ /dev/null
@@ -1,1754 +0,0 @@
-
-
-
diff --git a/Figures/Fig1/zkc_network.png b/Figures/Fig1/zkc_network.png
deleted file mode 100644
index 36f38ef..0000000
Binary files a/Figures/Fig1/zkc_network.png and /dev/null differ
diff --git a/Figures/Fig1/zkc_network.svg b/Figures/Fig1/zkc_network.svg
deleted file mode 100644
index edeea25..0000000
--- a/Figures/Fig1/zkc_network.svg
+++ /dev/null
@@ -1,546 +0,0 @@
-
-
-
diff --git a/Figures/Fig2/generative_models_fce.pdf b/Figures/Fig2/generative_models_fce.pdf
deleted file mode 100644
index 928fd07..0000000
Binary files a/Figures/Fig2/generative_models_fce.pdf and /dev/null differ
diff --git a/Figures/Fig2/generative_models_fce.png b/Figures/Fig2/generative_models_fce.png
deleted file mode 100644
index bed675f..0000000
Binary files a/Figures/Fig2/generative_models_fce.png and /dev/null differ
diff --git a/Figures/Fig2/generative_models_sps.pdf b/Figures/Fig2/generative_models_sps.pdf
index e157740..34aa2f6 100644
Binary files a/Figures/Fig2/generative_models_sps.pdf and b/Figures/Fig2/generative_models_sps.pdf differ
diff --git a/Figures/Fig2/generative_models_sps.png b/Figures/Fig2/generative_models_sps.png
index 59eaf58..91b830c 100644
Binary files a/Figures/Fig2/generative_models_sps.png and b/Figures/Fig2/generative_models_sps.png differ
diff --git a/fig_settings.py b/fig_settings.py
index d3fbd6a..531c1f4 100644
--- a/fig_settings.py
+++ b/fig_settings.py
@@ -5,32 +5,20 @@
@author: John Meluso
"""
-import os
-import matplotlib.pylab as pylab
-import matplotlib.pyplot as plt
-import matplotlib as mpl
import cmasher as cmr
+import matplotlib as mpl
+import matplotlib.pylab as pylab
-#color styling
-def set_colors(n_colors = 2):
- global cmap
+# color styling
+def set_colors(n_colors=2):
+ global cmap
global pallette
- cmap = 'cmr.redshift'
+ cmap = "cmr.redshift"
qualitative_cmap = cmr.get_sub_cmap(cmap, 0.2, 0.8, N=n_colors)
pallette = qualitative_cmap.colors
- mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color= pallette)
-
-
-def set_fontsize():
- plt.rcParams['axes.labelsize'] = 30
- # Set the global default size of the tick labels
- plt.rcParams['xtick.labelsize'] = 15
- plt.rcParams['ytick.labelsize'] = 15
- plt.rcParams['axes.titlesize'] = 25
- plt.rcParams['legend.fontsize'] = 25
-
+ mpl.rcParams["axes.prop_cycle"] = mpl.cycler(color=pallette)
def set_fonts(extra_params={}):
@@ -38,68 +26,13 @@ def set_fonts(extra_params={}):
"font.family": "Sans-Serif",
"font.sans-serif": ["Tahoma", "DejaVu Sans", "Lucida Grande", "Verdana"],
"mathtext.fontset": "cm",
- "legend.fontsize": 10,
- "axes.labelsize": 12,
- "axes.titlesize": 16,
- "xtick.labelsize": 12,
- "ytick.labelsize": 12,
- "figure.titlesize": 16,
+ "legend.fontsize": 12,
+ "axes.labelsize": 15,
+ "axes.titlesize": 15,
+ "xtick.labelsize": 15,
+ "ytick.labelsize": 15,
+ "figure.titlesize": 15,
}
for key, value in extra_params.items():
params[key] = value
pylab.rcParams.update(params)
-
-
-def fig_size(frac_width, frac_height, n_cols=1, n_rows=1):
- # Set default sizes
- page_width = 8.5
- page_height = 11
- side_margins = 1
- tb_margins = 1
- middle_margin = 0.25
- mid_marg_width = middle_margin * (n_cols - 1)
- mid_marg_height = middle_margin * (n_rows - 1)
-
- # Width logic
- if frac_width == 1:
- width = page_width - side_margins
- else:
- width = (page_width - side_margins - mid_marg_width) * frac_width
-
- # Height logic
- if frac_height == 1:
- height = page_height - tb_margins
- else:
- height = (page_height - tb_margins - mid_marg_height) * frac_height
-
- return (width, height)
-
-def get_formats():
- return ["eps", "jpg", "pdf", "png", "tif"]
-
-
-def set_border(ax, top=False, bottom=False, left=False, right=False):
- ax.spines["top"].set_visible(top)
- ax.spines["right"].set_visible(right)
- ax.spines["bottom"].set_visible(bottom)
- ax.spines["left"].set_visible(left)
-
-
-def save_publication_fig(name, dpi=1200, **kwargs):
- save_fig(name, dpi, fig_type="publication", **kwargs)
-
-
-def save_presentation_fig(name, dpi=1200, **kwargs):
- save_fig(name, dpi, fig_type="presentation", **kwargs)
-
-
-def save_fig(name, dpi=1200, fig_type=None, **kwargs):
- for ff in get_formats():
- if fig_type:
- path = f"../figures/{fig_type}/{ff}"
- else:
- path = f"../figures/{ff}"
- if not os.path.exists(path):
- os.makedirs(path)
- fname = f"{path}/{name}.{ff}"
- plt.savefig(fname, format=ff, dpi=dpi, **kwargs)
diff --git a/plot_fig1.py b/plot_fig1.py
index d319b4a..d1fa547 100644
--- a/plot_fig1.py
+++ b/plot_fig1.py
@@ -1,25 +1,27 @@
-
import json
+
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import xgi
+from matplotlib.gridspec import GridSpec
+
import fig_settings as fs
from lcs import *
-
-
-fs.set_fonts()
+fs.set_fonts({"font.family": "sans-serif"})
fs.set_colors()
-fs.set_fontsize()
cmap = fs.cmap
-fig, ((ax1, ax2),(ax3, ax4)) = plt.subplots(2,2,figsize=(16,12), sharey=False, sharex=False)
+fig = plt.figure(figsize=(8, 6))
+gs = GridSpec(2, 2, hspace=0.4, wspace=0.4)
"""
Panel 1: Network Viz
"""
+ax1 = fig.add_subplot(gs[0])
+
el = zkc(format="edgelist")
H = xgi.Hypergraph(el)
A = zkc()
@@ -37,7 +39,7 @@
x = contagion_process(A, gamma, c, x0, tmin=0, tmax=100, random_seed=2)
-infected_color = 'C0'
+infected_color = "C0"
susceptible_color = "white"
subgraph_color = "black"
graph_color = (0.1, 0.1, 0.1, 0.1)
@@ -51,28 +53,28 @@
pos = xgi.pca_transform(xgi.pairwise_spring_layout(H, seed=5, k=0.3))
node_fc = [infected_color if x[t, i] else susceptible_color for i in H.nodes]
node_ec = [subgraph_node_lc if n in nbrs else graph_node_lc for n in H.nodes]
-node_fc[12] = 'C1'
+node_fc[12] = "C1"
dyad_color = [subgraph_color if e in sg else graph_color for e in H.edges]
-
xgi.draw(
H,
pos=pos,
node_size=7.5,
node_fc=node_fc,
dyad_color=dyad_color,
+ dyad_lw=0.5,
node_ec=node_ec,
node_lw=0.5,
- ax = ax1
+ ax=ax1,
)
"""
Panel 2:
"""
-
+ax2 = fig.add_subplot(gs[1])
with open("Data/zkc_infer_contagion_functions.json") as file:
data = json.load(file)
@@ -97,7 +99,7 @@
# simple contagion
c1_mean = c1_samples.mean(axis=0)
-ax2.plot(nus, c1, "-", color='C0', label="Simple contagion")
+ax2.plot(nus, c1, "-", color="C0", label="Simple contagion")
err_c1 = np.zeros((2, n))
c1_mode = np.zeros(n)
@@ -106,11 +108,11 @@
x, y = interval
err_c1[0, i] = max(c1_mean[i] - x, 0)
err_c1[1, i] = max(y - c1_mean[i], 0)
-ax2.errorbar(nus, c1_mean, err_c1, color='C0', fmt="o")
+ax2.errorbar(nus, c1_mean, err_c1, color="C0", fmt="o")
# threshold contagion, tau=2
c2_mean = c2_samples.mean(axis=0)
-ax2.plot(nus, c2, "-", color='C1', label="Complex contagion")
+ax2.plot(nus, c2, "-", color="C1", label="Complex contagion")
err_c2 = np.zeros((2, n))
c2_mode = np.zeros(n)
@@ -119,7 +121,7 @@
x, y = interval
err_c2[0, i] = max(c2_mean[i] - x, 0)
err_c2[1, i] = max(y - c2_mean[i], 0)
-ax2.errorbar(nus, c2_mean, err_c2, color='C1', fmt="o")
+ax2.errorbar(nus, c2_mean, err_c2, color="C1", fmt="o")
ax2.set_xticks(np.arange(0, n, 5))
ax2.set_xlabel(r"$\nu$")
@@ -133,6 +135,11 @@
sns.despine()
+""""
+Panel 3: recovery vs. tmax
+"""
+ax3 = fig.add_subplot(gs[2])
+
with open("Data/zkc_infer_vs_tmax.json") as file:
data = json.load(file)
@@ -144,21 +151,21 @@
fce = np.array(data["fce"], dtype=float)
-ax3.semilogx(tmax, sps[0].mean(axis=1), color='C0', label="Simple contagion")
-ax3.semilogx(tmax, sps[1].mean(axis=1), color='C1', label="Complex contagion")
+ax3.semilogx(tmax, sps[0].mean(axis=1), color="C0", label="Simple contagion")
+ax3.semilogx(tmax, sps[1].mean(axis=1), color="C1", label="Complex contagion")
ax3.fill_between(
tmax,
sps[0].mean(axis=1) - sps[0].std(axis=1),
sps[0].mean(axis=1) + sps[0].std(axis=1),
alpha=0.3,
- color='C0',
+ color="C0",
)
ax3.fill_between(
tmax,
sps[1].mean(axis=1) - sps[1].std(axis=1),
sps[1].mean(axis=1) + sps[1].std(axis=1),
alpha=0.3,
- color='C1',
+ color="C1",
)
ax3.set_ylabel("F-Score")
ax3.set_xlabel(r"$t_{max}$")
@@ -166,7 +173,10 @@
ax3.legend(loc="upper left")
sns.despine()
-
+"""
+Panel 4: heatmap of recover vs. beta and f
+"""
+ax4 = fig.add_subplot(gs[3])
with open("Data/zkc_frac_vs_beta.json") as file:
data = json.load(file)
@@ -196,9 +206,10 @@
ax4.set_yticks([0, 0.25, 0.5, 0.75, 1], [0, 0.25, 0.5, 0.75, 1])
-cbar_ax = fig.add_axes([0.92, 0.11, 0.02, 0.35]) # x, y, width, height
+cbar_ax = fig.add_axes([0.91, 0.11, 0.015, 0.32]) # x, y, width, height
cbar = plt.colorbar(c, cax=cbar_ax)
-cbar.set_label(r"F-Score", fontsize=12, rotation=270, labelpad=15)
+cbar.set_label(r"F-Score", fontsize=15, rotation=270, labelpad=15)
+cbar_ax.set_yticks([0, 0.5, 1], [0, 0.5, 1], fontsize=15)
-plt.savefig("Figures/Fig1/figure1_4panel.png", dpi=1000)
-plt.savefig("Figures/Fig1/figure1_4panel.pdf", dpi=1000)
+plt.savefig("Figures/Fig1/illustration.png", dpi=1000)
+plt.savefig("Figures/Fig1/illustration.pdf", dpi=1000)
diff --git a/plot_fig2.py b/plot_fig2.py
index a4865bb..4624ae0 100644
--- a/plot_fig2.py
+++ b/plot_fig2.py
@@ -1,14 +1,13 @@
-
import json
import matplotlib.pyplot as plt
import numpy as np
import xgi
+from matplotlib.gridspec import GridSpec
import fig_settings as fs
from lcs import *
-fs.set_fonts()
fs.set_colors()
fs.set_fonts({"font.family": "sans-serif"})
cmap = fs.cmap
@@ -39,6 +38,7 @@
]
convert_to_log = [False, False, True, False, False]
+
def visualize_networks(i, ax):
n = 50
match i:
@@ -65,7 +65,7 @@ def visualize_networks(i, ax):
H = xgi.Hypergraph(e)
- node_size = 3
+ node_size = 5
dyad_lw = 0.5
node_lw = 0.5
@@ -83,8 +83,9 @@ def visualize_networks(i, ax):
xgi.draw(H, ax=ax, pos=pos, node_size=node_size, node_lw=node_lw, dyad_lw=dyad_lw)
+fig = plt.figure(figsize=(16, 10))
+gs = GridSpec(len(cfs) + 1, len(models), wspace=0.2, hspace=0.2)
-fig, axes = plt.subplots(len(cfs) + 1, len(models), figsize=(14, 8))
for i, m in enumerate(models):
with open(f"Data/{m.lower()}.json") as file:
data = json.load(file)
@@ -97,7 +98,8 @@ def visualize_networks(i, ax):
for j, cf in enumerate(cfs):
sps_summary = sps[j].mean(axis=2).T
- im = axes[j + 1, i].imshow(
+ ax = fig.add_subplot(gs[j + 1, i])
+ im = ax.imshow(
to_imshow_orientation(sps_summary),
extent=(min(var), max(var), min(b), max(b)),
vmin=0,
@@ -105,26 +107,33 @@ def visualize_networks(i, ax):
aspect="auto",
cmap=cmap,
)
- axes[j + 1, i].set_xlim([min(var), max(var)])
- axes[j + 1, i].set_ylim([min(b), max(b)])
- axes[j + 1, i].set_xticks(xticks[i], xticklabels[i])
- axes[j + 1, i].set_yticks([0, 0.5, 1], [0, 0.5, 1])
+ ax.set_xlim([min(var), max(var)])
+ ax.set_ylim([min(b), max(b)])
+ ax.set_xticks(xticks[i], xticklabels[i])
+ ax.set_yticks([0, 0.5, 1], [0, 0.5, 1])
if i == 0:
- axes[j + 1, i].set_ylabel(f"{cfs[j]}\n" + r"$\beta$")
+ ax.set_ylabel(f"{cfs[j]}\n" + r"$\beta$")
+ else:
+ ax.set_yticks([], [])
if j + 1 == len(cfs):
- axes[j + 1, i].set_xlabel(labels[i], fontsize=16)
+ ax.set_xlabel(labels[i])
+ else:
+ ax.set_xticks([], [])
-fig.subplots_adjust(bottom=0.15, top=0.95, left=0.1, right=0.8, wspace=0.3, hspace=0.3)
-cbar_ax = fig.add_axes([0.82, 0.15, 0.02, 0.8])
+# fig.subplots_adjust(bottom=0.15, top=0.95, left=0.1, right=0.8, wspace=0.1, hspace=0.3)
+# cbar_ax = fig.add_axes([0.82, 0.15, 0.02, 0.8])
+cbar_ax = fig.add_axes([0.91, 0.11, 0.015, 0.57])
cbar = fig.colorbar(im, cax=cbar_ax)
-cbar.set_label(r"F-Score", fontsize=16, rotation=270, labelpad=25)
+cbar.set_label(r"F-Score", fontsize=15, rotation=270, labelpad=25)
+cbar_ax.set_yticks([0, 0.5, 1], [0, 0.5, 1], fontsize=15)
for i, m in enumerate(models):
- visualize_networks(i, axes[0, i])
- axes[0, i].set_title(titles[i])
+ ax = fig.add_subplot(gs[0, i])
+ visualize_networks(i, ax)
+ ax.set_title(titles[i])
plt.savefig("Figures/Fig2/generative_models_sps.png", dpi=1000)
-plt.savefig("Figures/Fig2/generative_models_sps.pdf", dpi=1000)
\ No newline at end of file
+plt.savefig("Figures/Fig2/generative_models_sps.pdf", dpi=1000)