From 34723219a33c998813e86cc3cdc12d51b556e822 Mon Sep 17 00:00:00 2001 From: Javier Duarte Date: Mon, 15 Apr 2024 11:37:02 -0700 Subject: [PATCH] fix bdt --- notebooks/02_Tabular_Data_BDT.ipynb | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/notebooks/02_Tabular_Data_BDT.ipynb b/notebooks/02_Tabular_Data_BDT.ipynb index 281b16d..079ef99 100644 --- a/notebooks/02_Tabular_Data_BDT.ipynb +++ b/notebooks/02_Tabular_Data_BDT.ipynb @@ -323,7 +323,9 @@ "metadata": {}, "outputs": [], "source": [ - "feature_names = data.columns[1:-2] # we skip the first and last two columns because they are the ID, weight, and label\n", + "feature_names = list(data.columns[1:-2]) # we skip the first and last two columns because they are the ID, weight, and label\n", + "\n", + "print(len(feature_names))\n", "\n", "train = xgb.DMatrix(\n", " data=data_train[feature_names], label=data_train.Label.cat.codes, missing=-999.0, feature_names=feature_names\n", @@ -589,9 +591,16 @@ "outputs": [], "source": [ "plt.figure()\n", + "\n", + "mask_b = np.array(data_test.Label == \"b\")\n", + "mask_s = np.array(data_test.Label == \"s\")\n", + "\n", + "DER_mass_MMC = np.array(data_test.DER_mass_MMC)\n", + "DER_mass_transverse_met_lep = np.array(data_test.DER_mass_transverse_met_lep)\n", + "\n", "plt.plot(\n", - " data_test.DER_mass_MMC[data_test.Label == \"b\"],\n", - " data_test.DER_mass_transverse_met_lep[data_test.Label == \"b\"],\n", + " DER_mass_MMC[mask_b],\n", + " DER_mass_transverse_met_lep[mask_b],\n", " \"o\",\n", " markersize=2,\n", " color=\"midnightblue\",\n", @@ -600,8 +609,8 @@ " label=\"Background\",\n", ")\n", "plt.plot(\n", - " data_test.DER_mass_MMC[data_test.Label == \"s\"],\n", - " data_test.DER_mass_transverse_met_lep[data_test.Label == \"s\"],\n", + " DER_mass_MMC[mask_s],\n", + " DER_mass_transverse_met_lep[mask_s],\n", " \"o\",\n", " markersize=2,\n", " color=\"firebrick\",\n", @@ -650,8 +659,8 @@ "plt.contourf(x_grid, y_grid, z_grid, levels=[0, 0.5, 1], cmap=cm, alpha=0.25)\n", "# overlaid with test data points\n", "plt.plot(\n", - " data_test.DER_mass_MMC[data_test.Label == \"b\"],\n", - " data_test.DER_mass_transverse_met_lep[data_test.Label == \"b\"],\n", + " DER_mass_MMC[mask_b],\n", + " DER_mass_transverse_met_lep[mask_b],\n", " \"o\",\n", " markersize=2,\n", " color=\"midnightblue\",\n", @@ -660,8 +669,8 @@ " label=\"Background\",\n", ")\n", "plt.plot(\n", - " data_test.DER_mass_MMC[data_test.Label == \"s\"],\n", - " data_test.DER_mass_transverse_met_lep[data_test.Label == \"s\"],\n", + " DER_mass_MMC[mask_s],\n", + " DER_mass_transverse_met_lep[mask_s],\n", " \"o\",\n", " markersize=2,\n", " color=\"firebrick\",\n",