Skip to content

Commit

Permalink
non-parallel split_indices that require less memory
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Dec 18, 2018
1 parent 23925b7 commit d8585bb
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 23 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_higgs_boson.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_data():
max_leaf_nodes=n_leaf_nodes,
n_iter_no_change=None,
random_state=0,
verbose=1)
verbose=1, parallel_splitting=False)
pygbm_model.fit(data_train, target_train)
toc = time()
predicted_test = pygbm_model.predict(data_test)
Expand Down
15 changes: 9 additions & 6 deletions pygbm/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BaseGradientBoostingMachine(BaseEstimator, ABC):
def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
max_depth, min_samples_leaf, l2_regularization, max_bins,
scoring, validation_split, n_iter_no_change, tol, verbose,
random_state):
random_state, parallel_splitting):
self.loss = loss
self.learning_rate = learning_rate
self.max_iter = max_iter
Expand All @@ -41,6 +41,7 @@ def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
self.tol = tol
self.verbose = verbose
self.random_state = random_state
self.parallel_splitting = parallel_splitting

def _validate_parameters(self):
"""Validate parameters passed to __init__.
Expand Down Expand Up @@ -228,7 +229,8 @@ def fit(self, X, y):
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
l2_regularization=self.l2_regularization,
shrinkage=self.learning_rate)
shrinkage=self.learning_rate,
parallel_splitting=self.parallel_splitting)
grower.grow()

acc_apply_split_time += grower.total_apply_split_time
Expand Down Expand Up @@ -495,15 +497,16 @@ def __init__(self, loss='least_squares', learning_rate=0.1,
max_iter=100, max_leaf_nodes=31, max_depth=None,
min_samples_leaf=20, l2_regularization=0., max_bins=256,
scoring=None, validation_split=0.1, n_iter_no_change=5,
tol=1e-7, verbose=0, random_state=None):
tol=1e-7, verbose=0, random_state=None,
parallel_splitting=True):
super(GradientBoostingRegressor, self).__init__(
loss=loss, learning_rate=learning_rate, max_iter=max_iter,
max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization, max_bins=max_bins,
scoring=scoring, validation_split=validation_split,
n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
random_state=random_state)
random_state=random_state, parallel_splitting=parallel_splitting)

def predict(self, X):
"""Predict values for X.
Expand Down Expand Up @@ -614,15 +617,15 @@ def __init__(self, loss='auto', learning_rate=0.1, max_iter=100,
max_leaf_nodes=31, max_depth=None, min_samples_leaf=20,
l2_regularization=0., max_bins=256, scoring=None,
validation_split=0.1, n_iter_no_change=5, tol=1e-7,
verbose=0, random_state=None):
verbose=0, random_state=None, parallel_splitting=True):
super(GradientBoostingClassifier, self).__init__(
loss=loss, learning_rate=learning_rate, max_iter=max_iter,
max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
l2_regularization=l2_regularization, max_bins=max_bins,
scoring=scoring, validation_split=validation_split,
n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
random_state=random_state)
random_state=random_state, parallel_splitting=parallel_splitting)

def predict(self, X):
"""Predict classes for X.
Expand Down
10 changes: 7 additions & 3 deletions pygbm/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
from time import time

from .splitting import (SplittingContext, split_indices, find_node_split,
from .splitting import (SplittingContext, split_indices_parallel,
split_indices_single_thread, find_node_split,
find_node_split_subtraction)
from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE

Expand Down Expand Up @@ -163,7 +164,8 @@ class TreeGrower:
def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
max_depth=None, min_samples_leaf=20, min_gain_to_split=0.,
max_bins=256, n_bins_per_feature=None, l2_regularization=0.,
min_hessian_to_split=1e-3, shrinkage=1.):
min_hessian_to_split=1e-3, shrinkage=1.,
parallel_splitting=True):

self._validate_parameters(X_binned, max_leaf_nodes, max_depth,
min_samples_leaf, min_gain_to_split,
Expand All @@ -180,13 +182,14 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
self.splitting_context = SplittingContext(
X_binned, max_bins, n_bins_per_feature, gradients,
hessians, l2_regularization, min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split, parallel_splitting)
self.max_leaf_nodes = max_leaf_nodes
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.X_binned = X_binned
self.min_gain_to_split = min_gain_to_split
self.shrinkage = shrinkage
self.parallel_splitting = parallel_splitting
self.splittable_nodes = []
self.finalized_leaves = []
self.total_find_split_time = 0. # time spent finding the best splits
Expand Down Expand Up @@ -336,6 +339,7 @@ def split_next(self):
node = heappop(self.splittable_nodes)

tic = time()
split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread
(sample_indices_left, sample_indices_right) = split_indices(
self.splitting_context, node.split_info, node.sample_indices)
toc = time()
Expand Down
69 changes: 63 additions & 6 deletions pygbm/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
into the newly created left and right childs.
"""
import numpy as np
from numba import njit, jitclass, prange, float32, uint8, uint32
from numba import njit, jitclass, prange, float32, uint8, uint32, bool_
import numba

