Skip to content

Commit

Permalink
implemented suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
maartenbreddels committed Dec 19, 2018
1 parent e6ea8a3 commit b466047
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 68 deletions.
10 changes: 4 additions & 6 deletions pygbm/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import numpy as np
from time import time

from .splitting import (SplittingContext, split_indices_parallel,
split_indices_single_thread, find_node_split,
from .splitting import (SplittingContext, find_node_split,
find_node_split_subtraction)
from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE

Expand Down Expand Up @@ -189,7 +188,6 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
self.X_binned = X_binned
self.min_gain_to_split = min_gain_to_split
self.shrinkage = shrinkage
self.parallel_splitting = parallel_splitting
self.splittable_nodes = []
self.finalized_leaves = []
self.total_find_split_time = 0. # time spent finding the best splits
Expand Down Expand Up @@ -339,9 +337,9 @@ def split_next(self):
node = heappop(self.splittable_nodes)

tic = time()
split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread
(sample_indices_left, sample_indices_right) = split_indices(
self.splitting_context, node.split_info, node.sample_indices)
(sample_indices_left, sample_indices_right) = \
self.splitting_context.split_indices(node.split_info,
node.sample_indices)
toc = time()
node.apply_split_time = toc - tic
self.total_apply_split_time += node.apply_split_time
Expand Down
90 changes: 33 additions & 57 deletions pygbm/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,39 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
self.right_indices_buffer = np.empty_like(self.partition)
self.left_indices_buffer = np.empty_like(self.partition)

def split_indices(self, split_info, sample_indices):
"""Split samples into left and right arrays.
Parameters
----------
split_info : SplitInfo
The SplitInfo of the node to split
sample_indices : array of int
The indices of the samples at the node to split. This is a view on
context.partition, and it is modified inplace by placing the indices
of the left child at the beginning, and the indices of the right child
at the end.
Returns
-------
left_indices : array of int
The indices of the samples in the left child. This is a view on
context.partition.
right_indices : array of int
The indices of the samples in the right child. This is a view on
context.partition.
"""
if self.parallel_splitting:
return _split_indices_parallel(self, split_info, sample_indices)
else:
return _split_indices_single_threaded(self, split_info, sample_indices)


@njit(parallel=True,
locals={'sample_idx': uint32,
'left_count': uint32,
'right_count': uint32})
def split_indices_parallel(context, split_info, sample_indices):
"""Split samples into left and right arrays.
Parameters
----------
context : SplittingContext
The splitting context
split_ingo : SplitInfo
The SplitInfo of the node to split
sample_indices : array of int
The indices of the samples at the node to split. This is a view on
context.partition, and it is modified inplace by placing the indices
of the left child at the beginning, and the indices of the right child
at the end.
Returns
-------
left_indices : array of int
The indices of the samples in the left child. This is a view on
context.partition.
right_indices : array of int
The indices of the samples in the right child. This is a view on
context.partition.
"""
def _split_indices_parallel(context, split_info, sample_indices):
# This is a multi-threaded implementation inspired by lightgbm.
# Here is a quick break down. Let's suppose we want to split a node with
# 24 samples named from a to x. context.partition looks like this (the *
Expand Down Expand Up @@ -309,47 +313,20 @@ def split_indices_parallel(context, split_info, sample_indices):
sample_indices[right_child_position:])

@njit(parallel=False)
def split_indices_single_thread(context, split_info, sample_indices):
"""Split samples into left and right arrays.
This implementation requires less memory than the parallel version.
Parameters
----------
context : SplittingContext
The splitting context
split_ingo : SplitInfo
The SplitInfo of the node to split
sample_indices : array of int
The indices of the samples at the node to split. This is a view on
context.partition, and it is modified inplace by placing the indices
of the left child at the beginning, and the indices of the right child
at the end.
Returns
-------
left_indices : array of int
The indices of the samples in the left child. This is a view on
context.partition.
right_indices : array of int
The indices of the samples in the right child. This is a view on
context.partition.
"""
X_binned = context.X_binned.T[split_info.feature_idx]
def _split_indices_single_threaded(context, split_info, sample_indices):
binned_feature = context.X_binned.T[split_info.feature_idx]
n_samples = sample_indices.shape[0]

# approach from left with i
i = 0
# approach from right with j
j = n_samples - 1
X = X_binned
pivot = split_info.bin_idx
while i != j:
# continue until we find an element that should be on right
while X[sample_indices[i]] <= pivot and i < n_samples:
while binned_feature[sample_indices[i]] <= pivot and i < n_samples:
i += 1
# same, but now an element that should be on the left
while X[sample_indices[j]] > pivot and j >= 0:
while binned_feature[sample_indices[j]] > pivot and j >= 0:
j -= 1
if i >= j: # j can become smaller than j!
break
Expand All @@ -358,8 +335,7 @@ def split_indices_single_thread(context, split_info, sample_indices):
sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i]
i += 1
j -= 1
return (sample_indices[:i],
sample_indices[i:])
return (sample_indices[:i], sample_indices[i:])


@njit(parallel=True)
Expand Down
11 changes: 6 additions & 5 deletions tests/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pygbm.splitting import _find_histogram_split
from pygbm.splitting import (SplittingContext, find_node_split,
find_node_split_subtraction,
split_indices_parallel, split_indices_single_thread)
_split_indices_parallel, _split_indices_single_threaded)


@pytest.mark.parametrize('n_bins', [3, 32, 256])
Expand Down Expand Up @@ -271,8 +271,8 @@ def test_split_indices():
assert si_root.feature_idx == 1
assert si_root.bin_idx == 3

samples_left, samples_right = split_indices_parallel(
context, si_root, context.partition.view())
samples_left, samples_right = context.split_indices(
si_root, context.partition.view())
assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8])
assert set(samples_right) == set([2, 7, 9])

Expand All @@ -288,8 +288,9 @@ def test_split_indices():
assert samples_left.shape[0] == si_root.n_samples_left
assert samples_right.shape[0] == si_root.n_samples_right

samples_left_single_thread, samples_right_single_thread = split_indices_single_thread(
context, si_root, context.partition.view())
# test if the single thread version gives the same result
samples_left_single_thread, samples_right_single_thread = \
_split_indices_single_threaded(context, si_root, context.partition.view())

assert samples_left.tolist() == samples_left_single_thread.tolist()
assert samples_right.tolist() == samples_right_single_thread.tolist()
Expand Down

0 comments on commit b466047

Please sign in to comment.