-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Single threaded split #85
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -6,7 +6,7 @@ | |||||||
into the newly created left and right childs. | ||||||||
""" | ||||||||
import numpy as np | ||||||||
from numba import njit, jitclass, prange, float32, uint8, uint32 | ||||||||
from numba import njit, jitclass, prange, float32, uint8, uint32, bool_ | ||||||||
import numba | ||||||||
|
||||||||
from .histogram import _build_histogram | ||||||||
|
@@ -88,6 +88,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0, | |||||||
('partition', uint32[::1]), | ||||||||
('left_indices_buffer', uint32[::1]), | ||||||||
('right_indices_buffer', uint32[::1]), | ||||||||
('parallel_splitting', bool_) | ||||||||
]) | ||||||||
class SplittingContext: | ||||||||
"""Pure data class defining a splitting context. | ||||||||
|
@@ -128,7 +129,7 @@ class SplittingContext: | |||||||
def __init__(self, X_binned, max_bins, n_bins_per_feature, | ||||||||
gradients, hessians, l2_regularization, | ||||||||
min_hessian_to_split=1e-3, min_samples_leaf=20, | ||||||||
min_gain_to_split=0.): | ||||||||
min_gain_to_split=0., parallel_splitting=True): | ||||||||
|
||||||||
self.X_binned = X_binned | ||||||||
self.n_features = X_binned.shape[1] | ||||||||
|
@@ -148,6 +149,7 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, | |||||||
self.min_hessian_to_split = min_hessian_to_split | ||||||||
self.min_samples_leaf = min_samples_leaf | ||||||||
self.min_gain_to_split = min_gain_to_split | ||||||||
self.parallel_splitting = parallel_splitting | ||||||||
if self.constant_hessian: | ||||||||
self.constant_hessian_value = self.hessians[0] # 1 scalar | ||||||||
else: | ||||||||
|
@@ -163,39 +165,45 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, | |||||||
# we have 2 leaves, the left one is at position 0 and the second one at | ||||||||
# position 3. The order of the samples is irrelevant. | ||||||||
self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32) | ||||||||
# buffers used in split_indices to support parallel splitting. | ||||||||
self.left_indices_buffer = np.empty_like(self.partition) | ||||||||
self.right_indices_buffer = np.empty_like(self.partition) | ||||||||
if self.parallel_splitting: | ||||||||
# buffers used in split_indices_parallel to support parallel | ||||||||
# splitting. | ||||||||
self.right_indices_buffer = np.empty_like(self.partition) | ||||||||
self.left_indices_buffer = np.empty_like(self.partition) | ||||||||
|
||||||||
def split_indices(self, split_info, sample_indices): | ||||||||
"""Split samples into left and right arrays. | ||||||||
|
||||||||
Parameters | ||||||||
---------- | ||||||||
split_info : SplitInfo | ||||||||
The SplitInfo of the node to split | ||||||||
sample_indices : array of int | ||||||||
The indices of the samples at the node to split. This is a view on | ||||||||
context.partition, and it is modified inplace by placing the indices | ||||||||
of the left child at the beginning, and the indices of the right child | ||||||||
at the end. | ||||||||
|
||||||||
Returns | ||||||||
------- | ||||||||
left_indices : array of int | ||||||||
The indices of the samples in the left child. This is a view on | ||||||||
context.partition. | ||||||||
right_indices : array of int | ||||||||
The indices of the samples in the right child. This is a view on | ||||||||
context.partition. | ||||||||
""" | ||||||||
if self.parallel_splitting: | ||||||||
return _split_indices_parallel(self, split_info, sample_indices) | ||||||||
else: | ||||||||
return _split_indices_single_threaded(self, split_info, sample_indices) | ||||||||
|
||||||||
|
||||||||
@njit(parallel=True, | ||||||||
locals={'sample_idx': uint32, | ||||||||
'left_count': uint32, | ||||||||
'right_count': uint32}) | ||||||||
def split_indices(context, split_info, sample_indices): | ||||||||
"""Split samples into left and right arrays. | ||||||||
|
||||||||
Parameters | ||||||||
---------- | ||||||||
context : SplittingContext | ||||||||
The splitting context | ||||||||
split_ingo : SplitInfo | ||||||||
The SplitInfo of the node to split | ||||||||
sample_indices : array of int | ||||||||
The indices of the samples at the node to split. This is a view on | ||||||||
context.partition, and it is modified inplace by placing the indices | ||||||||
of the left child at the beginning, and the indices of the right child | ||||||||
at the end. | ||||||||
|
||||||||
Returns | ||||||||
------- | ||||||||
left_indices : array of int | ||||||||
The indices of the samples in the left child. This is a view on | ||||||||
context.partition. | ||||||||
right_indices : array of int | ||||||||
The indices of the samples in the right child. This is a view on | ||||||||
context.partition. | ||||||||
""" | ||||||||
def _split_indices_parallel(context, split_info, sample_indices): | ||||||||
# This is a multi-threaded implementation inspired by lightgbm. | ||||||||
# Here is a quick break down. Let's suppose we want to split a node with | ||||||||
# 24 samples named from a to x. context.partition looks like this (the * | ||||||||
|
@@ -304,6 +312,31 @@ def split_indices(context, split_info, sample_indices): | |||||||
return (sample_indices[:right_child_position], | ||||||||
sample_indices[right_child_position:]) | ||||||||
|
||||||||
@njit(parallel=False) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can just use
Suggested change
|
||||||||
def _split_indices_single_threaded(context, split_info, sample_indices): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
binned_feature = context.X_binned.T[split_info.feature_idx] | ||||||||
n_samples = sample_indices.shape[0] | ||||||||
# approach from left with i | ||||||||
i = 0 | ||||||||
# approach from right with j | ||||||||
j = n_samples - 1 | ||||||||
pivot = split_info.bin_idx | ||||||||
while i != j: | ||||||||
# continue until we find an element that should be on right | ||||||||
while binned_feature[sample_indices[i]] <= pivot and i < n_samples: | ||||||||
i += 1 | ||||||||
# same, but now an element that should be on the left | ||||||||
while binned_feature[sample_indices[j]] > pivot and j >= 0: | ||||||||
j -= 1 | ||||||||
if i >= j: # j can become smaller than j! | ||||||||
break | ||||||||
else: | ||||||||
# swap | ||||||||
sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i] | ||||||||
i += 1 | ||||||||
j -= 1 | ||||||||
return (sample_indices[:i], sample_indices[i:]) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need for parenthesis |
||||||||
|
||||||||
|
||||||||
@njit(parallel=True) | ||||||||
def find_node_split(context, sample_indices): | ||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from pygbm.splitting import _find_histogram_split | ||
from pygbm.splitting import (SplittingContext, find_node_split, | ||
find_node_split_subtraction, | ||
split_indices) | ||
_split_indices_parallel, _split_indices_single_threaded) | ||
|
||
|
||
@pytest.mark.parametrize('n_bins', [3, 32, 256]) | ||
|
@@ -39,7 +39,8 @@ def test_histogram_split(n_bins): | |
all_gradients, all_hessians, | ||
l2_regularization, | ||
min_hessian_to_split, | ||
min_samples_leaf, min_gain_to_split) | ||
min_samples_leaf, min_gain_to_split, | ||
True) | ||
|
||
split_info, _ = _find_histogram_split(context, feature_idx, | ||
sample_indices) | ||
|
@@ -85,7 +86,7 @@ def test_split_vs_split_subtraction(constant_hessian): | |
n_bins_per_feature, | ||
all_gradients, all_hessians, | ||
l2_regularization, min_hessian_to_split, | ||
min_samples_leaf, min_gain_to_split) | ||
min_samples_leaf, min_gain_to_split, True) | ||
|
||
mask = rng.randint(0, 2, n_samples).astype(np.bool) | ||
sample_indices_left = sample_indices[mask] | ||
|
@@ -165,7 +166,7 @@ def test_gradient_and_hessian_sanity(constant_hessian): | |
n_bins_per_feature, | ||
all_gradients, all_hessians, | ||
l2_regularization, min_hessian_to_split, | ||
min_samples_leaf, min_gain_to_split) | ||
min_samples_leaf, min_gain_to_split, True) | ||
|
||
mask = rng.randint(0, 2, n_samples).astype(np.bool) | ||
sample_indices_left = sample_indices[mask] | ||
|
@@ -261,7 +262,7 @@ def test_split_indices(): | |
n_bins_per_feature, | ||
all_gradients, all_hessians, | ||
l2_regularization, min_hessian_to_split, | ||
min_samples_leaf, min_gain_to_split) | ||
min_samples_leaf, min_gain_to_split, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please parametrize this test with |
||
|
||
assert_array_almost_equal(sample_indices, context.partition) | ||
si_root, _ = find_node_split(context, sample_indices) | ||
|
@@ -270,8 +271,8 @@ def test_split_indices(): | |
assert si_root.feature_idx == 1 | ||
assert si_root.bin_idx == 3 | ||
|
||
samples_left, samples_right = split_indices( | ||
context, si_root, context.partition.view()) | ||
samples_left, samples_right = context.split_indices( | ||
si_root, context.partition.view()) | ||
assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8]) | ||
assert set(samples_right) == set([2, 7, 9]) | ||
|
||
|
@@ -287,6 +288,13 @@ def test_split_indices(): | |
assert samples_left.shape[0] == si_root.n_samples_left | ||
assert samples_right.shape[0] == si_root.n_samples_right | ||
|
||
# test if the single thread version gives the same result | ||
samples_left_single_thread, samples_right_single_thread = \ | ||
_split_indices_single_threaded(context, si_root, context.partition.view()) | ||
|
||
assert samples_left.tolist() == samples_left_single_thread.tolist() | ||
assert samples_right.tolist() == samples_right_single_thread.tolist() | ||
|
||
|
||
def test_min_gain_to_split(): | ||
# Try to split a pure node (all gradients are equal, same for hessians) | ||
|
@@ -314,7 +322,7 @@ def test_min_gain_to_split(): | |
all_gradients, all_hessians, | ||
l2_regularization, | ||
min_hessian_to_split, | ||
min_samples_leaf, min_gain_to_split) | ||
min_samples_leaf, min_gain_to_split, True) | ||
|
||
split_info, _ = _find_histogram_split(context, feature_idx, sample_indices) | ||
assert split_info.gain == -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PEP8: you need 2 blank lines to separate top level functions.