diff --git a/docs/examples/grid_cond_gfn.ipynb b/docs/examples/grid_cond_gfn.ipynb index 958771f6..6dd155d4 100644 --- a/docs/examples/grid_cond_gfn.ipynb +++ b/docs/examples/grid_cond_gfn.ipynb @@ -53,55 +53,55 @@ "import torch\n", "import matplotlib.pyplot as pp\n", "from matplotlib import cm\n", - "from matplotlib.colors import ListedColormap, LinearSegmentedColormap\n", + "from matplotlib.colors import ListedColormap\n", "import matplotlib as mpl\n", "import pickle\n", "import gzip\n", "import numpy as np\n", "\n", - " # Some plotting routines\n", - "def binavg(x, n=100, var=False, \n", - " bounds=False, logx=False):\n", + "\n", + "# Some plotting routines\n", + "def binavg(x, n=100, var=False, bounds=False, logx=False):\n", " if len(x) < n:\n", - " return np.linspace(0, len(x), len(x)), x, 0*x+1, 0 * x, np.stack([x, x]).T\n", + " return np.linspace(0, len(x), len(x)), x, 0 * x + 1, 0 * x, np.stack([x, x]).T\n", " if logx:\n", - " bins = np.logspace(0, np.log(len(x))/np.log(10), n)\n", + " bins = np.logspace(0, np.log(len(x)) / np.log(10), n)\n", " idx = np.digitize(np.arange(len(x)), bins)\n", " else:\n", " bins = np.linspace(0, len(x), n)\n", - " idx = np.int32(np.linspace(0, n, len(x)+1))[:-1]\n", + " idx = np.int32(np.linspace(0, n, len(x) + 1))[:-1]\n", " counts = np.bincount(idx, minlength=n)\n", " _y = y = np.bincount(idx, x, minlength=n) / counts\n", - " bins = bins[counts>0]\n", - " y = y[counts>0]\n", + " bins = bins[counts > 0]\n", + " y = y[counts > 0]\n", " r = (bins, y, counts)\n", " if var:\n", " var = np.sqrt(np.bincount(idx, x**2, minlength=n) / np.bincount(idx, minlength=n) - _y**2)\n", - " r += (var[counts>0], )\n", + " r += (var[counts > 0],)\n", " if bounds:\n", - " r += (np.float32([(x[idx==i].min(),x[idx==i].max()) for i in range(n) if counts[i] > 0]), ) \n", + " r += (np.float32([(x[idx == i].min(), x[idx == i].max()) for i in range(n) if counts[i] > 0]),)\n", " return r\n", - " \n", + "\n", + "\n", "def smooth_plot(x, n=100, fill_var=False, fill_minmax=False, logx=False, **kw):\n", " bins, y, counts, var, bounds = binavg(x, n, var=True, bounds=True, logx=logx)\n", - " if 'bins' in kw:\n", - " bins = kw.pop('bins')[counts>0]\n", - " l, = pp.plot(bins, y, **kw)\n", + " if \"bins\" in kw:\n", + " bins = kw.pop(\"bins\")[counts > 0]\n", + " (l,) = pp.plot(bins, y, **kw)\n", " if fill_var:\n", - " pp.fill_between(bins, y-var, y+var, alpha=0.25, color=l.get_c())\n", + " pp.fill_between(bins, y - var, y + var, alpha=0.25, color=l.get_c())\n", " elif fill_minmax:\n", " pp.fill_between(bins, bounds[:, 0], bounds[:, 1], alpha=0.25, color=l.get_c())\n", " return l\n", "\n", "\n", - "top = cm.get_cmap('Blues_r', 128)\n", - "bottom = cm.get_cmap('Oranges', 128)\n", + "top = cm.get_cmap(\"Blues_r\", 128)\n", + "bottom = cm.get_cmap(\"Oranges\", 128)\n", "\n", - "newcolors = np.vstack((top(np.linspace(0, 1, 128)),\n", - " bottom(np.linspace(0, 1, 128))))\n", - "newcmp = ListedColormap(newcolors, name='OrangeBlue')\n", + "newcolors = np.vstack((top(np.linspace(0, 1, 128)), bottom(np.linspace(0, 1, 128))))\n", + "newcmp = ListedColormap(newcolors, name=\"OrangeBlue\")\n", "\n", - "mpl.rc('image', cmap=newcmp)" + "mpl.rc(\"image\", cmap=newcmp)" ] }, { @@ -142,12 +142,12 @@ "H = 64\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", - "f, ax = pp.subplots(1,2, figsize=(7,3))\n", + "f, ax = pp.subplots(1, 2, figsize=(7, 3))\n", "for i in range(2):\n", " pp.sca(ax[i])\n", - " pp.imshow(np.concatenate([r[:, i], [0]]).reshape((H,H)))\n", - " pp.xlabel('$x_1$')\n", - " pp.ylabel('$x_2$')\n", + " pp.imshow(np.concatenate([r[:, i], [0]]).reshape((H, H)))\n", + " pp.xlabel(\"$x_1$\")\n", + " pp.ylabel(\"$x_2$\")\n", "pp.colorbar()\n", "pp.tight_layout()" ] @@ -188,14 +188,14 @@ "H = 64\n", "env = gfn.GridEnv(64, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", - "f, ax = pp.subplots(1,11, figsize=(11,2))\n", + "f, ax = pp.subplots(1, 11, figsize=(11, 2))\n", "for i in range(11):\n", " pp.sca(ax[i])\n", " w = i / 10\n", - " y = r[:, 0] * w + r[:, 1] * (1-w)\n", - " pp.imshow(np.concatenate([y, [0]]).reshape((H,H)))\n", - " pp.axis('off')\n", - " pp.title(f'$\\omega_1={w}$')\n", + " y = r[:, 0] * w + r[:, 1] * (1 - w)\n", + " pp.imshow(np.concatenate([y, [0]]).reshape((H, H)))\n", + " pp.axis(\"off\")\n", + " pp.title(f\"$\\omega_1={w}$\")\n", "pp.tight_layout()" ] }, @@ -235,14 +235,14 @@ "H = 64\n", "env = gfn.GridEnv(64, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", - "f, ax = pp.subplots(1,11, figsize=(11,2))\n", + "f, ax = pp.subplots(1, 11, figsize=(11, 2))\n", "for i in range(11):\n", " pp.sca(ax[i])\n", " w = i / 10\n", - " y = (r[:, 0] * w + r[:, 1] * (1-w)) ** 8\n", - " pp.imshow(np.concatenate([y, [0]]).reshape((H,H)))\n", - " pp.axis('off')\n", - " pp.title(f'$\\omega_1={w}$')\n", + " y = (r[:, 0] * w + r[:, 1] * (1 - w)) ** 8\n", + " pp.imshow(np.concatenate([y, [0]]).reshape((H, H)))\n", + " pp.axis(\"off\")\n", + " pp.title(f\"$\\omega_1={w}$\")\n", "pp.tight_layout()" ] }, @@ -284,7 +284,7 @@ "H = 32\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", - "f, ax = pp.subplots(1,2,figsize=(9,4))\n", + "f, ax = pp.subplots(1, 2, figsize=(9, 4))\n", "for axi in range(2):\n", " pp.sca(ax[axi])\n", " pp.scatter(r[:, 0], r[:, 1], s=2)\n", @@ -295,12 +295,12 @@ " if d == 0:\n", " pareto.append(re[:, i])\n", " pareto = np.float32(pareto)\n", - " pp.scatter(pareto[:, 0], pareto[:, 1], marker='s', s=5, color='red')\n", - " pp.xlim(*[(0.2,1.02), (0.7,1.02)][axi])\n", - " pp.ylim(*[(0.1,1.02), (0.4,1.02)][axi])\n", - " pp.title(['the Pareto front', 'zooming in'][axi])\n", - " pp.xlabel('$R_{\\mathrm{Branin}}$')\n", - " pp.ylabel('$R_{\\mathrm{Currin}}$')" + " pp.scatter(pareto[:, 0], pareto[:, 1], marker=\"s\", s=5, color=\"red\")\n", + " pp.xlim(*[(0.2, 1.02), (0.7, 1.02)][axi])\n", + " pp.ylim(*[(0.1, 1.02), (0.4, 1.02)][axi])\n", + " pp.title([\"the Pareto front\", \"zooming in\"][axi])\n", + " pp.xlabel(\"$R_{\\mathrm{Branin}}$\")\n", + " pp.ylabel(\"$R_{\\mathrm{Currin}}$\")" ] }, { @@ -337,15 +337,16 @@ ], "source": [ "import scipy.stats as stats\n", - "f, ax = pp.subplots(1,2,figsize=(9,2.5))\n", + "\n", + "f, ax = pp.subplots(1, 2, figsize=(9, 2.5))\n", "pp.sca(ax[0])\n", - "pp.plot(np.linspace(0,1,100), stats.dirichlet.pdf([np.linspace(0,1,100),1-np.linspace(0,1,100)], [1.5,1.5]))\n", - "pp.ylabel('$p(\\\\omega_1)$')\n", - "pp.xlabel('$\\\\omega_1$')\n", + "pp.plot(np.linspace(0, 1, 100), stats.dirichlet.pdf([np.linspace(0, 1, 100), 1 - np.linspace(0, 1, 100)], [1.5, 1.5]))\n", + "pp.ylabel(\"$p(\\\\omega_1)$\")\n", + "pp.xlabel(\"$\\\\omega_1$\")\n", "pp.sca(ax[1])\n", - "pp.plot(np.linspace(0,8,100), stats.gamma.pdf(np.linspace(0,8,100), 2, scale=1))\n", - "pp.ylabel('$p(\\\\beta)$')\n", - "pp.xlabel('$\\\\beta$')\n", + "pp.plot(np.linspace(0, 8, 100), stats.gamma.pdf(np.linspace(0, 8, 100), 2, scale=1))\n", + "pp.ylabel(\"$p(\\\\beta)$\")\n", + "pp.xlabel(\"$\\\\beta$\")\n", "pp.tight_layout()" ] }, @@ -384,18 +385,18 @@ } ], "source": [ - "results = pickle.load(gzip.open('results/example_branincurrin.pkl.gz', 'rb'))\n", - "f, ax = pp.subplots(2,11,figsize=(18,4))\n", - "H = results['args'].horizon\n", + "results = pickle.load(gzip.open(\"results/example_branincurrin.pkl.gz\", \"rb\"))\n", + "f, ax = pp.subplots(2, 11, figsize=(18, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(11):\n", - " coef, temp = results['cond_confs'][col * 5]\n", + " coef, temp = results[\"cond_confs\"][col * 5]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, col * 5], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, col * 5], [0]]).reshape((H, H)))\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -431,17 +432,17 @@ } ], "source": [ - "f, ax = pp.subplots(2,11,figsize=(18,4))\n", - "H = results['args'].horizon\n", + "f, ax = pp.subplots(2, 11, figsize=(18, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(11):\n", - " coef, temp = results['cond_confs'][col * 5 + 2]\n", + " coef, temp = results[\"cond_confs\"][col * 5 + 2]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, col * 5 + 2], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, col * 5 + 2], [0]]).reshape((H, H)))\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -477,17 +478,17 @@ } ], "source": [ - "f, ax = pp.subplots(2,11,figsize=(18,4))\n", - "H = results['args'].horizon\n", + "f, ax = pp.subplots(2, 11, figsize=(18, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(11):\n", - " coef, temp = results['cond_confs'][col * 5 + 3]\n", + " coef, temp = results[\"cond_confs\"][col * 5 + 3]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, col * 5 + 3], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, col * 5 + 3], [0]]).reshape((H, H)))\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -526,28 +527,28 @@ "data = results\n", "errs = []\n", "logerrs = []\n", - "for (coef, t), dist in zip(data['cond_confs'], data['final_distribution'].T):\n", - " unnorm_p = (r[:, 0]*coef[0]+r[:,1]*coef[1])**t\n", + "for (coef, t), dist in zip(data[\"cond_confs\"], data[\"final_distribution\"].T):\n", + " unnorm_p = (r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** t\n", " Z = unnorm_p.sum()\n", " p = unnorm_p / Z\n", " errs.append(abs(dist - p).mean())\n", " logp = np.log(unnorm_p) - np.log(Z)\n", " logerrs.append(abs(np.log(dist) - logp).mean())\n", - "f, ax = pp.subplots(1,2)\n", + "f, ax = pp.subplots(1, 2)\n", "pp.sca(ax[0])\n", - "pp.imshow(np.float32(errs).reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "pp.ylabel('$\\\\omega_1$')\n", - "pp.colorbar(label='abs prob error')\n", + "pp.imshow(np.float32(errs).reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "pp.ylabel(\"$\\\\omega_1$\")\n", + "pp.colorbar(label=\"abs prob error\")\n", "pp.sca(ax[1])\n", - "pp.imshow(np.float32(logerrs).reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "pp.ylabel('$\\\\omega_1$')\n", - "pp.colorbar(label='abs log prob error')\n", + "pp.imshow(np.float32(logerrs).reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "pp.ylabel(\"$\\\\omega_1$\")\n", + "pp.colorbar(label=\"abs log prob error\")\n", "pp.tight_layout()\n", "pp.subplots_adjust(wspace=0.75)" ] @@ -579,7 +580,7 @@ } ], "source": [ - "print(f'{100 - 100 * stats.gamma.cdf(8, 2, scale=1): .2f}%')" + "print(f\"{100 - 100 * stats.gamma.cdf(8, 2, scale=1): .2f}%\")" ] }, { @@ -615,17 +616,17 @@ } ], "source": [ - "f, ax = pp.subplots(2,11,figsize=(18,4))\n", - "H = results['args'].horizon\n", + "f, ax = pp.subplots(2, 11, figsize=(18, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(11):\n", - " coef, temp = results['cond_confs'][col * 5 + 4]\n", + " coef, temp = results[\"cond_confs\"][col * 5 + 4]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, col * 5 + 4], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, col * 5 + 4], [0]]).reshape((H, H)))\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -650,13 +651,18 @@ }, "outputs": [], "source": [ - "agent = gfn.FlowNet_TBAgent(data['args'], [env])\n", - "for a,b in zip(agent.parameters(), data['params']):\n", + "agent = gfn.FlowNet_TBAgent(data[\"args\"], [env])\n", + "for a, b in zip(agent.parameters(), data[\"params\"]):\n", " a.data = torch.tensor(b)\n", "preds = []\n", - "for coef, t in data['cond_confs']:\n", + "for coef, t in data[\"cond_confs\"]:\n", " env.reset(coef, t)\n", - " preds.append((agent.Z(torch.tensor(env.cond_obs).float()).item(), np.log(((r[:, 0]*coef[0]+r[:,1]*coef[1])**t).sum())))" + " preds.append(\n", + " (\n", + " agent.Z(torch.tensor(env.cond_obs).float()).item(),\n", + " np.log(((r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** t).sum()),\n", + " )\n", + " )" ] }, { @@ -685,44 +691,46 @@ ], "source": [ "preds = np.float32(preds)\n", - "f, ax = pp.subplots(1,4,figsize=(12,4))\n", + "f, ax = pp.subplots(1, 4, figsize=(12, 4))\n", "pp.sca(ax[0])\n", - "err = abs(preds[:, 0] - preds[:,1])\n", - "pp.imshow(err.reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "pp.ylabel('$\\\\omega_1$')\n", + "err = abs(preds[:, 0] - preds[:, 1])\n", + "pp.imshow(err.reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "pp.ylabel(\"$\\\\omega_1$\")\n", "pp.colorbar()\n", - "pp.title('abs log Z err')\n", + "pp.title(\"abs log Z err\")\n", "\n", "pp.sca(ax[1])\n", - "err = np.log(abs(preds[:, 0] - preds[:,1])) / np.log(10)\n", - "pp.imshow(err.reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "#pp.ylabel('$\\\\omega_1$')\n", - "pp.title('abs log Z err (log-colored)')\n", - "cb = pp.colorbar(ticks=np.linspace(-2.5,0.5,7))\n", - "cb.ax.set_yticklabels([f'${m}\\\\times 10^{{{int(e)}}}$' for i in np.linspace(-2.5,0.5,7) for m,e in [f'{10**i:.0E}'.split('E')]])\n", + "err = np.log(abs(preds[:, 0] - preds[:, 1])) / np.log(10)\n", + "pp.imshow(err.reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "# pp.ylabel('$\\\\omega_1$')\n", + "pp.title(\"abs log Z err (log-colored)\")\n", + "cb = pp.colorbar(ticks=np.linspace(-2.5, 0.5, 7))\n", + "cb.ax.set_yticklabels(\n", + " [f\"${m}\\\\times 10^{{{int(e)}}}$\" for i in np.linspace(-2.5, 0.5, 7) for m, e in [f\"{10**i:.0E}\".split(\"E\")]]\n", + ")\n", "\n", "pp.sca(ax[2])\n", "err = np.log(preds[:, 0])\n", - "pp.imshow(err.reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "pp.title('predicted log Z')\n", + "pp.imshow(err.reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "pp.title(\"predicted log Z\")\n", "pp.colorbar()\n", "\n", "pp.sca(ax[3])\n", "err = np.log(preds[:, 1])\n", - "pp.imshow(err.reshape((11,5)))\n", - "pp.xticks(range(5), [1,2,4,8,16])\n", - "pp.xlabel('$\\\\beta$')\n", - "pp.yticks(range(0,11,2), np.arange(0,11,2)/10)\n", - "pp.title('True log Z')\n", + "pp.imshow(err.reshape((11, 5)))\n", + "pp.xticks(range(5), [1, 2, 4, 8, 16])\n", + "pp.xlabel(\"$\\\\beta$\")\n", + "pp.yticks(range(0, 11, 2), np.arange(0, 11, 2) / 10)\n", + "pp.title(\"True log Z\")\n", "pp.colorbar()\n", "pp.tight_layout()" ] @@ -759,8 +767,8 @@ } ], "source": [ - "l = smooth_plot(np.float32(data['losses'])[:,0], 100, fill_minmax=True)\n", - "pp.yscale('log')" + "l = smooth_plot(np.float32(data[\"losses\"])[:, 0], 100, fill_minmax=True)\n", + "pp.yscale(\"log\")" ] }, { @@ -806,9 +814,10 @@ ], "source": [ "import matplotlib.colors as colors\n", - "x, y = np.arange(len(data['losses'])), np.float32(data['losses'])[:,0]\n", - "pp.hist2d(x,y, [np.linspace(0, len(x), 100), np.logspace(-4.3, 2, 100)], norm=colors.LogNorm())\n", - "pp.yscale('log')\n", + "\n", + "x, y = np.arange(len(data[\"losses\"])), np.float32(data[\"losses\"])[:, 0]\n", + "pp.hist2d(x, y, [np.linspace(0, len(x), 100), np.logspace(-4.3, 2, 100)], norm=colors.LogNorm())\n", + "pp.yscale(\"log\")\n", "pp.colorbar()" ] }, @@ -842,12 +851,13 @@ ], "source": [ "import grid_cond_gfn as gfn\n", + "\n", "hps = gfn.parser.parse_args([])\n", - "hps.save_path = None \n", + "hps.save_path = None\n", "hps.progress = True\n", "hps.horizon = 16\n", - "hps.n_train_steps = 1000 # The more steps the better\n", - "hps.opt = 'adam'\n", + "hps.n_train_steps = 1000 # The more steps the better\n", + "hps.opt = \"adam\"\n", "hps.mbsize = 128\n", "hps.learning_rate = 1e-2\n", "hps.n_hid = 64\n", @@ -883,18 +893,18 @@ } ], "source": [ - "f, ax = pp.subplots(2,5,figsize=(5*2,4))\n", - "H = results['args'].horizon\n", + "f, ax = pp.subplots(2, 5, figsize=(5 * 2, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(5):\n", - " coef, temp = results['cond_confs'][5 * 5 + col]\n", + " coef, temp = results[\"cond_confs\"][5 * 5 + col]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, 5 * 5 + col], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, 5 * 5 + col], [0]]).reshape((H, H)))\n", " pp.title(f\"$\\\\beta={temp}$\")\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -921,18 +931,18 @@ } ], "source": [ - "f, ax = pp.subplots(2,11,figsize=(18,4))\n", - "H = results['args'].horizon\n", + "f, ax = pp.subplots(2, 11, figsize=(18, 4))\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", "s, r, pos = env.state_info()\n", "row = 0\n", "for col in range(11):\n", - " coef, temp = results['cond_confs'][col * 5 + 2]\n", + " coef, temp = results[\"cond_confs\"][col * 5 + 2]\n", " pp.sca(ax[0, col])\n", - " pp.imshow(np.concatenate([results['final_distribution'][:, col * 5 + 2], [0]]).reshape((H,H)))\n", + " pp.imshow(np.concatenate([results[\"final_distribution\"][:, col * 5 + 2], [0]]).reshape((H, H)))\n", " pp.title(f\"$\\\\omega_1={coef[0]:.1f}$\")\n", " pp.sca(ax[1, col])\n", - " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1])**temp, [0]]).reshape((H,H)))" + " pp.imshow(np.concatenate([(r[:, 0] * coef[0] + r[:, 1] * coef[1]) ** temp, [0]]).reshape((H, H)))" ] }, { @@ -967,8 +977,8 @@ } ], "source": [ - "smooth_plot(np.float32(results['losses'])[:,0], 100, fill_minmax=True)\n", - "pp.yscale('log')" + "smooth_plot(np.float32(results[\"losses\"])[:, 0], 100, fill_minmax=True)\n", + "pp.yscale(\"log\")" ] }, { @@ -982,12 +992,12 @@ }, "outputs": [], "source": [ - "H = results['args'].horizon\n", + "H = results[\"args\"].horizon\n", "env = gfn.GridEnv(H, funcs=[gfn.branin, gfn.currin])\n", - "model = gfn.FlowNet_TBAgent(results['args'], [env])\n", - "for a, b in zip(model.parameters(), results['params']):\n", + "model = gfn.FlowNet_TBAgent(results[\"args\"], [env])\n", + "for a, b in zip(model.parameters(), results[\"params\"]):\n", " a.data = torch.tensor(b)\n", - "s, r, pos = env.state_info()\n" + "s, r, pos = env.state_info()" ] }, { @@ -1022,17 +1032,17 @@ } ], "source": [ - "f, ax = pp.subplots(3,11,figsize=(18,6))\n", + "f, ax = pp.subplots(3, 11, figsize=(18, 6))\n", "for col in range(11):\n", - " #coef, temp = results['cond_confs'][col * 5]\n", + " # coef, temp = results['cond_confs'][col * 5]\n", " coef = col / 10\n", - " env.cond_obs = np.float32([coef, 1-coef, 2])\n", + " env.cond_obs = np.float32([coef, 1 - coef, 2])\n", " with torch.no_grad():\n", " pred = model.model(torch.tensor([env.obs(p) for p in pos]).float())\n", " for j in range(3):\n", " pp.sca(ax[j, col])\n", - " pp.imshow(np.concatenate([pred[:, j].numpy(), [0]]).reshape((H,H)), vmin=-2, vmax=2)\n", - " pp.axis('off')\n", + " pp.imshow(np.concatenate([pred[:, j].numpy(), [0]]).reshape((H, H)), vmin=-2, vmax=2)\n", + " pp.axis(\"off\")\n", "pp.tight_layout()" ] }, @@ -1068,18 +1078,18 @@ } ], "source": [ - "f, ax = pp.subplots(3,11,figsize=(18,6))\n", + "f, ax = pp.subplots(3, 11, figsize=(18, 6))\n", "sm = torch.nn.Softmax(1)\n", "for col in range(11):\n", - " #coef, temp = results['cond_confs'][col * 5]\n", + " # coef, temp = results['cond_confs'][col * 5]\n", " coef = col / 10\n", - " env.cond_obs = np.float32([coef, 1-coef, 2])\n", + " env.cond_obs = np.float32([coef, 1 - coef, 2])\n", " with torch.no_grad():\n", " pred = sm(model.model(torch.tensor([env.obs(p) for p in pos]).float()))\n", " for j in range(3):\n", " pp.sca(ax[j, col])\n", - " pp.imshow(np.concatenate([pred[:, j].numpy(), [0]]).reshape((H,H)), vmin=0, vmax=1)\n", - " pp.axis('off')\n", + " pp.imshow(np.concatenate([pred[:, j].numpy(), [0]]).reshape((H, H)), vmin=0, vmax=1)\n", + " pp.axis(\"off\")\n", "pp.tight_layout()" ] } diff --git a/pyproject.toml b/pyproject.toml index 6cfd7224..2e5fc8d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,20 +79,18 @@ dependencies = [ [project.optional-dependencies] dev = [ - "bandit[toml]", - "black", - "isort", - "mypy", - "pip-compile-cross-platform", - "pre-commit", - "pytest", - "pytest-cov", - "ruff", - "tox", - "typeguard", - "types-pkg_resources", - # Security pin - "gitpython>=3.1.30", + 'bandit[toml]', + 'black', + 'isort', + 'mypy', + 'pip-compile-cross-platform', + 'pre-commit', + 'pytest', + 'pytest-cov', + 'ruff', + 'tox', + 'typeguard', + 'gitpython>=3.1.30', ] [[project.authors]]