From e6ea8a3b62df2811f3373d021f485d160471b9d0 Mon Sep 17 00:00:00 2001 From: Maarten Breddels <maartenbreddels@gmail.com> Date: Tue, 18 Dec 2018 22:21:26 +0100 Subject: [PATCH 1/3] non-parallel split_indices that require less memory --- benchmarks/bench_higgs_boson.py | 2 +- pygbm/gradient_boosting.py | 15 ++++--- pygbm/grower.py | 10 +++-- pygbm/splitting.py | 69 ++++++++++++++++++++++++++++++--- tests/test_splitting.py | 21 ++++++---- 5 files changed, 94 insertions(+), 23 deletions(-) diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py index d7cb38e..b50132e 100644 --- a/benchmarks/bench_higgs_boson.py +++ b/benchmarks/bench_higgs_boson.py @@ -85,7 +85,7 @@ def load_data(): max_leaf_nodes=n_leaf_nodes, n_iter_no_change=None, random_state=0, - verbose=1) + verbose=1, parallel_splitting=False) pygbm_model.fit(data_train, target_train) toc = time() predicted_test = pygbm_model.predict(data_test) diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py index b70e4ce..2705970 100644 --- a/pygbm/gradient_boosting.py +++ b/pygbm/gradient_boosting.py @@ -26,7 +26,7 @@ class BaseGradientBoostingMachine(BaseEstimator, ABC): def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes, max_depth, min_samples_leaf, l2_regularization, max_bins, scoring, validation_split, n_iter_no_change, tol, verbose, - random_state): + random_state, parallel_splitting): self.loss = loss self.learning_rate = learning_rate self.max_iter = max_iter @@ -41,6 +41,7 @@ def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes, self.tol = tol self.verbose = verbose self.random_state = random_state + self.parallel_splitting = parallel_splitting def _validate_parameters(self, X): """Validate parameters passed to __init__. @@ -249,7 +250,8 @@ def fit(self, X, y): max_depth=self.max_depth, min_samples_leaf=self.min_samples_leaf, l2_regularization=self.l2_regularization, - shrinkage=self.learning_rate) + shrinkage=self.learning_rate, + parallel_splitting=self.parallel_splitting) grower.grow() acc_apply_split_time += grower.total_apply_split_time @@ -524,7 +526,8 @@ def __init__(self, loss='least_squares', learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0., max_bins=256, scoring=None, validation_split=0.1, n_iter_no_change=5, - tol=1e-7, verbose=0, random_state=None): + tol=1e-7, verbose=0, random_state=None, + parallel_splitting=True): super(GradientBoostingRegressor, self).__init__( loss=loss, learning_rate=learning_rate, max_iter=max_iter, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, @@ -532,7 +535,7 @@ def __init__(self, loss='least_squares', learning_rate=0.1, l2_regularization=l2_regularization, max_bins=max_bins, scoring=scoring, validation_split=validation_split, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, - random_state=random_state) + random_state=random_state, parallel_splitting=parallel_splitting) def predict(self, X): """Predict values for X. @@ -644,7 +647,7 @@ def __init__(self, loss='auto', learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, min_samples_leaf=20, l2_regularization=0., max_bins=256, scoring=None, validation_split=0.1, n_iter_no_change=5, tol=1e-7, - verbose=0, random_state=None): + verbose=0, random_state=None, parallel_splitting=True): super(GradientBoostingClassifier, self).__init__( loss=loss, learning_rate=learning_rate, max_iter=max_iter, max_leaf_nodes=max_leaf_nodes, max_depth=max_depth, @@ -652,7 +655,7 @@ def __init__(self, loss='auto', learning_rate=0.1, max_iter=100, l2_regularization=l2_regularization, max_bins=max_bins, scoring=scoring, validation_split=validation_split, n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose, - random_state=random_state) + random_state=random_state, parallel_splitting=parallel_splitting) def predict(self, X): """Predict classes for X. diff --git a/pygbm/grower.py b/pygbm/grower.py index c77d000..7e8c543 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -8,7 +8,8 @@ import numpy as np from time import time -from .splitting import (SplittingContext, split_indices, find_node_split, +from .splitting import (SplittingContext, split_indices_parallel, + split_indices_single_thread, find_node_split, find_node_split_subtraction) from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE @@ -163,7 +164,8 @@ class TreeGrower: def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, max_depth=None, min_samples_leaf=20, min_gain_to_split=0., max_bins=256, n_bins_per_feature=None, l2_regularization=0., - min_hessian_to_split=1e-3, shrinkage=1.): + min_hessian_to_split=1e-3, shrinkage=1., + parallel_splitting=True): self._validate_parameters(X_binned, max_leaf_nodes, max_depth, min_samples_leaf, min_gain_to_split, @@ -180,13 +182,14 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, self.splitting_context = SplittingContext( X_binned, max_bins, n_bins_per_feature, gradients, hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, parallel_splitting) self.max_leaf_nodes = max_leaf_nodes self.max_depth = max_depth self.min_samples_leaf = min_samples_leaf self.X_binned = X_binned self.min_gain_to_split = min_gain_to_split self.shrinkage = shrinkage + self.parallel_splitting = parallel_splitting self.splittable_nodes = [] self.finalized_leaves = [] self.total_find_split_time = 0. # time spent finding the best splits @@ -336,6 +339,7 @@ def split_next(self): node = heappop(self.splittable_nodes) tic = time() + split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread (sample_indices_left, sample_indices_right) = split_indices( self.splitting_context, node.split_info, node.sample_indices) toc = time() diff --git a/pygbm/splitting.py b/pygbm/splitting.py index 56ae412..b574816 100644 --- a/pygbm/splitting.py +++ b/pygbm/splitting.py @@ -6,7 +6,7 @@ into the newly created left and right childs. """ import numpy as np -from numba import njit, jitclass, prange, float32, uint8, uint32 +from numba import njit, jitclass, prange, float32, uint8, uint32, bool_ import numba from .histogram import _build_histogram @@ -88,6 +88,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0, ('partition', uint32[::1]), ('left_indices_buffer', uint32[::1]), ('right_indices_buffer', uint32[::1]), + ('parallel_splitting', bool_) ]) class SplittingContext: """Pure data class defining a splitting context. @@ -128,7 +129,7 @@ class SplittingContext: def __init__(self, X_binned, max_bins, n_bins_per_feature, gradients, hessians, l2_regularization, min_hessian_to_split=1e-3, min_samples_leaf=20, - min_gain_to_split=0.): + min_gain_to_split=0., parallel_splitting=True): self.X_binned = X_binned self.n_features = X_binned.shape[1] @@ -148,6 +149,7 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, self.min_hessian_to_split = min_hessian_to_split self.min_samples_leaf = min_samples_leaf self.min_gain_to_split = min_gain_to_split + self.parallel_splitting = parallel_splitting if self.constant_hessian: self.constant_hessian_value = self.hessians[0] # 1 scalar else: @@ -163,16 +165,18 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, # we have 2 leaves, the left one is at position 0 and the second one at # position 3. The order of the samples is irrelevant. self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32) - # buffers used in split_indices to support parallel splitting. - self.left_indices_buffer = np.empty_like(self.partition) - self.right_indices_buffer = np.empty_like(self.partition) + if self.parallel_splitting: + # buffers used in split_indices_parallel to support parallel + # splitting. + self.right_indices_buffer = np.empty_like(self.partition) + self.left_indices_buffer = np.empty_like(self.partition) @njit(parallel=True, locals={'sample_idx': uint32, 'left_count': uint32, 'right_count': uint32}) -def split_indices(context, split_info, sample_indices): +def split_indices_parallel(context, split_info, sample_indices): """Split samples into left and right arrays. Parameters @@ -304,6 +308,59 @@ def split_indices(context, split_info, sample_indices): return (sample_indices[:right_child_position], sample_indices[right_child_position:]) +@njit(parallel=False) +def split_indices_single_thread(context, split_info, sample_indices): + """Split samples into left and right arrays. + + This implementation requires less memory than the parallel version. + + Parameters + ---------- + context : SplittingContext + The splitting context + split_ingo : SplitInfo + The SplitInfo of the node to split + sample_indices : array of int + The indices of the samples at the node to split. This is a view on + context.partition, and it is modified inplace by placing the indices + of the left child at the beginning, and the indices of the right child + at the end. + + Returns + ------- + left_indices : array of int + The indices of the samples in the left child. This is a view on + context.partition. + right_indices : array of int + The indices of the samples in the right child. This is a view on + context.partition. + """ + X_binned = context.X_binned.T[split_info.feature_idx] + n_samples = sample_indices.shape[0] + + # approach from left with i + i = 0 + # approach from right with j + j = n_samples - 1 + X = X_binned + pivot = split_info.bin_idx + while i != j: + # continue until we find an element that should be on right + while X[sample_indices[i]] <= pivot and i < n_samples: + i += 1 + # same, but now an element that should be on the left + while X[sample_indices[j]] > pivot and j >= 0: + j -= 1 + if i >= j: # j can become smaller than j! + break + else: + # swap + sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i] + i += 1 + j -= 1 + return (sample_indices[:i], + sample_indices[i:]) + @njit(parallel=True) def find_node_split(context, sample_indices): diff --git a/tests/test_splitting.py b/tests/test_splitting.py index 3fc26db..d3e8848 100644 --- a/tests/test_splitting.py +++ b/tests/test_splitting.py @@ -6,7 +6,7 @@ from pygbm.splitting import _find_histogram_split from pygbm.splitting import (SplittingContext, find_node_split, find_node_split_subtraction, - split_indices) + split_indices_parallel, split_indices_single_thread) @pytest.mark.parametrize('n_bins', [3, 32, 256]) @@ -39,7 +39,8 @@ def test_histogram_split(n_bins): all_gradients, all_hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, + True) split_info, _ = _find_histogram_split(context, feature_idx, sample_indices) @@ -85,7 +86,7 @@ def test_split_vs_split_subtraction(constant_hessian): n_bins_per_feature, all_gradients, all_hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, True) mask = rng.randint(0, 2, n_samples).astype(np.bool) sample_indices_left = sample_indices[mask] @@ -165,7 +166,7 @@ def test_gradient_and_hessian_sanity(constant_hessian): n_bins_per_feature, all_gradients, all_hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, True) mask = rng.randint(0, 2, n_samples).astype(np.bool) sample_indices_left = sample_indices[mask] @@ -261,7 +262,7 @@ def test_split_indices(): n_bins_per_feature, all_gradients, all_hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, True) assert_array_almost_equal(sample_indices, context.partition) si_root, _ = find_node_split(context, sample_indices) @@ -270,7 +271,7 @@ def test_split_indices(): assert si_root.feature_idx == 1 assert si_root.bin_idx == 3 - samples_left, samples_right = split_indices( + samples_left, samples_right = split_indices_parallel( context, si_root, context.partition.view()) assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8]) assert set(samples_right) == set([2, 7, 9]) @@ -287,6 +288,12 @@ def test_split_indices(): assert samples_left.shape[0] == si_root.n_samples_left assert samples_right.shape[0] == si_root.n_samples_right + samples_left_single_thread, samples_right_single_thread = split_indices_single_thread( + context, si_root, context.partition.view()) + + assert samples_left.tolist() == samples_left_single_thread.tolist() + assert samples_right.tolist() == samples_right_single_thread.tolist() + def test_min_gain_to_split(): # Try to split a pure node (all gradients are equal, same for hessians) @@ -314,7 +321,7 @@ def test_min_gain_to_split(): all_gradients, all_hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + min_samples_leaf, min_gain_to_split, True) split_info, _ = _find_histogram_split(context, feature_idx, sample_indices) assert split_info.gain == -1 From b46604736c8d4521a99e27b5f9574f7260fc9d14 Mon Sep 17 00:00:00 2001 From: Maarten Breddels <maartenbreddels@gmail.com> Date: Wed, 19 Dec 2018 15:28:56 +0100 Subject: [PATCH 2/3] implemented suggestions --- pygbm/grower.py | 10 ++--- pygbm/splitting.py | 90 +++++++++++++++-------------------------- tests/test_splitting.py | 11 ++--- 3 files changed, 43 insertions(+), 68 deletions(-) diff --git a/pygbm/grower.py b/pygbm/grower.py index 7e8c543..b60f0da 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -8,8 +8,7 @@ import numpy as np from time import time -from .splitting import (SplittingContext, split_indices_parallel, - split_indices_single_thread, find_node_split, +from .splitting import (SplittingContext, find_node_split, find_node_split_subtraction) from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE @@ -189,7 +188,6 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, self.X_binned = X_binned self.min_gain_to_split = min_gain_to_split self.shrinkage = shrinkage - self.parallel_splitting = parallel_splitting self.splittable_nodes = [] self.finalized_leaves = [] self.total_find_split_time = 0. # time spent finding the best splits @@ -339,9 +337,9 @@ def split_next(self): node = heappop(self.splittable_nodes) tic = time() - split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread - (sample_indices_left, sample_indices_right) = split_indices( - self.splitting_context, node.split_info, node.sample_indices) + (sample_indices_left, sample_indices_right) = \ + self.splitting_context.split_indices(node.split_info, + node.sample_indices) toc = time() node.apply_split_time = toc - tic self.total_apply_split_time += node.apply_split_time diff --git a/pygbm/splitting.py b/pygbm/splitting.py index b574816..d9f8e22 100644 --- a/pygbm/splitting.py +++ b/pygbm/splitting.py @@ -171,35 +171,39 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, self.right_indices_buffer = np.empty_like(self.partition) self.left_indices_buffer = np.empty_like(self.partition) + def split_indices(self, split_info, sample_indices): + """Split samples into left and right arrays. + + Parameters + ---------- + split_info : SplitInfo + The SplitInfo of the node to split + sample_indices : array of int + The indices of the samples at the node to split. This is a view on + context.partition, and it is modified inplace by placing the indices + of the left child at the beginning, and the indices of the right child + at the end. + + Returns + ------- + left_indices : array of int + The indices of the samples in the left child. This is a view on + context.partition. + right_indices : array of int + The indices of the samples in the right child. This is a view on + context.partition. + """ + if self.parallel_splitting: + return _split_indices_parallel(self, split_info, sample_indices) + else: + return _split_indices_single_threaded(self, split_info, sample_indices) + @njit(parallel=True, locals={'sample_idx': uint32, 'left_count': uint32, 'right_count': uint32}) -def split_indices_parallel(context, split_info, sample_indices): - """Split samples into left and right arrays. - - Parameters - ---------- - context : SplittingContext - The splitting context - split_ingo : SplitInfo - The SplitInfo of the node to split - sample_indices : array of int - The indices of the samples at the node to split. This is a view on - context.partition, and it is modified inplace by placing the indices - of the left child at the beginning, and the indices of the right child - at the end. - - Returns - ------- - left_indices : array of int - The indices of the samples in the left child. This is a view on - context.partition. - right_indices : array of int - The indices of the samples in the right child. This is a view on - context.partition. - """ +def _split_indices_parallel(context, split_info, sample_indices): # This is a multi-threaded implementation inspired by lightgbm. # Here is a quick break down. Let's suppose we want to split a node with # 24 samples named from a to x. context.partition looks like this (the * @@ -309,47 +313,20 @@ def split_indices_parallel(context, split_info, sample_indices): sample_indices[right_child_position:]) @njit(parallel=False) -def split_indices_single_thread(context, split_info, sample_indices): - """Split samples into left and right arrays. - - This implementation requires less memory than the parallel version. - - Parameters - ---------- - context : SplittingContext - The splitting context - split_ingo : SplitInfo - The SplitInfo of the node to split - sample_indices : array of int - The indices of the samples at the node to split. This is a view on - context.partition, and it is modified inplace by placing the indices - of the left child at the beginning, and the indices of the right child - at the end. - - Returns - ------- - left_indices : array of int - The indices of the samples in the left child. This is a view on - context.partition. - right_indices : array of int - The indices of the samples in the right child. This is a view on - context.partition. - """ - X_binned = context.X_binned.T[split_info.feature_idx] +def _split_indices_single_threaded(context, split_info, sample_indices): + binned_feature = context.X_binned.T[split_info.feature_idx] n_samples = sample_indices.shape[0] - # approach from left with i i = 0 # approach from right with j j = n_samples - 1 - X = X_binned pivot = split_info.bin_idx while i != j: # continue until we find an element that should be on right - while X[sample_indices[i]] <= pivot and i < n_samples: + while binned_feature[sample_indices[i]] <= pivot and i < n_samples: i += 1 # same, but now an element that should be on the left - while X[sample_indices[j]] > pivot and j >= 0: + while binned_feature[sample_indices[j]] > pivot and j >= 0: j -= 1 if i >= j: # j can become smaller than j! break @@ -358,8 +335,7 @@ def split_indices_single_thread(context, split_info, sample_indices): sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i] i += 1 j -= 1 - return (sample_indices[:i], - sample_indices[i:]) + return (sample_indices[:i], sample_indices[i:]) @njit(parallel=True) diff --git a/tests/test_splitting.py b/tests/test_splitting.py index d3e8848..8dbbda6 100644 --- a/tests/test_splitting.py +++ b/tests/test_splitting.py @@ -6,7 +6,7 @@ from pygbm.splitting import _find_histogram_split from pygbm.splitting import (SplittingContext, find_node_split, find_node_split_subtraction, - split_indices_parallel, split_indices_single_thread) + _split_indices_parallel, _split_indices_single_threaded) @pytest.mark.parametrize('n_bins', [3, 32, 256]) @@ -271,8 +271,8 @@ def test_split_indices(): assert si_root.feature_idx == 1 assert si_root.bin_idx == 3 - samples_left, samples_right = split_indices_parallel( - context, si_root, context.partition.view()) + samples_left, samples_right = context.split_indices( + si_root, context.partition.view()) assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8]) assert set(samples_right) == set([2, 7, 9]) @@ -288,8 +288,9 @@ def test_split_indices(): assert samples_left.shape[0] == si_root.n_samples_left assert samples_right.shape[0] == si_root.n_samples_right - samples_left_single_thread, samples_right_single_thread = split_indices_single_thread( - context, si_root, context.partition.view()) + # test if the single thread version gives the same result + samples_left_single_thread, samples_right_single_thread = \ + _split_indices_single_threaded(context, si_root, context.partition.view()) assert samples_left.tolist() == samples_left_single_thread.tolist() assert samples_right.tolist() == samples_right_single_thread.tolist() From e84cdc31e035f6b2872fcd49637af499ce40f932 Mon Sep 17 00:00:00 2001 From: Maarten Breddels <maartenbreddels@gmail.com> Date: Wed, 19 Dec 2018 15:29:27 +0100 Subject: [PATCH 3/3] add --no-parallel-split argument to higg benchmark --- benchmarks/bench_higgs_boson.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py index b50132e..93110c0 100644 --- a/benchmarks/bench_higgs_boson.py +++ b/benchmarks/bench_higgs_boson.py @@ -21,6 +21,7 @@ parser.add_argument('--learning-rate', type=float, default=1.) parser.add_argument('--subsample', type=int, default=None) parser.add_argument('--max-bins', type=int, default=255) +parser.add_argument('--no-parallel-split', action="store_true", default=False) args = parser.parse_args() HERE = os.path.dirname(__file__) @@ -85,7 +86,8 @@ def load_data(): max_leaf_nodes=n_leaf_nodes, n_iter_no_change=None, random_state=0, - verbose=1, parallel_splitting=False) + verbose=1, + parallel_splitting=not args.no_parallel_split) pygbm_model.fit(data_train, target_train) toc = time() predicted_test = pygbm_model.predict(data_test)