Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Accept pre-binned data in fit #74

Merged
merged 10 commits into from
Dec 18, 2018
11 changes: 6 additions & 5 deletions pygbm/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def _find_binning_thresholds(data, max_bins=256, subsample=int(2e5),
random_state=None):
"""Extract feature-wise equally-spaced quantiles from numerical data


Return
------
binning_thresholds: tuple of arrays
Expand Down Expand Up @@ -152,13 +151,15 @@ def fit(self, X, y=None):
self : object
"""
X = check_array(X)
self.bin_thresholds_ = _find_binning_thresholds(
self.numerical_thresholds_ = _find_binning_thresholds(
X, self.max_bins, subsample=self.subsample,
random_state=self.random_state)

self.n_bins_per_feature_ = np.array(
[thresholds.shape[0] + 1 for thresholds in self.bin_thresholds_],
dtype=np.uint32)
[thresholds.shape[0] + 1
for thresholds in self.numerical_thresholds_],
dtype=np.uint32
)

return self

Expand All @@ -175,4 +176,4 @@ def transform(self, X):
X_binned : array-like
The binned data
"""
return _map_to_bins(X, binning_thresholds=self.bin_thresholds_)
return _map_to_bins(X, binning_thresholds=self.numerical_thresholds_)
80 changes: 56 additions & 24 deletions pygbm/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
self.verbose = verbose
self.random_state = random_state

def _validate_parameters(self):
def _validate_parameters(self, X):
"""Validate parameters passed to __init__.

The parameters that are directly passed to the grower are checked in
Expand All @@ -69,14 +69,24 @@ def _validate_parameters(self):
if self.tol is not None and self.tol < 0:
raise ValueError(f'tol={self.tol} '
f'must not be smaller than 0.')
if X.dtype == np.uint8: # pre-binned data
max_bin_index = X.max()
if self.max_bins < max_bin_index + 1:
raise ValueError(
f'max_bins is set to {self.max_bins} but the data is '
f'pre-binned with {max_bin_index + 1} bins.'
)

