Skip to content

Commit

Permalink
fix bdt
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Apr 15, 2024
1 parent 921b832 commit 3472321
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions notebooks/02_Tabular_Data_BDT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,9 @@
"metadata": {},
"outputs": [],
"source": [
"feature_names = data.columns[1:-2] # we skip the first and last two columns because they are the ID, weight, and label\n",
"feature_names = list(data.columns[1:-2]) # we skip the first and last two columns because they are the ID, weight, and label\n",
"\n",
"print(len(feature_names))\n",
"\n",
"train = xgb.DMatrix(\n",
" data=data_train[feature_names], label=data_train.Label.cat.codes, missing=-999.0, feature_names=feature_names\n",
Expand Down Expand Up @@ -589,9 +591,16 @@
"outputs": [],
"source": [
"plt.figure()\n",
"\n",
"mask_b = np.array(data_test.Label == \"b\")\n",
"mask_s = np.array(data_test.Label == \"s\")\n",
"\n",
"DER_mass_MMC = np.array(data_test.DER_mass_MMC)\n",
"DER_mass_transverse_met_lep = np.array(data_test.DER_mass_transverse_met_lep)\n",
"\n",
"plt.plot(\n",
" data_test.DER_mass_MMC[data_test.Label == \"b\"],\n",
" data_test.DER_mass_transverse_met_lep[data_test.Label == \"b\"],\n",
" DER_mass_MMC[mask_b],\n",
" DER_mass_transverse_met_lep[mask_b],\n",
" \"o\",\n",
" markersize=2,\n",
" color=\"midnightblue\",\n",
Expand All @@ -600,8 +609,8 @@
" label=\"Background\",\n",
")\n",
"plt.plot(\n",
" data_test.DER_mass_MMC[data_test.Label == \"s\"],\n",
" data_test.DER_mass_transverse_met_lep[data_test.Label == \"s\"],\n",
" DER_mass_MMC[mask_s],\n",
" DER_mass_transverse_met_lep[mask_s],\n",
" \"o\",\n",
" markersize=2,\n",
" color=\"firebrick\",\n",
Expand Down Expand Up @@ -650,8 +659,8 @@
"plt.contourf(x_grid, y_grid, z_grid, levels=[0, 0.5, 1], cmap=cm, alpha=0.25)\n",
"# overlaid with test data points\n",
"plt.plot(\n",
" data_test.DER_mass_MMC[data_test.Label == \"b\"],\n",
" data_test.DER_mass_transverse_met_lep[data_test.Label == \"b\"],\n",
" DER_mass_MMC[mask_b],\n",
" DER_mass_transverse_met_lep[mask_b],\n",
" \"o\",\n",
" markersize=2,\n",
" color=\"midnightblue\",\n",
Expand All @@ -660,8 +669,8 @@
" label=\"Background\",\n",
")\n",
"plt.plot(\n",
" data_test.DER_mass_MMC[data_test.Label == \"s\"],\n",
" data_test.DER_mass_transverse_met_lep[data_test.Label == \"s\"],\n",
" DER_mass_MMC[mask_s],\n",
" DER_mass_transverse_met_lep[mask_s],\n",
" \"o\",\n",
" markersize=2,\n",
" color=\"firebrick\",\n",
Expand Down

0 comments on commit 3472321

Please sign in to comment.