Skip to content

Commit

Permalink
cluster max with features and not samples
Browse files Browse the repository at this point in the history
  • Loading branch information
perib committed Oct 8, 2024
1 parent bfb40d9 commit 38fe6af
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
11 changes: 7 additions & 4 deletions tpot2/config/get_configspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
from sklearn.svm import SVC, SVR, LinearSVR, LinearSVC
from lightgbm import LGBMClassifier, LGBMRegressor
import sklearn
from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB
from sklearn.decomposition import FastICA, PCA
from sklearn.cluster import FeatureAgglomeration
Expand Down Expand Up @@ -117,7 +118,7 @@
"selectors": ["SelectFwe", "SelectPercentile", "VarianceThreshold",],
"selectors_classification": ["SelectFwe", "SelectPercentile", "VarianceThreshold", "RFE_classification", "SelectFromModel_classification"],
"selectors_regression": ["SelectFwe", "SelectPercentile", "VarianceThreshold", "RFE_regression", "SelectFromModel_regression"],
"classifiers" : ["LGBMClassifier", "BaggingClassifier", 'AdaBoostClassifier', 'BernoulliNB', 'DecisionTreeClassifier', 'ExtraTreesClassifier', 'GaussianNB', 'HistGradientBoostingClassifier', 'KNeighborsClassifier','LinearDiscriminantAnalysis', 'LogisticRegression', "LinearSVC", "SVC", 'MLPClassifier', 'MultinomialNB', "QuadraticDiscriminantAnalysis", 'RandomForestClassifier', 'SGDClassifier', 'XGBClassifier'],
"classifiers" : ["LGBMClassifier", "BaggingClassifier", 'AdaBoostClassifier', 'BernoulliNB', 'DecisionTreeClassifier', 'ExtraTreesClassifier', 'GaussianNB', 'HistGradientBoostingClassifier', 'KNeighborsClassifier','LinearDiscriminantAnalysis', 'LogisticRegression', "LinearSVC_wrapped", "SVC", 'MLPClassifier', 'MultinomialNB', "QuadraticDiscriminantAnalysis", 'RandomForestClassifier', 'SGDClassifier', 'XGBClassifier'],
"regressors" : ["LGBMRegressor", 'AdaBoostRegressor', "ARDRegression", 'DecisionTreeRegressor', 'ExtraTreesRegressor', 'HistGradientBoostingRegressor', 'KNeighborsRegressor', 'LinearSVR', "MLPRegressor", 'RandomForestRegressor', 'SGDRegressor', 'SVR', 'XGBRegressor'],


Expand Down Expand Up @@ -298,7 +299,7 @@ def get_configspace(name, n_classes=3, n_samples=1000, n_features=100, random_st
case "FastICA":
return transformers.get_FastICA_configspace(n_features=n_features, random_state=random_state)
case "FeatureAgglomeration":
return transformers.get_FeatureAgglomeration_configspace(n_samples=n_samples)
return transformers.get_FeatureAgglomeration_configspace(n_features=n_features)
case "Nystroem":
return transformers.get_Nystroem_configspace(n_features=n_features, random_state=random_state)
case "RBFSampler":
Expand Down Expand Up @@ -522,7 +523,9 @@ def get_node(name, n_classes=3, n_samples=100, n_features=100, random_state=None
"""


if name == "LinearSVC_wrapped":
ext = get_node("LinearSVC", n_classes=n_classes, n_samples=n_samples, random_state=random_state)
return WrapperPipeline(estimator_search_space=ext, method=sklearn.calibration.CalibratedClassifierCV, space={})
if name == "RFE_classification":
rfe_sp = get_configspace(name="RFE", n_classes=n_classes, n_samples=n_samples, random_state=random_state)
ext = get_node("ExtraTreesClassifier", n_classes=n_classes, n_samples=n_samples, random_state=random_state)
Expand Down Expand Up @@ -577,7 +580,7 @@ def get_node(name, n_classes=3, n_samples=100, n_features=100, random_state=None
configspace = get_configspace(name, n_classes=n_classes, n_samples=n_samples, random_state=random_state)
return base_node(STRING_TO_CLASS[name], configspace, hyperparameter_parser=classifiers.GaussianProcessClassifier_hyperparameter_parser)
if name == "FeatureAgglomeration":
configspace = get_configspace(name, n_features=n_features)
configspace = get_configspace(name, n_classes=n_classes, n_samples=n_samples, random_state=random_state)
return base_node(STRING_TO_CLASS[name], configspace, hyperparameter_parser=transformers.FeatureAgglomeration_hyperparameter_parser)

configspace = get_configspace(name, n_classes=n_classes, n_samples=n_samples, n_features=n_features, random_state=random_state)
Expand Down
4 changes: 2 additions & 2 deletions tpot2/config/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def get_FastICA_configspace(n_features=100, random_state=None):

)

def get_FeatureAgglomeration_configspace(n_samples):
def get_FeatureAgglomeration_configspace(n_features):

linkage = Categorical('linkage', ['ward', 'complete', 'average'])
metric = Categorical('metric', ['euclidean', 'l1', 'l2', 'manhattan', 'cosine'])
n_clusters = Integer('n_clusters', bounds=(2, min(n_samples,400)))
n_clusters = Integer('n_clusters', bounds=(2, min(n_features,400)))
pooling_func = Categorical('pooling_func', ['mean', 'median', 'max'])

metric_condition = NotEqualsCondition(metric, linkage, 'ward')
Expand Down

0 comments on commit 38fe6af

Please sign in to comment.