def fit(self, X, y):
"""Fit the gradient boosting model.

Parameters
----------
X : array-like, shape=(n_samples, n_features)
The input samples.
The input samples. If ``X.dtype == np.uint8``, the data is
assumed to be pre-binned and the prediction methods
(``predict``, ``predict_proba``) will only accept pre-binned
data as well.

y : array-like, shape=(n_samples,)
Target values.
Expand All @@ -93,8 +103,7 @@ def fit(self, X, y):
acc_prediction_time = 0.
# TODO: add support for mixed-typed (numerical + categorical) data
# TODO: add support for missing data
# TODO: add support for pre-binned data (pass-through)?
X, y = check_X_y(X, y, dtype=[np.float32, np.float64])
X, y = check_X_y(X, y, dtype=[np.float32, np.float64, np.uint8])
y = self._encode_y(y)
if X.shape[0] == 1 or X.shape[1] == 1:
raise ValueError(
Expand All @@ -103,20 +112,32 @@ def fit(self, X, y):
)
rng = check_random_state(self.random_state)

self._validate_parameters()
self._validate_parameters(X)
self.n_features_ = X.shape[1] # used for validation in predict()

if self.verbose:
print(f"Binning {X.nbytes / 1e9:.3f} GB of data: ", end="",
flush=True)
tic = time()
self.bin_mapper_ = BinMapper(max_bins=self.max_bins, random_state=rng)
X_binned = self.bin_mapper_.fit_transform(X)
toc = time()
if self.verbose:
duration = toc - tic
troughput = X.nbytes / duration
print(f"{duration:.3f} s ({troughput / 1e6:.3f} MB/s)")
if X.dtype == np.uint8: # data is pre-binned
if self.verbose:
print("X is pre-binned.")
X_binned = X
self.bin_mapper_ = None
numerical_thresholds = None
n_bins_per_feature = X.max(axis=0).astype(np.uint32)
else:
if self.verbose:
print(f"Binning {X.nbytes / 1e9:.3f} GB of data: ", end="",
flush=True)
tic = time()
self.bin_mapper_ = BinMapper(max_bins=self.max_bins,
random_state=rng)
X_binned = self.bin_mapper_.fit_transform(X)
numerical_thresholds = self.bin_mapper_.numerical_thresholds_
n_bins_per_feature = self.bin_mapper_.n_bins_per_feature_
toc = time()

if self.verbose:
duration = toc - tic
throughput = X.nbytes / duration
print(f"{duration:.3f} s ({throughput / 1e6:.3f} MB/s)")

self.loss_ = self._get_loss()

Expand Down Expand Up @@ -220,7 +241,7 @@ def fit(self, X, y):
grower = TreeGrower(
X_binned_train, gradients_at_k, hessians_at_k,
max_bins=self.max_bins,
n_bins_per_feature=self.bin_mapper_.n_bins_per_feature_,
n_bins_per_feature=n_bins_per_feature,
max_leaf_nodes=self.max_leaf_nodes,
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
Expand All @@ -231,8 +252,7 @@ def fit(self, X, y):
acc_apply_split_time += grower.total_apply_split_time
acc_find_split_time += grower.total_find_split_time

predictor = grower.make_predictor(
bin_thresholds=self.bin_mapper_.bin_thresholds_)
predictor = grower.make_predictor(numerical_thresholds)
predictors[-1].append(predictor)

tic_pred = time()
Expand Down Expand Up @@ -371,7 +391,8 @@ def _raw_predict(self, X):
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
to be pre-binned and the estimator must have been fitted with
pre-binned data.

Returns
-------
Expand All @@ -385,14 +406,22 @@ def _raw_predict(self, X):
f'X has {X.shape[1]} features but this estimator was '
f'trained with {self.n_features_} features.'
)
is_binned = X.dtype == np.uint8
if not is_binned and self.bin_mapper_ is None:
raise ValueError(
'This estimator was fitted with pre-binned data and '
'can only predict pre-binned data as well. If your data *is* '
'already pre-binnned, convert it to uint8 using e.g. '
'X.astype(np.uint8). If the data passed to fit() was *not* '
'pre-binned, convert it to float32 and call fit() again.'
)
n_samples = X.shape[0]
raw_predictions = np.zeros(
shape=(n_samples, self.n_trees_per_iteration_),
dtype=self.baseline_prediction_.dtype
)
raw_predictions += self.baseline_prediction_
# Should we parallelize this?
is_binned = X.dtype == np.uint8
for predictors_of_ith_iteration in self.predictors_:
for k, predictor in enumerate(predictors_of_ith_iteration):
predict = (predictor.predict_binned if is_binned
Expand Down Expand Up @@ -509,7 +538,8 @@ def predict(self, X):
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
to be pre-binned and the estimator must have been fitted with
pre-binned data.

Returns
-------
Expand Down Expand Up @@ -628,7 +658,8 @@ def predict(self, X):
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
to be pre-binned and the estimator must have been fitted with
pre-binned data.

Returns
-------
Expand All @@ -646,7 +677,8 @@ def predict_proba(self, X):
----------
X : array-like, shape=(n_samples, n_features)
The input samples. If ``X.dtype == np.uint8``, the data is assumed
to be pre-binned.
to be pre-binned and the estimator must have been fitted with
pre-binned data.

Returns
-------
Expand Down
30 changes: 18 additions & 12 deletions pygbm/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,25 +413,30 @@ def _finalize_splittable_nodes(self):
node = self.splittable_nodes.pop()
self._finalize_leaf(node)

def make_predictor(self, bin_thresholds=None):
def make_predictor(self, numerical_thresholds=None):
"""Make a TreePredictor object out of the current tree.

Parameters
----------
bin_thresholds : array-like of floats, optional (default=None)
The actual thresholds values of each bin.
numerical_thresholds : array-like of floats, optional (default=None)
The actual thresholds values of each bin, expected to be in sorted
increasing order. None if the training data was pre-binned.

Returns
-------
A TreePredictor object.
"""
predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE)
self._fill_predictor_node_array(predictor_nodes, self.root,
bin_thresholds=bin_thresholds)
return TreePredictor(predictor_nodes)
self._fill_predictor_node_array(
predictor_nodes, self.root,
numerical_thresholds=numerical_thresholds
)
has_numerical_thresholds = numerical_thresholds is not None
return TreePredictor(nodes=predictor_nodes,
has_numerical_thresholds=has_numerical_thresholds)

def _fill_predictor_node_array(self, predictor_nodes, grower_node,
bin_thresholds=None, next_free_idx=0):
numerical_thresholds=None, next_free_idx=0):
"""Helper used in make_predictor to set the TreePredictor fields."""
node = predictor_nodes[next_free_idx]
node['count'] = grower_node.n_samples
Expand All @@ -452,17 +457,18 @@ def _fill_predictor_node_array(self, predictor_nodes, grower_node,
feature_idx, bin_idx = split_info.feature_idx, split_info.bin_idx
node['feature_idx'] = feature_idx
node['bin_threshold'] = bin_idx
if bin_thresholds is not None:
threshold = bin_thresholds[feature_idx][bin_idx]
node['threshold'] = threshold
if numerical_thresholds is not None:
node['threshold'] = numerical_thresholds[feature_idx][bin_idx]
next_free_idx += 1

node['left'] = next_free_idx
next_free_idx = self._fill_predictor_node_array(
predictor_nodes, grower_node.left_child,
bin_thresholds=bin_thresholds, next_free_idx=next_free_idx)
numerical_thresholds=numerical_thresholds,
next_free_idx=next_free_idx)

node['right'] = next_free_idx
return self._fill_predictor_node_array(
predictor_nodes, grower_node.right_child,
bin_thresholds=bin_thresholds, next_free_idx=next_free_idx)
numerical_thresholds=numerical_thresholds,
next_free_idx=next_free_idx)
21 changes: 20 additions & 1 deletion pygbm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class TreePredictor:
nodes : list of PREDICTOR_RECORD_DTYPE.
The nodes of the tree.
"""
def __init__(self, nodes):
def __init__(self, nodes, has_numerical_thresholds=True):
self.nodes = nodes
self.has_numerical_thresholds = has_numerical_thresholds

