From 4ac71b6a8ffb526797570a257a51e2cbe9ae5722 Mon Sep 17 00:00:00 2001 From: perib Date: Mon, 31 Jul 2023 17:02:16 -0700 Subject: [PATCH] correctly get n_splits, tutorial update --- Tutorial/1_Estimators_Overview.ipynb | 319 ++++++++++++------ tpot2/tpot_estimator/estimator.py | 2 +- .../tpot_estimator/steady_state_estimator.py | 2 +- 3 files changed, 211 insertions(+), 112 deletions(-) diff --git a/Tutorial/1_Estimators_Overview.ipynb b/Tutorial/1_Estimators_Overview.ipynb index ee6e4624..6e164775 100644 --- a/Tutorial/1_Estimators_Overview.ipynb +++ b/Tutorial/1_Estimators_Overview.ipynb @@ -5,27 +5,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "All TPOT estimators can be created with the TPOTEstimator class. \n", - "The TPOTClassifier and TPOTRegressor are set default parameters for the TPOTEstimator for Classification and Regression." + "TPOT1 and TPOTSteady use a standard evolutionary algorithm that evaluates exactly population_size individuals each generation. The next generation does not start until the previous is completely finished evaluating. This leads to underutilized CPU time as the cores are waiting for the last individuals to finish training. \n", + "\n", + "TPOTEstimatorSteadyState will generate and evaluate the next individual as soon as an individual finishes evaluation. The number of individuals being evaluated is determined by the n_jobs parameter. There is no longer a concept of generations. The population_size parameter now refers to the size of the list of evaluated parents. When an individual is evaluated, the selection method updates the list of parents. Then " ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Generation: 100%|██████████| 5/5 [00:35<00:00, 7.17s/it]\n" + "Evaluations: : 111it [00:30, 3.64it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "0.9947089947089948\n" + "1.0\n" ] } ], @@ -34,36 +35,39 @@ "import sklearn\n", "import sklearn.datasets\n", "\n", - "est = tpot2.TPOTEstimator( population_size=30,\n", - " generations=5,\n", + "est = tpot2.TPOTEstimatorSteadyState( population_size=30,\n", + " initial_population_size=30,\n", " scorers=['roc_auc_ovr'], #scorers can be a list of strings or a list of scorers. These get evaluated during cross validation. \n", " scorers_weights=[1],\n", + "\n", " classification=True,\n", - " n_jobs=1, \n", - " early_stop=5, #how many generations with no improvement to stop after\n", - " \n", - " #List of other objective functions. All objective functions take in an untrained GraphPipeline and return a score or a list of scores\n", + " n_jobs=1,\n", + " #List of other objective functions. All objective functions take in an untrained GraphPipeline and return a score or a list of scores\n", " other_objective_functions= [ ],\n", " \n", " #List of weights for the other objective functions. Must be the same length as other_objective_functions. By default, bigger is better is set to True. \n", " other_objective_functions_weights=[],\n", + "\n", + " max_eval_time_seconds=15,\n", + " max_time_seconds=30,\n", " verbose=2)\n", "\n", + "\n", "scorer = sklearn.metrics.get_scorer('roc_auc_ovo')\n", "X, y = sklearn.datasets.load_iris(return_X_y=True)\n", "X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, train_size=0.75, test_size=0.25)\n", "est.fit(X_train, y_train)\n", - "print(scorer(est, X_test, y_test))" + "print(scorer(est, X_test, y_test))\n" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnYAAAHWCAYAAAD6oMSKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAV+klEQVR4nO3de6zX9Z3n8de5UOAcRDniHg53bQ/1BoJitaK13a7StVkTesFO/MMm22YzxWnSWOnsdrOdacbZzrRNOwnbxHbMGGPialndJTqNtmbaRlREq1AuBaYgIJdT8RxRzgHhXPaPOmSsYFHBA+95PBIS8v39ft/P+/f74+SZ7+/7+34bhoaGhgIAwCmvcbgHAADg+BB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUETzcA8AcDwNDAyku7s7XV1d6erqyou7d+e1/fszODCQxqamjBw9OmdNmJD29va0t7enra0tTU1Nwz02wHHRMDQ0NDTcQwC8Wz09PVm1alV+/atf5UBvb4b6+zNm//6c3t2dEf39aRwaymBDQw41N2dvW1v2jR6dhubmjGptzcyLL85FF12UcePGDffbAHhXhB1wStu5c2cef+yxbNm0KSP6+jJ12/Z0dHfn9N7ejBgYOOrrDjU1ZW9ra3a1tWXb1Ck51NKSszs7M++qq9LR0fEevgOA40fYAaek/v7+LF++PCuXL8+YPXvyga3bMnnPnjQNDr7tfQ00NuaF8ePzz9OmZt/48bl03rzMmzcvzc3OVgFOLcIOOOXs3r07Dy1blp4XduTcTZvSuWNHGo/Dn7LBhoZsmjQpv+nsTNvkSbnu+uszYcKE4zAxwHtD2AGnlK1bt+aBe+9Ny85duWT9+ozt6zvua7zS0pJnzjsvfRMnZsENCzNt2rTjvgbAiSDsgFPG1q1b83/uuSdnbt2WD61bl+Z38LXrsepvbMyKC85P99Sp+fSf/Im4A04JrmMHnBJ2796dB+69N21bt+XytWtPaNQlSfPgYD68Zm3atm3LA/fel927d5/Q9QCOB2EHnPT6+/vz0LJladm5K5etW3dczqc7Fo1DQ7ls7bqM3rUz/7hsWfr7+9+TdQHeKWEHnPSWL1+enhd25JL160/4kbo/1Dw4mEvWrU/3jh15/PHH39O1Ad4uYQec1Hbu3JmVy5fn3E2bTsgPJY7F6X19+eDGTXnqsceya9euYZkB4FgIO+Ck9vhjj2XMnj3p3LFjWOeYsWNHxuzZk+WPPTascwC8FWEHnLR6enqyZdOmfGDrtvfsvLqjaRwayvu3bsuWjRvT09MzrLMAHI2wA05aq1atyoi+vkzes2e4R0mSTNmzJ819fVm9evVwjwJwRMIOOCkNDAzk17/6VaZu2/6ObhN2IjQNDmba9u1Z/cwzGXiL+9ACDBdhBxxRc3NzZs+effjfXXfdddTnPv/887nvvvve1v4ff/zxw/seM2ZMzj333MyePTtf/vKXkyTd3d050Nubju7ud/U+kuTV/v78+caN+fcrV+ZTzz2b/7x2Tbbs78uKl1/On61f97b21fHS7+fq/oO5nn766dx6661JkhdffDGXXXZZ5syZk1/84he58cYb3/V7eOqppzJ37tyMGDEiDz744LveH1CTO1wDR3TGGWfkueeeO6bn/kvYLVy48E2PDQwMpKmp6U3br7jiisP7/+hHP5olS5bkwgsvPPx4V1dXBg4ezBn79r2j+f+1r23cmA+2tuTRuXPT0NCQjb292XPw0Dva1+m9vRnq709XV1fOOuusw9vnzp2buXPnJkkeffTRXHrppVmyZEmS5Oqrrz7m/R/t85o4cWLuuOOOfPe7331HcwP/Ngg74Jht3rw5n/zkJ7NixYqMGDEil156ae655558/etfz9q1azN79uzcfPPNaW5uzrJly9Ld3Z22trZ8+9vfzuc///n09vamubk5P/zhDzN79uwjrjF9+vR87nOfy9KlS/MfL744D+7elbt27syhwaF8+Iwz8t/OOSdJ8n9/1/Wm7b0DA/ny+vXpOvhakuRrZ5+TKaNG5Te9vVly3nlpaGhIksxobU2SrHj55cPrPvfKK/nrLZtzcHAwrU1N+dsZH8ykUaPy5Msv5682/zYNaciIxobcP3tO9m7blgULFqT19f088sgjWbduXZYsWZJvfOMbWbx4cQ4cOJAnn3wyS5cuzWc+85k8/fTTGRgYyOLFi/PLX/4yBw8ezOLFi3PjjTfmzjvvfMPndf/997/pc5k8eXImT56cxkZftABHJ+yAI3r55ZffEF/f+9738rGPfSyLFi3KLbfckrFjx+aGG27IzJkzc9ttt2XJkiVZunRpkuTOO+/MqlWr8uyzz2bs2LHp6+vLz372s4wcOTKrV6/OLbfckp/+9KdHXXvKlCn5q7/8y2x74IE8+tJLue+i2WluaMitGzbkn7q7M2XUqCNuPzg4mDNGNOeOCy/M0NBQegcGsmLv3pzb2prG16PuaD7Q0pJ7Zl2UpoaGPPrSS/nB9u25rbMz/7BjR/7r2edk3rhxefX1O0+sfOqpfPTqq/Ojv//77N+//w1H2GbOnJlvfvObWbNmTb7zne/k+eefP/zYHXfckY6OjqxcuTL79+/P5Zdfnk984hNJ8obPC+CdEnbAER3tq9hFixbl6quvTm9vb1asWHHU18+fP/9wpLz22mu5+eabs3r16jQ1NeXFF198y7U/+9nP5ifLluU327fnuVdfzaeeezZJcmBgMBeOGZMXDhw44vaPtI3LbZv35m+3bMk1Z56ZOW8jkvb29+fWjRuy7cCBDA4N5fTmEUmSi8eOzXeefz6/3d+XT4w/K6clOffM8Vn60EP51re+lYULF+ac148i/jGPPPJI1qxZk7vvvvv3a+7dm82bN7/p8wJ4p4Qd8La8+uqr2bNnT4aGhnLgwIGMGTPmiM9raWk5/P/vf//7mT59eu6+++709vZm+vTpb7lGS0tLBgcGkqGhLJwwIX82ddobHr9r544jbk+S/zfn4vxTd3f+55bN+U9n/btcNW5cNvT1ZnBo6C2P2v3dtq25uq0tn5vQkY29vfnzTRuTJP9lypR8ZNy4/LynOwtXPZf/PeuifOS8czPxmv+QkSNH5pprrsmPf/zjt3w//2JwcDC33377m865W7t27Rs+L4B3yskawNvy1a9+NYsWLcoXvvCFLF68OEly2mmn5dVXXz3qa1555ZVMnDgxDQ0NufPOO49pncamplzQ0ZF/fPHF9Bz6/Q8dXjp4ML87eDAfPv2MI27veu21tDQ15VPt7blp4qSs792X6aNHZ0ZLa/7X9m0Zev0ix5t6e/P03r1vWG9f/0Da3zcySXL/77oOb9+2f3/OGzMmfzplat7f0pIXDhzIrn296ejoyFe+8pVce+21Wbfu2H5Ze+211+YHP/jB4UulrFmzxmVTgOPKETvgiP7wHLubbropM2fOzIYNG3L77bdnaGgoV111VX7+859n3rx5OXTo0Bt+PPGvfelLX8qnP/3p/OhHP8qCBQuOaf2Ro0dnwlln5U+nTM1Na36doaGhjGhszN90zkhna+sRt+8+eDB/s2VzGhsaMqqxMX/d2Zkk+daMzty2eXM+/vTTaWlqzISRI/Pfz3l/ul577fB6X5w8OV/buDF/t/X5XDWu7fD2f9i5Iyv27k1TkpmnnZY5Y8fmf2zbmsf/4i/yvSVLMm3atCxYsCArV678o+/pi1/8YrZs2ZI5c+ZkcHAwHR0d+clPfnJMn8fq1atz3XXXpaenJw8++GA6OzvzxBNPHNNrgX87GoaGhvk+PQBH8Oijj2bDww/nmieeHO5R3uSnH748H5w/Px//+MeHexSAN/BVLHBSam9vz77Ro3PoCNd0G06Hmpqyb/TotLe3D/coAG8i7ICTUnt7exqam7P39WvFnSz2tramobn5hIXdww8//IY7fsyePTuLFi06IWsB9TjHDjgptbW1ZVRra3a1tWX8K68M9ziH7Trz93O1tbX98Se/A/Pnz8/8+fNPyL6B+hyxA05KTU1NmXnxxdk2dUoGTpK7LQw0NmbrlCmZdcklR7ztF8BwOzn+WgIcwUUXXZRDLS15Yfz44R4lSbJ9/Pj0t7Rk1qxZwz0KwBEJO+CkNW7cuJzd2Zl/njY1g3/klmAn2mBDQ347bWrOnjEj48aNG9ZZAI5G2AEntXlXXZV948dn06RJwzrHxkmTsm/8+My78sphnQPgrQg74KTW0dGRS+fNy286O/PKMN12a29LSzbM6MyHrrwyHR0dwzIDwLEQdsBJb968eRk3eVKeOe+89L/HP6Tob2zMM+efl7ZJk3LFFVe8p2sDvF3CDjjpNTc355PXX5++iROz4oLz37Pz7QYbGrLigvOzv2Nirrv++jfdKg3gZCPsgFPChAkTsuCGhemeOjVPXHjBCT9y19/YmCcuvCDdU6dmwQ0LM2HChBO6HsDx4F6xwCll69ateeDe+9Kyc2cuWb8+Y/v6jvsae1ta8sz552V/x8QsuGFhpk2bdtzXADgRhB1wytm9e3ceWrYsPS/syLmbNqVzx440Hoc/ZYMNDdk4aVI2zOhM26RJue766x2pA04pwg44JfX392f58uVZuXx5xuzZk/dv3ZYpe/akaXDwbe9roLEx28ePz2+nTc2+8ePzoSuvzBVXXOGcOuCUI+yAU9rOnTvz+PLl2bJxY5r7+jJt+/Z0vNSd03t7M2Jg4KivO9TUlL2trdl1Zlu2TpmS/paWnD1jRua5pAlwChN2QAk9PT1ZvXp1Vj/zTA709maovz9j9u/P2O6evK+/P41DgxlsaMzB5ua80jYu+0aPTkNzc0a1tmbWJZdk1qxZ7igBnPKEHVDKwMBAuru709XVla6urry4e3cOHjiQgf7+NDU3532jRuWsCRPS3t6e9vb2tLW1pampabjHBjguhB0AQBGuYwcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABFCDsAgCKEHQBAEcIOAKAIYQcAUISwAwAoQtgBABQh7AAAihB2AABF/H/A1NoqwhhDAgAAAABJRU5ErkJggg==", + "image/png": "", "text/plain": [ "
" ] @@ -79,7 +83,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -107,7 +111,8 @@ " Parents\n", " Variation_Function\n", " Individual\n", - " Generation\n", + " Submitted Timestamp\n", + " Completed Timestamp\n", " Pareto_Front\n", " Instance\n", " \n", @@ -119,48 +124,53 @@ " NaN\n", " NaN\n", " ['LogisticRegression_1']\n", - " 0.0\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", " ['LogisticRegression_1']\n", " \n", " \n", " 1\n", - " 0.978247\n", + " 0.982956\n", " NaN\n", " NaN\n", " ['DecisionTreeClassifier_1']\n", - " 0.0\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", " ['DecisionTreeClassifier_1']\n", " \n", " \n", " 2\n", - " 0.966548\n", + " 0.953313\n", " NaN\n", " NaN\n", " ['KNeighborsClassifier_1']\n", - " 0.0\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", " ['KNeighborsClassifier_1']\n", " \n", " \n", " 3\n", - " 0.99877\n", + " 0.5\n", " NaN\n", " NaN\n", " ['GradientBoostingClassifier_1']\n", - " 0.0\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", " ['GradientBoostingClassifier_1']\n", " \n", " \n", " 4\n", - " 1.0\n", + " 0.983413\n", " NaN\n", " NaN\n", " ['ExtraTreesClassifier_1']\n", - " 0.0\n", - " 0.0\n", + " 1.690848e+09\n", + " 1.690848e+09\n", + " NaN\n", " ['ExtraTreesClassifier_1']\n", " \n", " \n", @@ -172,106 +182,125 @@ " ...\n", " ...\n", " ...\n", + " ...\n", " \n", " \n", - " 145\n", - " 0.993849\n", - " (42,)\n", + " 106\n", + " 0.989742\n", + " (104,)\n", " mutate\n", - " [('GaussianNB_1', 'SelectPercentile_1')]\n", - " 4.0\n", + " [('MLPClassifier_1', 'SelectPercentile_1')]\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", - " [('GaussianNB_1', 'SelectPercentile_1')]\n", + " [('MLPClassifier_1', 'SelectPercentile_1')]\n", " \n", " \n", - " 146\n", - " 0.99877\n", - " (114,)\n", + " 107\n", + " 0.99631\n", + " (12,)\n", " mutate\n", - " [('MLPClassifier_1', 'PolynomialFeatures_1')]\n", - " 4.0\n", + " ['MLPClassifier_1']\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", - " [('MLPClassifier_1', 'PolynomialFeatures_1')]\n", + " ['MLPClassifier_1']\n", " \n", " \n", - " 147\n", + " 108\n", " 0.99754\n", - " (48,)\n", - " mutate\n", - " [('GaussianNB_1', 'SelectPercentile_1'), ('Sel...\n", - " 4.0\n", + " (97, 93)\n", + " crossover_then_mutate\n", + " [('MLPClassifier_1', 'PolynomialFeatures_1'), ...\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", - " [('GaussianNB_1', 'SelectPercentile_1'), ('Sel...\n", + " [('MLPClassifier_1', 'PolynomialFeatures_1'), ...\n", " \n", " \n", - " 148\n", - " 0.5\n", - " (114,)\n", - " mutate\n", - " ['BernoulliNB_1']\n", - " 4.0\n", + " 109\n", + " 0.989484\n", + " (80, 61)\n", + " crossover\n", + " [('MLPClassifier_1', 'FastICA_1')]\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", - " ['BernoulliNB_1']\n", + " [('MLPClassifier_1', 'FastICA_1')]\n", " \n", " \n", - " 149\n", - " 0.999365\n", - " (4,)\n", + " 110\n", + " 0.994008\n", + " (71,)\n", " mutate\n", - " [('ExtraTreesClassifier_1', 'SelectFwe_1')]\n", - " 4.0\n", + " [('GradientBoostingClassifier_1', 'PolynomialF...\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " NaN\n", - " [('ExtraTreesClassifier_1', 'SelectFwe_1')]\n", + " [('GradientBoostingClassifier_1', 'PolynomialF...\n", " \n", " \n", "\n", - "

150 rows × 7 columns

\n", + "

111 rows × 8 columns

\n", "" ], "text/plain": [ - " roc_auc_score Parents Variation_Function \\\n", - "0 0.99631 NaN NaN \n", - "1 0.978247 NaN NaN \n", - "2 0.966548 NaN NaN \n", - "3 0.99877 NaN NaN \n", - "4 1.0 NaN NaN \n", - ".. ... ... ... \n", - "145 0.993849 (42,) mutate \n", - "146 0.99877 (114,) mutate \n", - "147 0.99754 (48,) mutate \n", - "148 0.5 (114,) mutate \n", - "149 0.999365 (4,) mutate \n", + " roc_auc_score Parents Variation_Function \\\n", + "0 0.99631 NaN NaN \n", + "1 0.982956 NaN NaN \n", + "2 0.953313 NaN NaN \n", + "3 0.5 NaN NaN \n", + "4 0.983413 NaN NaN \n", + ".. ... ... ... \n", + "106 0.989742 (104,) mutate \n", + "107 0.99631 (12,) mutate \n", + "108 0.99754 (97, 93) crossover_then_mutate \n", + "109 0.989484 (80, 61) crossover \n", + "110 0.994008 (71,) mutate \n", "\n", - " Individual Generation \\\n", - "0 ['LogisticRegression_1'] 0.0 \n", - "1 ['DecisionTreeClassifier_1'] 0.0 \n", - "2 ['KNeighborsClassifier_1'] 0.0 \n", - "3 ['GradientBoostingClassifier_1'] 0.0 \n", - "4 ['ExtraTreesClassifier_1'] 0.0 \n", - ".. ... ... \n", - "145 [('GaussianNB_1', 'SelectPercentile_1')] 4.0 \n", - "146 [('MLPClassifier_1', 'PolynomialFeatures_1')] 4.0 \n", - "147 [('GaussianNB_1', 'SelectPercentile_1'), ('Sel... 4.0 \n", - "148 ['BernoulliNB_1'] 4.0 \n", - "149 [('ExtraTreesClassifier_1', 'SelectFwe_1')] 4.0 \n", + " Individual Submitted Timestamp \\\n", + "0 ['LogisticRegression_1'] 1.690848e+09 \n", + "1 ['DecisionTreeClassifier_1'] 1.690848e+09 \n", + "2 ['KNeighborsClassifier_1'] 1.690848e+09 \n", + "3 ['GradientBoostingClassifier_1'] 1.690848e+09 \n", + "4 ['ExtraTreesClassifier_1'] 1.690848e+09 \n", + ".. ... ... \n", + "106 [('MLPClassifier_1', 'SelectPercentile_1')] 1.690848e+09 \n", + "107 ['MLPClassifier_1'] 1.690848e+09 \n", + "108 [('MLPClassifier_1', 'PolynomialFeatures_1'), ... 1.690848e+09 \n", + "109 [('MLPClassifier_1', 'FastICA_1')] 1.690848e+09 \n", + "110 [('GradientBoostingClassifier_1', 'PolynomialF... 1.690848e+09 \n", "\n", - " Pareto_Front Instance \n", - "0 NaN ['LogisticRegression_1'] \n", - "1 NaN ['DecisionTreeClassifier_1'] \n", - "2 NaN ['KNeighborsClassifier_1'] \n", - "3 NaN ['GradientBoostingClassifier_1'] \n", - "4 0.0 ['ExtraTreesClassifier_1'] \n", - ".. ... ... \n", - "145 NaN [('GaussianNB_1', 'SelectPercentile_1')] \n", - "146 NaN [('MLPClassifier_1', 'PolynomialFeatures_1')] \n", - "147 NaN [('GaussianNB_1', 'SelectPercentile_1'), ('Sel... \n", - "148 NaN ['BernoulliNB_1'] \n", - "149 NaN [('ExtraTreesClassifier_1', 'SelectFwe_1')] \n", + " Completed Timestamp Pareto_Front \\\n", + "0 1.690848e+09 NaN \n", + "1 1.690848e+09 NaN \n", + "2 1.690848e+09 NaN \n", + "3 1.690848e+09 NaN \n", + "4 1.690848e+09 NaN \n", + ".. ... ... \n", + "106 1.690848e+09 NaN \n", + "107 1.690848e+09 NaN \n", + "108 1.690848e+09 NaN \n", + "109 1.690848e+09 NaN \n", + "110 1.690848e+09 NaN \n", "\n", - "[150 rows x 7 columns]" + " Instance \n", + "0 ['LogisticRegression_1'] \n", + "1 ['DecisionTreeClassifier_1'] \n", + "2 ['KNeighborsClassifier_1'] \n", + "3 ['GradientBoostingClassifier_1'] \n", + "4 ['ExtraTreesClassifier_1'] \n", + ".. ... \n", + "106 [('MLPClassifier_1', 'SelectPercentile_1')] \n", + "107 ['MLPClassifier_1'] \n", + "108 [('MLPClassifier_1', 'PolynomialFeatures_1'), ... \n", + "109 [('MLPClassifier_1', 'FastICA_1')] \n", + "110 [('GradientBoostingClassifier_1', 'PolynomialF... \n", + "\n", + "[111 rows x 8 columns]" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -283,7 +312,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -311,35 +340,43 @@ " Parents\n", " Variation_Function\n", " Individual\n", - " Generation\n", + " Submitted Timestamp\n", + " Completed Timestamp\n", " Pareto_Front\n", " Instance\n", " \n", " \n", " \n", " \n", - " 4\n", + " 69\n", + " 0.99754\n", + " (51,)\n", + " mutate\n", + " [('LogisticRegression_1', 'PolynomialFeatures_...\n", + " 1.690848e+09\n", + " 1.690848e+09\n", " 1.0\n", - " NaN\n", - " NaN\n", - " ['ExtraTreesClassifier_1']\n", - " 0.0\n", - " 0.0\n", - " ['ExtraTreesClassifier_1']\n", + " [('LogisticRegression_1', 'PolynomialFeatures_...\n", " \n", " \n", "\n", "" ], "text/plain": [ - " roc_auc_score Parents Variation_Function Individual \\\n", - "4 1.0 NaN NaN ['ExtraTreesClassifier_1'] \n", + " roc_auc_score Parents Variation_Function \\\n", + "69 0.99754 (51,) mutate \n", "\n", - " Generation Pareto_Front Instance \n", - "4 0.0 0.0 ['ExtraTreesClassifier_1'] " + " Individual Submitted Timestamp \\\n", + "69 [('LogisticRegression_1', 'PolynomialFeatures_... 1.690848e+09 \n", + "\n", + " Completed Timestamp Pareto_Front \\\n", + "69 1.690848e+09 1.0 \n", + "\n", + " Instance \n", + "69 [('LogisticRegression_1', 'PolynomialFeatures_... " ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -349,6 +386,68 @@ "est.pareto_front" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TPOTEstimator does a standard evolutionary algorithm. In this version, the next generation doesn't start evaluation until all individuals in the previous generation are finished evaluating." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generation: 100%|██████████| 5/5 [00:35<00:00, 7.17s/it]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9947089947089948\n" + ] + } + ], + "source": [ + "import tpot2\n", + "import sklearn\n", + "import sklearn.datasets\n", + "\n", + "est = tpot2.TPOTEstimator( population_size=30,\n", + " generations=5,\n", + " scorers=['roc_auc_ovr'], #scorers can be a list of strings or a list of scorers. These get evaluated during cross validation. \n", + " scorers_weights=[1],\n", + " classification=True,\n", + " n_jobs=1, \n", + " early_stop=5, #how many generations with no improvement to stop after\n", + " \n", + " #List of other objective functions. All objective functions take in an untrained GraphPipeline and return a score or a list of scores\n", + " other_objective_functions= [ ],\n", + " \n", + " #List of weights for the other objective functions. Must be the same length as other_objective_functions. By default, bigger is better is set to True. \n", + " other_objective_functions_weights=[],\n", + " verbose=2)\n", + "\n", + "scorer = sklearn.metrics.get_scorer('roc_auc_ovo')\n", + "X, y = sklearn.datasets.load_iris(return_X_y=True)\n", + "X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(X, y, train_size=0.75, test_size=0.25)\n", + "est.fit(X_train, y_train)\n", + "print(scorer(est, X_test, y_test))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The TPOTClassifier and TPOTRegressor are set default parameters for the TPOTEstimator for Classification and Regression.\n", + "In the future, a metalearner will be used to predict the best values for a given dataset." + ] + }, { "cell_type": "code", "execution_count": 5, diff --git a/tpot2/tpot_estimator/estimator.py b/tpot2/tpot_estimator/estimator.py index 6d1c43b4..092fedeb 100644 --- a/tpot2/tpot_estimator/estimator.py +++ b/tpot2/tpot_estimator/estimator.py @@ -581,7 +581,7 @@ def fit(self, X, y): if isinstance(self.cv, int) or isinstance(self.cv, float): n_folds = self.cv else: - n_folds = self.cv.n_splits + n_folds = self.cv.get_n_splits(X, y) X, y = remove_underrepresented_classes(X, y, n_folds) diff --git a/tpot2/tpot_estimator/steady_state_estimator.py b/tpot2/tpot_estimator/steady_state_estimator.py index 085d5ffe..7ecf0cd0 100644 --- a/tpot2/tpot_estimator/steady_state_estimator.py +++ b/tpot2/tpot_estimator/steady_state_estimator.py @@ -570,7 +570,7 @@ def fit(self, X, y): if isinstance(self.cv, int) or isinstance(self.cv, float): n_folds = self.cv else: - n_folds = self.cv.n_splits + n_folds = self.cv.get_n_splits(X, y) X, y = remove_underrepresented_classes(X, y, n_folds)