From b46604736c8d4521a99e27b5f9574f7260fc9d14 Mon Sep 17 00:00:00 2001 From: Maarten Breddels Date: Wed, 19 Dec 2018 15:28:56 +0100 Subject: [PATCH] implemented suggestions --- pygbm/grower.py | 10 ++--- pygbm/splitting.py | 90 +++++++++++++++-------------------------- tests/test_splitting.py | 11 ++--- 3 files changed, 43 insertions(+), 68 deletions(-) diff --git a/pygbm/grower.py b/pygbm/grower.py index 7e8c543..b60f0da 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -8,8 +8,7 @@ import numpy as np from time import time -from .splitting import (SplittingContext, split_indices_parallel, - split_indices_single_thread, find_node_split, +from .splitting import (SplittingContext, find_node_split, find_node_split_subtraction) from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE @@ -189,7 +188,6 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, self.X_binned = X_binned self.min_gain_to_split = min_gain_to_split self.shrinkage = shrinkage - self.parallel_splitting = parallel_splitting self.splittable_nodes = [] self.finalized_leaves = [] self.total_find_split_time = 0. # time spent finding the best splits @@ -339,9 +337,9 @@ def split_next(self): node = heappop(self.splittable_nodes) tic = time() - split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread - (sample_indices_left, sample_indices_right) = split_indices( - self.splitting_context, node.split_info, node.sample_indices) + (sample_indices_left, sample_indices_right) = \ + self.splitting_context.split_indices(node.split_info, + node.sample_indices) toc = time() node.apply_split_time = toc - tic self.total_apply_split_time += node.apply_split_time diff --git a/pygbm/splitting.py b/pygbm/splitting.py index b574816..d9f8e22 100644 --- a/pygbm/splitting.py +++ b/pygbm/splitting.py @@ -171,35 +171,39 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, self.right_indices_buffer = np.empty_like(self.partition) self.left_indices_buffer = np.empty_like(self.partition) + def split_indices(self, split_info, sample_indices): + """Split samples into left and right arrays. + + Parameters + ---------- + split_info : SplitInfo + The SplitInfo of the node to split + sample_indices : array of int + The indices of the samples at the node to split. This is a view on + context.partition, and it is modified inplace by placing the indices + of the left child at the beginning, and the indices of the right child + at the end. + + Returns + ------- + left_indices : array of int + The indices of the samples in the left child. This is a view on + context.partition. + right_indices : array of int + The indices of the samples in the right child. This is a view on + context.partition. + """ + if self.parallel_splitting: + return _split_indices_parallel(self, split_info, sample_indices) + else: + return _split_indices_single_threaded(self, split_info, sample_indices) + @njit(parallel=True, locals={'sample_idx': uint32, 'left_count': uint32, 'right_count': uint32}) -def split_indices_parallel(context, split_info, sample_indices): - """Split samples into left and right arrays. - - Parameters - ---------- - context : SplittingContext - The splitting context - split_ingo : SplitInfo - The SplitInfo of the node to split - sample_indices : array of int - The indices of the samples at the node to split. This is a view on - context.partition, and it is modified inplace by placing the indices - of the left child at the beginning, and the indices of the right child - at the end. - - Returns - ------- - left_indices : array of int - The indices of the samples in the left child. This is a view on - context.partition. - right_indices : array of int - The indices of the samples in the right child. This is a view on - context.partition. - """ +def _split_indices_parallel(context, split_info, sample_indices): # This is a multi-threaded implementation inspired by lightgbm. # Here is a quick break down. Let's suppose we want to split a node with # 24 samples named from a to x. context.partition looks like this (the * @@ -309,47 +313,20 @@ def split_indices_parallel(context, split_info, sample_indices): sample_indices[right_child_position:]) @njit(parallel=False) -def split_indices_single_thread(context, split_info, sample_indices): - """Split samples into left and right arrays. - - This implementation requires less memory than the parallel version. - - Parameters - ---------- - context : SplittingContext - The splitting context - split_ingo : SplitInfo - The SplitInfo of the node to split - sample_indices : array of int - The indices of the samples at the node to split. This is a view on - context.partition, and it is modified inplace by placing the indices - of the left child at the beginning, and the indices of the right child - at the end. - - Returns - ------- - left_indices : array of int - The indices of the samples in the left child. This is a view on - context.partition. - right_indices : array of int - The indices of the samples in the right child. This is a view on - context.partition. - """ - X_binned = context.X_binned.T[split_info.feature_idx] +def _split_indices_single_threaded(context, split_info, sample_indices): + binned_feature = context.X_binned.T[split_info.feature_idx] n_samples = sample_indices.shape[0] - # approach from left with i i = 0 # approach from right with j j = n_samples - 1 - X = X_binned pivot = split_info.bin_idx while i != j: # continue until we find an element that should be on right - while X[sample_indices[i]] <= pivot and i < n_samples: + while binned_feature[sample_indices[i]] <= pivot and i < n_samples: i += 1 # same, but now an element that should be on the left - while X[sample_indices[j]] > pivot and j >= 0: + while binned_feature[sample_indices[j]] > pivot and j >= 0: j -= 1 if i >= j: # j can become smaller than j! break @@ -358,8 +335,7 @@ def split_indices_single_thread(context, split_info, sample_indices): sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i] i += 1 j -= 1 - return (sample_indices[:i], - sample_indices[i:]) + return (sample_indices[:i], sample_indices[i:]) @njit(parallel=True) diff --git a/tests/test_splitting.py b/tests/test_splitting.py index d3e8848..8dbbda6 100644 --- a/tests/test_splitting.py +++ b/tests/test_splitting.py @@ -6,7 +6,7 @@ from pygbm.splitting import _find_histogram_split from pygbm.splitting import (SplittingContext, find_node_split, find_node_split_subtraction, - split_indices_parallel, split_indices_single_thread) + _split_indices_parallel, _split_indices_single_threaded) @pytest.mark.parametrize('n_bins', [3, 32, 256]) @@ -271,8 +271,8 @@ def test_split_indices(): assert si_root.feature_idx == 1 assert si_root.bin_idx == 3 - samples_left, samples_right = split_indices_parallel( - context, si_root, context.partition.view()) + samples_left, samples_right = context.split_indices( + si_root, context.partition.view()) assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8]) assert set(samples_right) == set([2, 7, 9]) @@ -288,8 +288,9 @@ def test_split_indices(): assert samples_left.shape[0] == si_root.n_samples_left assert samples_right.shape[0] == si_root.n_samples_right - samples_left_single_thread, samples_right_single_thread = split_indices_single_thread( - context, si_root, context.partition.view()) + # test if the single thread version gives the same result + samples_left_single_thread, samples_right_single_thread = \ + _split_indices_single_threaded(context, si_root, context.partition.view()) assert samples_left.tolist() == samples_left_single_thread.tolist() assert samples_right.tolist() == samples_right_single_thread.tolist()