From f18519a0e29c95c0e9179fcd7e86d9b19301cc50 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sat, 15 Dec 2018 17:28:06 -0500 Subject: [PATCH 1/9] Accept pre-binned data in fit --- pygbm/binning.py | 11 +++-- pygbm/gradient_boosting.py | 81 +++++++++++++++++++++++---------- pygbm/grower.py | 30 +++++++----- pygbm/predictor.py | 10 +++- tests/test_binning.py | 15 +++--- tests/test_compare_lightgbm.py | 12 +++-- tests/test_gradient_boosting.py | 38 ++++++++++++++++ tests/test_grower.py | 3 +- tests/test_predictor.py | 34 +++++++++++++- 9 files changed, 179 insertions(+), 55 deletions(-) diff --git a/pygbm/binning.py b/pygbm/binning.py index df0ddc8..793ac5d 100644 --- a/pygbm/binning.py +++ b/pygbm/binning.py @@ -14,7 +14,6 @@ def _find_binning_thresholds(data, max_bins=256, subsample=int(2e5), random_state=None): """Extract feature-wise equally-spaced quantiles from numerical data - Return ------ binning_thresholds: tuple of arrays @@ -152,13 +151,15 @@ def fit(self, X, y=None): self : object """ X = check_array(X) - self.bin_thresholds_ = _find_binning_thresholds( + self.numerical_thresholds_ = _find_binning_thresholds( X, self.max_bins, subsample=self.subsample, random_state=self.random_state) self.n_bins_per_feature_ = np.array( - [thresholds.shape[0] + 1 for thresholds in self.bin_thresholds_], - dtype=np.uint32) + [thresholds.shape[0] + 1 + for thresholds in self.numerical_thresholds_], + dtype=np.uint32 + ) return self @@ -175,4 +176,4 @@ def transform(self, X): X_binned : array-like The binned data """ - return _map_to_bins(X, binning_thresholds=self.bin_thresholds_) + return _map_to_bins(X, binning_thresholds=self.numerical_thresholds_) diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py index dd31471..f3595a0 100644 --- a/pygbm/gradient_boosting.py +++ b/pygbm/gradient_boosting.py @@ -42,7 +42,7 @@ def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes, self.verbose = verbose self.random_state = random_state - def _validate_parameters(self): + def _validate_parameters(self, X): """Validate parameters passed to __init__. The parameters that are directly passed to the grower are checked in @@ -69,6 +69,13 @@ def _validate_parameters(self): if self.tol is not None and self.tol < 0: raise ValueError(f'tol={self.tol} ' f'must not be smaller than 0.') + if self.train_data_is_pre_binned_: + max_bin_index = X.max() + if self.max_bins < max_bin_index + 1: + raise ValueError( + f'Data is pre-binned and max_bins={self.max_bins}, ' + f'but data has {max_bin_index + 1} bins.' + ) def fit(self, X, y): """Fit the gradient boosting model. @@ -76,7 +83,10 @@ def fit(self, X, y): Parameters ---------- X : array-like, shape=(n_samples, n_features) - The input samples. + The input samples. If ``X.dtype == np.uint8``, the data is + assumed to be pre-binned and the prediction methods + (``predict``, ``predict_proba``) will only accept pre-binned + data as well. y : array-like, shape=(n_samples,) Target values. @@ -93,8 +103,7 @@ def fit(self, X, y): acc_prediction_time = 0. # TODO: add support for mixed-typed (numerical + categorical) data # TODO: add support for missing data - # TODO: add support for pre-binned data (pass-through)? - X, y = check_X_y(X, y, dtype=[np.float32, np.float64]) + X, y = check_X_y(X, y, dtype=[np.float32, np.float64, np.uint8]) y = self._encode_y(y) if X.shape[0] == 1 or X.shape[1] == 1: raise ValueError( @@ -103,20 +112,33 @@ def fit(self, X, y): ) rng = check_random_state(self.random_state) - self._validate_parameters() + self.train_data_is_pre_binned_ = X.dtype == np.uint8 + self._validate_parameters(X) self.n_features_ = X.shape[1] # used for validation in predict() - if self.verbose: - print(f"Binning {X.nbytes / 1e9:.3f} GB of data: ", end="", - flush=True) - tic = time() - self.bin_mapper_ = BinMapper(max_bins=self.max_bins, random_state=rng) - X_binned = self.bin_mapper_.fit_transform(X) - toc = time() - if self.verbose: - duration = toc - tic - troughput = X.nbytes / duration - print(f"{duration:.3f} s ({troughput / 1e6:.3f} MB/s)") + if self.train_data_is_pre_binned_: + if self.verbose: + print("X is pre-binned.") + X_binned = X + self.bin_mapper_ = None + numerical_thresholds = None + n_bins_per_feature = X.max(axis=0).astype(np.uint32) + else: + if self.verbose: + print(f"Binning {X.nbytes / 1e9:.3f} GB of data: ", end="", + flush=True) + tic = time() + self.bin_mapper_ = BinMapper(max_bins=self.max_bins, + random_state=rng) + X_binned = self.bin_mapper_.fit_transform(X) + numerical_thresholds = self.bin_mapper_.numerical_thresholds_ + n_bins_per_feature = self.bin_mapper_.n_bins_per_feature_ + toc = time() + + if self.verbose: + duration = toc - tic + throughput = X.nbytes / duration + print(f"{duration:.3f} s ({throughput / 1e6:.3f} MB/s)") self.loss_ = self._get_loss() @@ -217,7 +239,7 @@ def fit(self, X, y): grower = TreeGrower( X_binned_train, gradients_at_k, hessians_at_k, max_bins=self.max_bins, - n_bins_per_feature=self.bin_mapper_.n_bins_per_feature_, + n_bins_per_feature=n_bins_per_feature, max_leaf_nodes=self.max_leaf_nodes, max_depth=self.max_depth, min_samples_leaf=self.min_samples_leaf, @@ -228,8 +250,7 @@ def fit(self, X, y): acc_apply_split_time += grower.total_apply_split_time acc_find_split_time += grower.total_find_split_time - predictor = grower.make_predictor( - bin_thresholds=self.bin_mapper_.bin_thresholds_) + predictor = grower.make_predictor(numerical_thresholds) predictors[-1].append(predictor) tic_pred = time() @@ -352,7 +373,8 @@ def _raw_predict(self, X): ---------- X : array-like, shape=(n_samples, n_features) The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + to be pre-binned and the estimator must have been fitted with + pre-binned data. Returns ------- @@ -366,6 +388,15 @@ def _raw_predict(self, X): f'X has {X.shape[1]} features but this estimator was ' f'trained with {self.n_features_} features.' ) + is_binned = X.dtype == np.uint8 + if not is_binned and self.train_data_is_pre_binned_: + raise ValueError( + 'This estimator was fitted with pre-binned data and ' + 'can only predict pre-binned data as well. If your data *is* ' + 'already pre-binnned, convert it to uint8 using e.g. ' + 'X.astype(np.uint8). If the data passed to fit() was *not* ' + 'pre-binned, convert it to float32 and call fit() again.' + ) n_samples = X.shape[0] raw_predictions = np.zeros( shape=(n_samples, self.n_trees_per_iteration_), @@ -373,7 +404,6 @@ def _raw_predict(self, X): ) raw_predictions += self.baseline_prediction_ # Should we parallelize this? - is_binned = X.dtype == np.uint8 for predictors_of_ith_iteration in self.predictors_: for k, predictor in enumerate(predictors_of_ith_iteration): predict = (predictor.predict_binned if is_binned @@ -489,7 +519,8 @@ def predict(self, X): ---------- X : array-like, shape=(n_samples, n_features) The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + to be pre-binned and the estimator must have been fitted with + pre-binned data. Returns ------- @@ -607,7 +638,8 @@ def predict(self, X): ---------- X : array-like, shape=(n_samples, n_features) The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + to be pre-binned and the estimator must have been fitted with + pre-binned data. Returns ------- @@ -625,7 +657,8 @@ def predict_proba(self, X): ---------- X : array-like, shape=(n_samples, n_features) The input samples. If ``X.dtype == np.uint8``, the data is assumed - to be pre-binned. + to be pre-binned and the estimator must have been fitted with + pre-binned data. Returns ------- diff --git a/pygbm/grower.py b/pygbm/grower.py index f1b5000..1ea3749 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -413,25 +413,30 @@ def _finalize_splittable_nodes(self): node = self.splittable_nodes.pop() self._finalize_leaf(node) - def make_predictor(self, bin_thresholds=None): + def make_predictor(self, numerical_thresholds=None): """Make a TreePredictor object out of the current tree. Parameters ---------- - bin_thresholds : array-like of floats, optional (default=None) - The actual thresholds values of each bin. + numerical_thresholds : array-like of floats, optional (default=None) + The actual thresholds values of each bin. None if the training data + was pre-binned. Returns ------- A TreePredictor object. """ predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE) - self._fill_predictor_node_array(predictor_nodes, self.root, - bin_thresholds=bin_thresholds) - return TreePredictor(predictor_nodes) + self._fill_predictor_node_array( + predictor_nodes, self.root, + numerical_thresholds=numerical_thresholds + ) + has_numerical_thresholds = numerical_thresholds is not None + return TreePredictor(nodes=predictor_nodes, + has_numerical_thresholds=has_numerical_thresholds) def _fill_predictor_node_array(self, predictor_nodes, grower_node, - bin_thresholds=None, next_free_idx=0): + numerical_thresholds=None, next_free_idx=0): """Helper used in make_predictor to set the TreePredictor fields.""" node = predictor_nodes[next_free_idx] node['count'] = grower_node.n_samples @@ -452,17 +457,18 @@ def _fill_predictor_node_array(self, predictor_nodes, grower_node, feature_idx, bin_idx = split_info.feature_idx, split_info.bin_idx node['feature_idx'] = feature_idx node['bin_threshold'] = bin_idx - if bin_thresholds is not None: - threshold = bin_thresholds[feature_idx][bin_idx] - node['threshold'] = threshold + if numerical_thresholds is not None: + node['threshold'] = numerical_thresholds[feature_idx][bin_idx] next_free_idx += 1 node['left'] = next_free_idx next_free_idx = self._fill_predictor_node_array( predictor_nodes, grower_node.left_child, - bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) + numerical_thresholds=numerical_thresholds, + next_free_idx=next_free_idx) node['right'] = next_free_idx return self._fill_predictor_node_array( predictor_nodes, grower_node.right_child, - bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) + numerical_thresholds=numerical_thresholds, + next_free_idx=next_free_idx) diff --git a/pygbm/predictor.py b/pygbm/predictor.py index 8b4afa0..0f46dc9 100644 --- a/pygbm/predictor.py +++ b/pygbm/predictor.py @@ -28,8 +28,9 @@ class TreePredictor: nodes : list of PREDICTOR_RECORD_DTYPE. The nodes of the tree. """ - def __init__(self, nodes): + def __init__(self, nodes, has_numerical_thresholds=True): self.nodes = nodes + self.has_numerical_thresholds = has_numerical_thresholds def get_n_leaf_nodes(self): """Return number of leaves.""" @@ -74,6 +75,13 @@ def predict(self, X): """ # TODO: introspect X to dispatch to numerical or categorical data # (dense or sparse) on a feature by feature basis. + + if not self.has_numerical_thresholds: + raise ValueError( + 'This predictor does not have numerical thresholds so it can' + 'only predict pre-binned data.' + ) + out = np.empty(X.shape[0], dtype=np.float32) _predict_from_numeric_data(self.nodes, X, out) return out diff --git a/tests/test_binning.py b/tests/test_binning.py index f7d7f8f..10db67a 100644 --- a/tests/test_binning.py +++ b/tests/test_binning.py @@ -95,10 +95,10 @@ def test_bin_mapper_random_data(n_bins): assert binned.dtype == np.uint8 assert_array_equal(binned.min(axis=0), np.array([0, 0])) assert_array_equal(binned.max(axis=0), np.array([n_bins - 1, n_bins - 1])) - assert len(mapper.bin_thresholds_) == n_features - for i in range(len(mapper.bin_thresholds_)): - assert mapper.bin_thresholds_[i].shape == (n_bins - 1,) - assert mapper.bin_thresholds_[i].dtype == DATA.dtype + assert len(mapper.numerical_thresholds_) == n_features + for i in range(len(mapper.numerical_thresholds_)): + assert mapper.numerical_thresholds_[i].shape == (n_bins - 1,) + assert mapper.numerical_thresholds_[i].dtype == DATA.dtype assert np.all(mapper.n_bins_per_feature_ == n_bins) # Check that the binned data is approximately balanced across bins. @@ -159,7 +159,8 @@ def test_bin_mapper_repeated_values_invariance(n_distinct): mapper_2 = BinMapper(max_bins=min(256, n_distinct * 3)) binned_2 = mapper_2.fit_transform(data) - assert_allclose(mapper_1.bin_thresholds_[0], mapper_2.bin_thresholds_[0]) + assert_allclose(mapper_1.numerical_thresholds_[0], + mapper_2.numerical_thresholds_[0]) assert_array_equal(binned_1, binned_2) @@ -214,7 +215,7 @@ def test_subsample(): for feature in range(DATA.shape[1]): with pytest.raises(AssertionError): np.testing.assert_array_almost_equal( - mapper_no_subsample.bin_thresholds_[feature], - mapper_subsample.bin_thresholds_[feature], + mapper_no_subsample.numerical_thresholds_[feature], + mapper_subsample.numerical_thresholds_[feature], decimal=3 ) diff --git a/tests/test_compare_lightgbm.py b/tests/test_compare_lightgbm.py index 1f39989..5ffce6f 100644 --- a/tests/test_compare_lightgbm.py +++ b/tests/test_compare_lightgbm.py @@ -45,7 +45,9 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, n_informative=5, random_state=0) if n_samples > 255: - X = BinMapper(max_bins=max_bins).fit_transform(X) + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) @@ -94,7 +96,9 @@ def test_same_predictions_classification(seed, min_samples_leaf, n_samples, n_informative=5, n_redundant=0, random_state=0) if n_samples > 255: - X = BinMapper(max_bins=max_bins).fit_transform(X) + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) @@ -154,7 +158,9 @@ def test_same_predictions_multiclass_classification( n_clusters_per_class=1, random_state=0) if n_samples > 255: - X = BinMapper(max_bins=max_bins).fit_transform(X) + # bin data and convert it to float32 so that the estimator doesn't + # treat it as pre-binned + X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) diff --git a/tests/test_gradient_boosting.py b/tests/test_gradient_boosting.py index a16502e..b3ce0c2 100644 --- a/tests/test_gradient_boosting.py +++ b/tests/test_gradient_boosting.py @@ -1,12 +1,14 @@ import os import warnings +import numpy as np import pytest from sklearn.utils.testing import assert_raises_regex from sklearn.datasets import make_classification, make_regression from pygbm import GradientBoostingClassifier from pygbm import GradientBoostingRegressor +from pygbm.binning import BinMapper X_classification, y_classification = make_classification(random_state=0) @@ -69,6 +71,12 @@ def test_init_parameters_validation(GradientBoosting, X, y): GradientBoosting(max_bins=max_bins).fit, X, y ) + assert_raises_regex( + ValueError, + f"Data is pre-binned and max_bins=4, but data has 256 bins", + GradientBoosting(max_bins=4).fit, X.astype(np.uint8), y + ) + assert_raises_regex( ValueError, f"n_iter_no_change=1 must not be smaller than 2", @@ -243,3 +251,33 @@ def test_estimator_checks(Estimator): # dataset, the root is never split with min_samples_leaf=20 and only the # majority class is predicted. custom_check_estimator(Estimator) + + +def test_pre_binned_data(): + # Make sure that: + # - training on numerical data and predicting on numerical data is the + # same as training on binned data and predicting on binned data + # - training on numerical data and predicting on numerical data is the + # same as training on numerical data and predicting on binned data + # - training on binned data and predicting on numerical data is not + # possible. + + X, y = make_regression(random_state=0) + gbdt = GradientBoostingRegressor(scoring=None, random_state=0) + mapper = BinMapper(random_state=0) + X_binned = mapper.fit_transform(X) + + fit_num_pred_num = gbdt.fit(X, y).predict(X) + fit_binned_pred_binned = gbdt.fit(X_binned, y).predict(X_binned) + fit_num_pred_binned = gbdt.fit(X, y).predict(X_binned) + + np.testing.assert_array_almost_equal(fit_num_pred_num, + fit_binned_pred_binned) + np.testing.assert_array_almost_equal(fit_num_pred_num, + fit_num_pred_binned) + + assert_raises_regex( + ValueError, + 'This estimator was fitted with pre-binned data ', + gbdt.fit(X_binned, y).predict, X + ) diff --git a/tests/test_grower.py b/tests/test_grower.py index e4c9dc7..d48f658 100644 --- a/tests/test_grower.py +++ b/tests/test_grower.py @@ -215,7 +215,8 @@ def test_min_samples_leaf(n_samples, min_samples_leaf, n_bins, min_samples_leaf=min_samples_leaf, max_leaf_nodes=n_samples) grower.grow() - predictor = grower.make_predictor(bin_thresholds=mapper.bin_thresholds_) + predictor = grower.make_predictor( + numerical_thresholds=mapper.numerical_thresholds_) if n_samples >= min_samples_leaf: for node in predictor.nodes: diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 9f52655..70cd86c 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -1,8 +1,9 @@ import numpy as np from numpy.testing import assert_allclose -from sklearn.datasets import load_boston +from sklearn.datasets import load_boston, make_regression from sklearn.model_selection import train_test_split from sklearn.metrics import r2_score +from sklearn.utils.testing import assert_raises_regex import pytest from pygbm.binning import BinMapper @@ -31,7 +32,8 @@ def test_boston_dataset(max_bins): n_bins_per_feature=mapper.n_bins_per_feature_) grower.grow() - predictor = grower.make_predictor(bin_thresholds=mapper.bin_thresholds_) + predictor = grower.make_predictor( + numerical_thresholds=mapper.numerical_thresholds_) assert r2_score(y_train, predictor.predict_binned(X_train_binned)) > 0.85 assert r2_score(y_test, predictor.predict_binned(X_test_binned)) > 0.70 @@ -44,3 +46,31 @@ def test_boston_dataset(max_bins): assert r2_score(y_train, predictor.predict(X_train)) > 0.85 assert r2_score(y_test, predictor.predict(X_test)) > 0.70 + + +def test_pre_binned_data(): + # Make sure ValueError is raised when predictor.predict() is called while + # the predictor does not have any numerical thresholds. + + X, y = make_regression() + + # Init gradients and hessians to that of least squares loss + gradients = -y.astype(np.float32) + hessians = np.ones(1, dtype=np.float32) + + mapper = BinMapper(random_state=0) + X_binned = mapper.fit_transform(X) + grower = TreeGrower(X_binned, gradients, hessians, + n_bins_per_feature=mapper.n_bins_per_feature_) + grower.grow() + predictor = grower.make_predictor( + numerical_thresholds=None + ) + + assert_raises_regex( + ValueError, + 'This predictor does not have numerical thresholds', + predictor.predict, X_binned + ) + + predictor.predict_binned(X) # No error From 1fcf558bdb46a558dd508facb248111a6ee186af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 17 Dec 2018 09:46:47 -0500 Subject: [PATCH 2/9] Addressed comments --- pygbm/gradient_boosting.py | 7 +++---- pygbm/predictor.py | 7 +++++++ tests/test_gradient_boosting.py | 7 +++---- tests/test_predictor.py | 19 +++++++++++++++++-- 4 files changed, 30 insertions(+), 10 deletions(-) diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py index f3595a0..c6dcca6 100644 --- a/pygbm/gradient_boosting.py +++ b/pygbm/gradient_boosting.py @@ -69,7 +69,7 @@ def _validate_parameters(self, X): if self.tol is not None and self.tol < 0: raise ValueError(f'tol={self.tol} ' f'must not be smaller than 0.') - if self.train_data_is_pre_binned_: + if X.dtype == np.uint8: # pre-binned data max_bin_index = X.max() if self.max_bins < max_bin_index + 1: raise ValueError( @@ -112,11 +112,10 @@ def fit(self, X, y): ) rng = check_random_state(self.random_state) - self.train_data_is_pre_binned_ = X.dtype == np.uint8 self._validate_parameters(X) self.n_features_ = X.shape[1] # used for validation in predict() - if self.train_data_is_pre_binned_: + if X.dtype == np.uint8: # data is pre-binned if self.verbose: print("X is pre-binned.") X_binned = X @@ -389,7 +388,7 @@ def _raw_predict(self, X): f'trained with {self.n_features_} features.' ) is_binned = X.dtype == np.uint8 - if not is_binned and self.train_data_is_pre_binned_: + if not is_binned and self.bin_mapper_ is None: raise ValueError( 'This estimator was fitted with pre-binned data and ' 'can only predict pre-binned data as well. If your data *is* ' diff --git a/pygbm/predictor.py b/pygbm/predictor.py index 0f46dc9..8a93818 100644 --- a/pygbm/predictor.py +++ b/pygbm/predictor.py @@ -55,6 +55,10 @@ def predict_binned(self, binned_data, out=None): y : array, shape (n_samples,) The raw predicted values. """ + + if binned_data.dtype != np.uint8: + raise ValueError('binned_data dtype should be uint8') + if out is None: out = np.empty(binned_data.shape[0], dtype=np.float32) _predict_binned(self.nodes, binned_data, out) @@ -82,6 +86,9 @@ def predict(self, X): 'only predict pre-binned data.' ) + if X.dtype == np.uint8: + raise ValueError('X dtype should be float32 or float64') + out = np.empty(X.shape[0], dtype=np.float32) _predict_from_numeric_data(self.nodes, X, out) return out diff --git a/tests/test_gradient_boosting.py b/tests/test_gradient_boosting.py index b3ce0c2..067ed61 100644 --- a/tests/test_gradient_boosting.py +++ b/tests/test_gradient_boosting.py @@ -2,6 +2,7 @@ import warnings import numpy as np +from numpy.testing import assert_allclose import pytest from sklearn.utils.testing import assert_raises_regex from sklearn.datasets import make_classification, make_regression @@ -271,10 +272,8 @@ def test_pre_binned_data(): fit_binned_pred_binned = gbdt.fit(X_binned, y).predict(X_binned) fit_num_pred_binned = gbdt.fit(X, y).predict(X_binned) - np.testing.assert_array_almost_equal(fit_num_pred_num, - fit_binned_pred_binned) - np.testing.assert_array_almost_equal(fit_num_pred_num, - fit_num_pred_binned) + assert_allclose(fit_num_pred_num, fit_binned_pred_binned) + assert_allclose(fit_num_pred_num, fit_num_pred_binned) assert_raises_regex( ValueError, diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 70cd86c..7adc9eb 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -70,7 +70,22 @@ def test_pre_binned_data(): assert_raises_regex( ValueError, 'This predictor does not have numerical thresholds', - predictor.predict, X_binned + predictor.predict, X + ) + + assert_raises_regex( + ValueError, + 'binned_data dtype should be uint8', + predictor.predict_binned, X ) - predictor.predict_binned(X) # No error + predictor.predict_binned(X_binned) # No error + + predictor = grower.make_predictor( + numerical_thresholds=mapper.numerical_thresholds_ + ) + assert_raises_regex( + ValueError, + 'X dtype should be float32 or float64', + predictor.predict, X_binned + ) From 648ad172bfea2723c186e99d66a357b572da36b9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 11:26:33 -0500 Subject: [PATCH 3/9] Addressed comments --- pygbm/gradient_boosting.py | 4 ++-- pygbm/grower.py | 4 ++-- pygbm/predictor.py | 6 +++++- tests/test_gradient_boosting.py | 2 +- tests/test_predictor.py | 2 +- tox.ini | 2 +- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py index a0814f5..0e37918 100644 --- a/pygbm/gradient_boosting.py +++ b/pygbm/gradient_boosting.py @@ -73,8 +73,8 @@ def _validate_parameters(self, X): max_bin_index = X.max() if self.max_bins < max_bin_index + 1: raise ValueError( - f'Data is pre-binned and max_bins={self.max_bins}, ' - f'but data has {max_bin_index + 1} bins.' + f'max_bins is set to {self.max_bins} but the data is ' + f'pre-binned with {max_bin_index + 1} bins.' ) def fit(self, X, y): diff --git a/pygbm/grower.py b/pygbm/grower.py index 1ea3749..c77d000 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -419,8 +419,8 @@ def make_predictor(self, numerical_thresholds=None): Parameters ---------- numerical_thresholds : array-like of floats, optional (default=None) - The actual thresholds values of each bin. None if the training data - was pre-binned. + The actual thresholds values of each bin, expected to be in sorted + increasing order. None if the training data was pre-binned. Returns ------- diff --git a/pygbm/predictor.py b/pygbm/predictor.py index 8a93818..7a7fa7f 100644 --- a/pygbm/predictor.py +++ b/pygbm/predictor.py @@ -87,7 +87,11 @@ def predict(self, X): ) if X.dtype == np.uint8: - raise ValueError('X dtype should be float32 or float64') + raise ValueError( + 'X has uint8 dtype: use estimator.predict(X) if X is ' + 'pre-binned, or convert X to a float32 dtype to be treated ' + 'as numerical data' + ) out = np.empty(X.shape[0], dtype=np.float32) _predict_from_numeric_data(self.nodes, X, out) diff --git a/tests/test_gradient_boosting.py b/tests/test_gradient_boosting.py index 5aadf23..1b02929 100644 --- a/tests/test_gradient_boosting.py +++ b/tests/test_gradient_boosting.py @@ -74,7 +74,7 @@ def test_init_parameters_validation(GradientBoosting, X, y): assert_raises_regex( ValueError, - f"Data is pre-binned and max_bins=4, but data has 256 bins", + f"max_bins is set to 4 but the data is pre-binned with 256 bins", GradientBoosting(max_bins=4).fit, X.astype(np.uint8), y ) diff --git a/tests/test_predictor.py b/tests/test_predictor.py index 7adc9eb..be6bdad 100644 --- a/tests/test_predictor.py +++ b/tests/test_predictor.py @@ -86,6 +86,6 @@ def test_pre_binned_data(): ) assert_raises_regex( ValueError, - 'X dtype should be float32 or float64', + 'X has uint8 dtype', predictor.predict, X_binned ) diff --git a/tox.ini b/tox.ini index b8ed507..cf0701c 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy + scipy == 1.1.0 scikit-learn numba pytest From 3826804ff344b549603eb34208bfb5b7bb184594 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 11:35:44 -0500 Subject: [PATCH 4/9] Added reference to scipy version issue in tox.ini --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index cf0701c..6dce557 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy == 1.1.0 + scipy == 1.1.0 # temporary fix for issue #82 scikit-learn numba pytest From 1e7075e436f2fd3a4453acb1db0b68365feacae3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 12:29:53 -0500 Subject: [PATCH 5/9] Trying tox with no whitespace, single = sign --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6dce557..c144a52 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy == 1.1.0 # temporary fix for issue #82 + scipy=1.1.0 # temporary fix for issue #82 scikit-learn numba pytest From 1f060086c002d8e40c840a33fe5511bcf1fdc6de Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 13:03:11 -0500 Subject: [PATCH 6/9] double equal sign... ? --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index c144a52..07ccd77 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy=1.1.0 # temporary fix for issue #82 + scipy==1.1.0 # temporary fix for issue #82 scikit-learn numba pytest From 688419d2e2d94f7ac9646ebe8cf92af126a38ed4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 13:11:06 -0500 Subject: [PATCH 7/9] sigh... --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 07ccd77..1c3b899 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy==1.1.0 # temporary fix for issue #82 + scipy <= 1.1.0 # temporary fix for issue #82 scikit-learn numba pytest From 32eb845d87c76b48b62391b652b153d72ee645af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 13:19:29 -0500 Subject: [PATCH 8/9] removing comment ??????? --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 1c3b899..029fe24 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy <= 1.1.0 # temporary fix for issue #82 + scipy <= 1.1.0 scikit-learn numba pytest From 056512dbfe3ec1be19ed69219d5bbf21e57607ff Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 18 Dec 2018 13:26:49 -0500 Subject: [PATCH 9/9] trying == again. Comment was causing error --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 029fe24..cf0701c 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ skip_missing_interpreters=True [testenv] deps = numpy - scipy <= 1.1.0 + scipy == 1.1.0 scikit-learn numba pytest