from .histogram import _build_histogram
Expand Down Expand Up @@ -88,6 +88,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0,
('partition', uint32[::1]),
('left_indices_buffer', uint32[::1]),
('right_indices_buffer', uint32[::1]),
('parallel_splitting', bool_)
])
class SplittingContext:
"""Pure data class defining a splitting context.
Expand Down Expand Up @@ -128,7 +129,7 @@ class SplittingContext:
def __init__(self, X_binned, max_bins, n_bins_per_feature,
gradients, hessians, l2_regularization,
min_hessian_to_split=1e-3, min_samples_leaf=20,
min_gain_to_split=0.):
min_gain_to_split=0., parallel_splitting=True):

self.X_binned = X_binned
self.n_features = X_binned.shape[1]
Expand All @@ -148,6 +149,7 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
self.min_hessian_to_split = min_hessian_to_split
self.min_samples_leaf = min_samples_leaf
self.min_gain_to_split = min_gain_to_split
self.parallel_splitting = parallel_splitting
if self.constant_hessian:
self.constant_hessian_value = self.hessians[0] # 1 scalar
else:
Expand All @@ -163,16 +165,18 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
# we have 2 leaves, the left one is at position 0 and the second one at
# position 3. The order of the samples is irrelevant.
self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32)
# buffers used in split_indices to support parallel splitting.
self.left_indices_buffer = np.empty_like(self.partition)
self.right_indices_buffer = np.empty_like(self.partition)
if self.parallel_splitting:
# buffers used in split_indices_parallel to support parallel
# splitting.
self.right_indices_buffer = np.empty_like(self.partition)
self.left_indices_buffer = np.empty_like(self.partition)


@njit(parallel=True,
locals={'sample_idx': uint32,
'left_count': uint32,
'right_count': uint32})
def split_indices(context, split_info, sample_indices):
def split_indices_parallel(context, split_info, sample_indices):
"""Split samples into left and right arrays.
Parameters
Expand Down Expand Up @@ -304,6 +308,59 @@ def split_indices(context, split_info, sample_indices):
return (sample_indices[:right_child_position],
sample_indices[right_child_position:])

@njit(parallel=False)
def split_indices_single_thread(context, split_info, sample_indices):
"""Split samples into left and right arrays.
This implementation requires less memory than the parallel version.
Parameters
----------
context : SplittingContext
The splitting context
split_ingo : SplitInfo
The SplitInfo of the node to split
sample_indices : array of int
The indices of the samples at the node to split. This is a view on
context.partition, and it is modified inplace by placing the indices
of the left child at the beginning, and the indices of the right child
at the end.
Returns
-------
left_indices : array of int
The indices of the samples in the left child. This is a view on
context.partition.
right_indices : array of int
The indices of the samples in the right child. This is a view on
context.partition.
"""
X_binned = context.X_binned.T[split_info.feature_idx]
n_samples = sample_indices.shape[0]

# approach from left with i
i = 0
# approach from right with j
j = n_samples - 1
X = X_binned
pivot = split_info.bin_idx
while i != j:
# continue until we find an element that should be on right
while X[sample_indices[i]] <= pivot and i < n_samples:
i += 1
# same, but now an element that should be on the left
while X[sample_indices[j]] > pivot and j >= 0:
j -= 1
if i >= j: # j can become smaller than j!
break
else:
# swap
sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i]
i += 1
j -= 1
return (sample_indices[:i],
sample_indices[i:])


@njit(parallel=True)
def find_node_split(context, sample_indices):
Expand Down
21 changes: 14 additions & 7 deletions tests/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pygbm.splitting import _find_histogram_split
from pygbm.splitting import (SplittingContext, find_node_split,
find_node_split_subtraction,
split_indices)
split_indices_parallel, split_indices_single_thread)


@pytest.mark.parametrize('n_bins', [3, 32, 256])
Expand Down Expand Up @@ -39,7 +39,8 @@ def test_histogram_split(n_bins):
all_gradients, all_hessians,
l2_regularization,
min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split,
True)

split_info, _ = _find_histogram_split(context, feature_idx,
sample_indices)
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_split_vs_split_subtraction(constant_hessian):
n_bins_per_feature,
all_gradients, all_hessians,
l2_regularization, min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split, True)

mask = rng.randint(0, 2, n_samples).astype(np.bool)
sample_indices_left = sample_indices[mask]
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_gradient_and_hessian_sanity(constant_hessian):
n_bins_per_feature,
all_gradients, all_hessians,
l2_regularization, min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split, True)

mask = rng.randint(0, 2, n_samples).astype(np.bool)
sample_indices_left = sample_indices[mask]
Expand Down Expand Up @@ -261,7 +262,7 @@ def test_split_indices():
n_bins_per_feature,
all_gradients, all_hessians,
l2_regularization, min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split, True)

assert_array_almost_equal(sample_indices, context.partition)
si_root, _ = find_node_split(context, sample_indices)
Expand All @@ -270,7 +271,7 @@ def test_split_indices():
assert si_root.feature_idx == 1
assert si_root.bin_idx == 3

samples_left, samples_right = split_indices(
samples_left, samples_right = split_indices_parallel(
context, si_root, context.partition.view())
assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8])
assert set(samples_right) == set([2, 7, 9])
Expand All @@ -287,6 +288,12 @@ def test_split_indices():
assert samples_left.shape[0] == si_root.n_samples_left
assert samples_right.shape[0] == si_root.n_samples_right

samples_left_single_thread, samples_right_single_thread = split_indices_single_thread(
context, si_root, context.partition.view())

assert samples_left.tolist() == samples_left_single_thread.tolist()
assert samples_right.tolist() == samples_right_single_thread.tolist()


def test_min_gain_to_split():
# Try to split a pure node (all gradients are equal, same for hessians)
Expand Down Expand Up @@ -314,7 +321,7 @@ def test_min_gain_to_split():
all_gradients, all_hessians,
l2_regularization,
min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
min_samples_leaf, min_gain_to_split, True)

split_info, _ = _find_histogram_split(context, feature_idx, sample_indices)
assert split_info.gain == -1

0 comments on commit d8585bb

Please sign in to comment.