def get_n_leaf_nodes(self):
"""Return number of leaves."""
Expand All @@ -54,6 +55,10 @@ def predict_binned(self, binned_data, out=None):
y : array, shape (n_samples,)
The raw predicted values.
"""

if binned_data.dtype != np.uint8:
raise ValueError('binned_data dtype should be uint8')

if out is None:
out = np.empty(binned_data.shape[0], dtype=np.float32)
_predict_binned(self.nodes, binned_data, out)
Expand All @@ -74,6 +79,20 @@ def predict(self, X):
"""
# TODO: introspect X to dispatch to numerical or categorical data
# (dense or sparse) on a feature by feature basis.

if not self.has_numerical_thresholds:
raise ValueError(
'This predictor does not have numerical thresholds so it can'
'only predict pre-binned data.'
)

if X.dtype == np.uint8:
raise ValueError(
'X has uint8 dtype: use estimator.predict(X) if X is '
'pre-binned, or convert X to a float32 dtype to be treated '
'as numerical data'
)

out = np.empty(X.shape[0], dtype=np.float32)
_predict_from_numeric_data(self.nodes, X, out)
return out
Expand Down
15 changes: 8 additions & 7 deletions tests/test_binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def test_bin_mapper_random_data(n_bins):
assert binned.dtype == np.uint8
assert_array_equal(binned.min(axis=0), np.array([0, 0]))
assert_array_equal(binned.max(axis=0), np.array([n_bins - 1, n_bins - 1]))
assert len(mapper.bin_thresholds_) == n_features
for i in range(len(mapper.bin_thresholds_)):
assert mapper.bin_thresholds_[i].shape == (n_bins - 1,)
assert mapper.bin_thresholds_[i].dtype == DATA.dtype
assert len(mapper.numerical_thresholds_) == n_features
for i in range(len(mapper.numerical_thresholds_)):
assert mapper.numerical_thresholds_[i].shape == (n_bins - 1,)
assert mapper.numerical_thresholds_[i].dtype == DATA.dtype
assert np.all(mapper.n_bins_per_feature_ == n_bins)

# Check that the binned data is approximately balanced across bins.
Expand Down Expand Up @@ -159,7 +159,8 @@ def test_bin_mapper_repeated_values_invariance(n_distinct):
mapper_2 = BinMapper(max_bins=min(256, n_distinct * 3))
binned_2 = mapper_2.fit_transform(data)

assert_allclose(mapper_1.bin_thresholds_[0], mapper_2.bin_thresholds_[0])
assert_allclose(mapper_1.numerical_thresholds_[0],
mapper_2.numerical_thresholds_[0])
assert_array_equal(binned_1, binned_2)


Expand Down Expand Up @@ -214,7 +215,7 @@ def test_subsample():
for feature in range(DATA.shape[1]):
with pytest.raises(AssertionError):
np.testing.assert_array_almost_equal(
mapper_no_subsample.bin_thresholds_[feature],
mapper_subsample.bin_thresholds_[feature],
mapper_no_subsample.numerical_thresholds_[feature],
mapper_subsample.numerical_thresholds_[feature],
decimal=3
)
12 changes: 9 additions & 3 deletions tests/test_compare_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples,
n_informative=5, random_state=0)

if n_samples > 255:
X = BinMapper(max_bins=max_bins).fit_transform(X)
# bin data and convert it to float32 so that the estimator doesn't
# treat it as pre-binned
X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

Expand Down Expand Up @@ -94,7 +96,9 @@ def test_same_predictions_classification(seed, min_samples_leaf, n_samples,
n_informative=5, n_redundant=0, random_state=0)

if n_samples > 255:
X = BinMapper(max_bins=max_bins).fit_transform(X)
# bin data and convert it to float32 so that the estimator doesn't
# treat it as pre-binned
X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

Expand Down Expand Up @@ -153,7 +157,9 @@ def test_same_predictions_multiclass_classification(
n_clusters_per_class=1, random_state=0)

if n_samples > 255:
X = BinMapper(max_bins=max_bins).fit_transform(X)
# bin data and convert it to float32 so that the estimator doesn't
# treat it as pre-binned
X = BinMapper(max_bins=max_bins).fit_transform(X).astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)

Expand Down
Loading