From e6ea8a3b62df2811f3373d021f485d160471b9d0 Mon Sep 17 00:00:00 2001
From: Maarten Breddels <maartenbreddels@gmail.com>
Date: Tue, 18 Dec 2018 22:21:26 +0100
Subject: [PATCH 1/3] non-parallel split_indices that require less memory

---
 benchmarks/bench_higgs_boson.py |  2 +-
 pygbm/gradient_boosting.py      | 15 ++++---
 pygbm/grower.py                 | 10 +++--
 pygbm/splitting.py              | 69 ++++++++++++++++++++++++++++++---
 tests/test_splitting.py         | 21 ++++++----
 5 files changed, 94 insertions(+), 23 deletions(-)

diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py
index d7cb38e..b50132e 100644
--- a/benchmarks/bench_higgs_boson.py
+++ b/benchmarks/bench_higgs_boson.py
@@ -85,7 +85,7 @@ def load_data():
                                          max_leaf_nodes=n_leaf_nodes,
                                          n_iter_no_change=None,
                                          random_state=0,
-                                         verbose=1)
+                                         verbose=1, parallel_splitting=False)
 pygbm_model.fit(data_train, target_train)
 toc = time()
 predicted_test = pygbm_model.predict(data_test)
diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py
index b70e4ce..2705970 100644
--- a/pygbm/gradient_boosting.py
+++ b/pygbm/gradient_boosting.py
@@ -26,7 +26,7 @@ class BaseGradientBoostingMachine(BaseEstimator, ABC):
     def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
                  max_depth, min_samples_leaf, l2_regularization, max_bins,
                  scoring, validation_split, n_iter_no_change, tol, verbose,
-                 random_state):
+                 random_state, parallel_splitting):
         self.loss = loss
         self.learning_rate = learning_rate
         self.max_iter = max_iter
@@ -41,6 +41,7 @@ def __init__(self, loss, learning_rate, max_iter, max_leaf_nodes,
         self.tol = tol
         self.verbose = verbose
         self.random_state = random_state
+        self.parallel_splitting = parallel_splitting
 
     def _validate_parameters(self, X):
         """Validate parameters passed to __init__.
@@ -249,7 +250,8 @@ def fit(self, X, y):
                     max_depth=self.max_depth,
                     min_samples_leaf=self.min_samples_leaf,
                     l2_regularization=self.l2_regularization,
-                    shrinkage=self.learning_rate)
+                    shrinkage=self.learning_rate,
+                    parallel_splitting=self.parallel_splitting)
                 grower.grow()
 
                 acc_apply_split_time += grower.total_apply_split_time
@@ -524,7 +526,8 @@ def __init__(self, loss='least_squares', learning_rate=0.1,
                  max_iter=100, max_leaf_nodes=31, max_depth=None,
                  min_samples_leaf=20, l2_regularization=0., max_bins=256,
                  scoring=None, validation_split=0.1, n_iter_no_change=5,
-                 tol=1e-7, verbose=0, random_state=None):
+                 tol=1e-7, verbose=0, random_state=None,
+                 parallel_splitting=True):
         super(GradientBoostingRegressor, self).__init__(
             loss=loss, learning_rate=learning_rate, max_iter=max_iter,
             max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
@@ -532,7 +535,7 @@ def __init__(self, loss='least_squares', learning_rate=0.1,
             l2_regularization=l2_regularization, max_bins=max_bins,
             scoring=scoring, validation_split=validation_split,
             n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
-            random_state=random_state)
+            random_state=random_state, parallel_splitting=parallel_splitting)
 
     def predict(self, X):
         """Predict values for X.
@@ -644,7 +647,7 @@ def __init__(self, loss='auto', learning_rate=0.1, max_iter=100,
                  max_leaf_nodes=31, max_depth=None, min_samples_leaf=20,
                  l2_regularization=0., max_bins=256, scoring=None,
                  validation_split=0.1, n_iter_no_change=5, tol=1e-7,
-                 verbose=0, random_state=None):
+                 verbose=0, random_state=None, parallel_splitting=True):
         super(GradientBoostingClassifier, self).__init__(
             loss=loss, learning_rate=learning_rate, max_iter=max_iter,
             max_leaf_nodes=max_leaf_nodes, max_depth=max_depth,
@@ -652,7 +655,7 @@ def __init__(self, loss='auto', learning_rate=0.1, max_iter=100,
             l2_regularization=l2_regularization, max_bins=max_bins,
             scoring=scoring, validation_split=validation_split,
             n_iter_no_change=n_iter_no_change, tol=tol, verbose=verbose,
-            random_state=random_state)
+            random_state=random_state, parallel_splitting=parallel_splitting)
 
     def predict(self, X):
         """Predict classes for X.
diff --git a/pygbm/grower.py b/pygbm/grower.py
index c77d000..7e8c543 100644
--- a/pygbm/grower.py
+++ b/pygbm/grower.py
@@ -8,7 +8,8 @@
 import numpy as np
 from time import time
 
-from .splitting import (SplittingContext, split_indices, find_node_split,
+from .splitting import (SplittingContext, split_indices_parallel,
+                        split_indices_single_thread, find_node_split,
                         find_node_split_subtraction)
 from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE
 
@@ -163,7 +164,8 @@ class TreeGrower:
     def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
                  max_depth=None, min_samples_leaf=20, min_gain_to_split=0.,
                  max_bins=256, n_bins_per_feature=None, l2_regularization=0.,
-                 min_hessian_to_split=1e-3, shrinkage=1.):
+                 min_hessian_to_split=1e-3, shrinkage=1.,
+                 parallel_splitting=True):
 
         self._validate_parameters(X_binned, max_leaf_nodes, max_depth,
                                   min_samples_leaf, min_gain_to_split,
@@ -180,13 +182,14 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
         self.splitting_context = SplittingContext(
             X_binned, max_bins, n_bins_per_feature, gradients,
             hessians, l2_regularization, min_hessian_to_split,
-            min_samples_leaf, min_gain_to_split)
+            min_samples_leaf, min_gain_to_split, parallel_splitting)
         self.max_leaf_nodes = max_leaf_nodes
         self.max_depth = max_depth
         self.min_samples_leaf = min_samples_leaf
         self.X_binned = X_binned
         self.min_gain_to_split = min_gain_to_split
         self.shrinkage = shrinkage
+        self.parallel_splitting = parallel_splitting
         self.splittable_nodes = []
         self.finalized_leaves = []
         self.total_find_split_time = 0.  # time spent finding the best splits
@@ -336,6 +339,7 @@ def split_next(self):
         node = heappop(self.splittable_nodes)
 
         tic = time()
+        split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread
         (sample_indices_left, sample_indices_right) = split_indices(
             self.splitting_context, node.split_info, node.sample_indices)
         toc = time()
diff --git a/pygbm/splitting.py b/pygbm/splitting.py
index 56ae412..b574816 100644
--- a/pygbm/splitting.py
+++ b/pygbm/splitting.py
@@ -6,7 +6,7 @@
   into the newly created left and right childs.
 """
 import numpy as np
-from numba import njit, jitclass, prange, float32, uint8, uint32
+from numba import njit, jitclass, prange, float32, uint8, uint32, bool_
 import numba
 
 from .histogram import _build_histogram
@@ -88,6 +88,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0,
     ('partition', uint32[::1]),
     ('left_indices_buffer', uint32[::1]),
     ('right_indices_buffer', uint32[::1]),
+    ('parallel_splitting', bool_)
 ])
 class SplittingContext:
     """Pure data class defining a splitting context.
@@ -128,7 +129,7 @@ class SplittingContext:
     def __init__(self, X_binned, max_bins, n_bins_per_feature,
                  gradients, hessians, l2_regularization,
                  min_hessian_to_split=1e-3, min_samples_leaf=20,
-                 min_gain_to_split=0.):
+                 min_gain_to_split=0., parallel_splitting=True):
 
         self.X_binned = X_binned
         self.n_features = X_binned.shape[1]
@@ -148,6 +149,7 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
         self.min_hessian_to_split = min_hessian_to_split
         self.min_samples_leaf = min_samples_leaf
         self.min_gain_to_split = min_gain_to_split
+        self.parallel_splitting = parallel_splitting
         if self.constant_hessian:
             self.constant_hessian_value = self.hessians[0]  # 1 scalar
         else:
@@ -163,16 +165,18 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
         # we have 2 leaves, the left one is at position 0 and the second one at
         # position 3. The order of the samples is irrelevant.
         self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32)
-        # buffers used in split_indices to support parallel splitting.
-        self.left_indices_buffer = np.empty_like(self.partition)
-        self.right_indices_buffer = np.empty_like(self.partition)
+        if self.parallel_splitting:
+            # buffers used in split_indices_parallel to support parallel
+            # splitting.
+            self.right_indices_buffer = np.empty_like(self.partition)
+            self.left_indices_buffer = np.empty_like(self.partition)
 
 
 @njit(parallel=True,
       locals={'sample_idx': uint32,
               'left_count': uint32,
               'right_count': uint32})
-def split_indices(context, split_info, sample_indices):
+def split_indices_parallel(context, split_info, sample_indices):
     """Split samples into left and right arrays.
 
     Parameters
@@ -304,6 +308,59 @@ def split_indices(context, split_info, sample_indices):
     return (sample_indices[:right_child_position],
             sample_indices[right_child_position:])
 
+@njit(parallel=False)
+def split_indices_single_thread(context, split_info, sample_indices):
+    """Split samples into left and right arrays.
+
+    This implementation requires less memory than the parallel version.
+
+    Parameters
+    ----------
+    context : SplittingContext
+        The splitting context
+    split_ingo : SplitInfo
+        The SplitInfo of the node to split
+    sample_indices : array of int
+        The indices of the samples at the node to split. This is a view on
+        context.partition, and it is modified inplace by placing the indices
+        of the left child at the beginning, and the indices of the right child
+        at the end.
+
+    Returns
+    -------
+    left_indices : array of int
+        The indices of the samples in the left child. This is a view on
+        context.partition.
+    right_indices : array of int
+        The indices of the samples in the right child. This is a view on
+        context.partition.
+    """
+    X_binned = context.X_binned.T[split_info.feature_idx]
+    n_samples = sample_indices.shape[0]
+
+    # approach from left with i
+    i = 0
+    # approach from right with j
+    j = n_samples - 1
+    X = X_binned
+    pivot = split_info.bin_idx
+    while i != j:
+        # continue until we find an element that should be on right
+        while X[sample_indices[i]] <= pivot and i < n_samples:
+            i += 1
+        # same, but now an element that should be on the left
+        while X[sample_indices[j]] > pivot and j >= 0:
+            j -= 1
+        if i >= j:  # j can become smaller than j!
+            break
+        else:
+            # swap
+            sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i]
+            i += 1
+            j -= 1
+    return (sample_indices[:i],
+            sample_indices[i:])
+
 
 @njit(parallel=True)
 def find_node_split(context, sample_indices):
diff --git a/tests/test_splitting.py b/tests/test_splitting.py
index 3fc26db..d3e8848 100644
--- a/tests/test_splitting.py
+++ b/tests/test_splitting.py
@@ -6,7 +6,7 @@
 from pygbm.splitting import _find_histogram_split
 from pygbm.splitting import (SplittingContext, find_node_split,
                              find_node_split_subtraction,
-                             split_indices)
+                             split_indices_parallel, split_indices_single_thread)
 
 
 @pytest.mark.parametrize('n_bins', [3, 32, 256])
@@ -39,7 +39,8 @@ def test_histogram_split(n_bins):
                                        all_gradients, all_hessians,
                                        l2_regularization,
                                        min_hessian_to_split,
-                                       min_samples_leaf, min_gain_to_split)
+                                       min_samples_leaf, min_gain_to_split,
+                                       True)
 
             split_info, _ = _find_histogram_split(context, feature_idx,
                                                   sample_indices)
@@ -85,7 +86,7 @@ def test_split_vs_split_subtraction(constant_hessian):
                                n_bins_per_feature,
                                all_gradients, all_hessians,
                                l2_regularization, min_hessian_to_split,
-                               min_samples_leaf, min_gain_to_split)
+                               min_samples_leaf, min_gain_to_split, True)
 
     mask = rng.randint(0, 2, n_samples).astype(np.bool)
     sample_indices_left = sample_indices[mask]
@@ -165,7 +166,7 @@ def test_gradient_and_hessian_sanity(constant_hessian):
                                n_bins_per_feature,
                                all_gradients, all_hessians,
                                l2_regularization, min_hessian_to_split,
-                               min_samples_leaf, min_gain_to_split)
+                               min_samples_leaf, min_gain_to_split, True)
 
     mask = rng.randint(0, 2, n_samples).astype(np.bool)
     sample_indices_left = sample_indices[mask]
@@ -261,7 +262,7 @@ def test_split_indices():
                                n_bins_per_feature,
                                all_gradients, all_hessians,
                                l2_regularization, min_hessian_to_split,
-                               min_samples_leaf, min_gain_to_split)
+                               min_samples_leaf, min_gain_to_split, True)
 
     assert_array_almost_equal(sample_indices, context.partition)
     si_root, _ = find_node_split(context, sample_indices)
@@ -270,7 +271,7 @@ def test_split_indices():
     assert si_root.feature_idx == 1
     assert si_root.bin_idx == 3
 
-    samples_left, samples_right = split_indices(
+    samples_left, samples_right = split_indices_parallel(
         context, si_root, context.partition.view())
     assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8])
     assert set(samples_right) == set([2, 7, 9])
@@ -287,6 +288,12 @@ def test_split_indices():
     assert samples_left.shape[0] == si_root.n_samples_left
     assert samples_right.shape[0] == si_root.n_samples_right
 
+    samples_left_single_thread, samples_right_single_thread = split_indices_single_thread(
+        context, si_root, context.partition.view())
+
+    assert samples_left.tolist() == samples_left_single_thread.tolist()
+    assert samples_right.tolist() == samples_right_single_thread.tolist()
+
 
 def test_min_gain_to_split():
     # Try to split a pure node (all gradients are equal, same for hessians)
@@ -314,7 +321,7 @@ def test_min_gain_to_split():
                                all_gradients, all_hessians,
                                l2_regularization,
                                min_hessian_to_split,
-                               min_samples_leaf, min_gain_to_split)
+                               min_samples_leaf, min_gain_to_split, True)
 
     split_info, _ = _find_histogram_split(context, feature_idx, sample_indices)
     assert split_info.gain == -1

From b46604736c8d4521a99e27b5f9574f7260fc9d14 Mon Sep 17 00:00:00 2001
From: Maarten Breddels <maartenbreddels@gmail.com>
Date: Wed, 19 Dec 2018 15:28:56 +0100
Subject: [PATCH 2/3] implemented suggestions

---
 pygbm/grower.py         | 10 ++---
 pygbm/splitting.py      | 90 +++++++++++++++--------------------------
 tests/test_splitting.py | 11 ++---
 3 files changed, 43 insertions(+), 68 deletions(-)

diff --git a/pygbm/grower.py b/pygbm/grower.py
index 7e8c543..b60f0da 100644
--- a/pygbm/grower.py
+++ b/pygbm/grower.py
@@ -8,8 +8,7 @@
 import numpy as np
 from time import time
 
-from .splitting import (SplittingContext, split_indices_parallel,
-                        split_indices_single_thread, find_node_split,
+from .splitting import (SplittingContext, find_node_split,
                         find_node_split_subtraction)
 from .predictor import TreePredictor, PREDICTOR_RECORD_DTYPE
 
@@ -189,7 +188,6 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
         self.X_binned = X_binned
         self.min_gain_to_split = min_gain_to_split
         self.shrinkage = shrinkage
-        self.parallel_splitting = parallel_splitting
         self.splittable_nodes = []
         self.finalized_leaves = []
         self.total_find_split_time = 0.  # time spent finding the best splits
@@ -339,9 +337,9 @@ def split_next(self):
         node = heappop(self.splittable_nodes)
 
         tic = time()
-        split_indices = split_indices_parallel if self.parallel_splitting else split_indices_single_thread
-        (sample_indices_left, sample_indices_right) = split_indices(
-            self.splitting_context, node.split_info, node.sample_indices)
+        (sample_indices_left, sample_indices_right) = \
+            self.splitting_context.split_indices(node.split_info,
+                                                 node.sample_indices)
         toc = time()
         node.apply_split_time = toc - tic
         self.total_apply_split_time += node.apply_split_time
diff --git a/pygbm/splitting.py b/pygbm/splitting.py
index b574816..d9f8e22 100644
--- a/pygbm/splitting.py
+++ b/pygbm/splitting.py
@@ -171,35 +171,39 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature,
             self.right_indices_buffer = np.empty_like(self.partition)
             self.left_indices_buffer = np.empty_like(self.partition)
 
+    def split_indices(self, split_info, sample_indices):
+        """Split samples into left and right arrays.
+
+        Parameters
+        ----------
+        split_info : SplitInfo
+            The SplitInfo of the node to split
+        sample_indices : array of int
+            The indices of the samples at the node to split. This is a view on
+            context.partition, and it is modified inplace by placing the indices
+            of the left child at the beginning, and the indices of the right child
+            at the end.
+
+        Returns
+        -------
+        left_indices : array of int
+            The indices of the samples in the left child. This is a view on
+            context.partition.
+        right_indices : array of int
+            The indices of the samples in the right child. This is a view on
+            context.partition.
+        """
+        if self.parallel_splitting:
+            return _split_indices_parallel(self, split_info, sample_indices)
+        else:
+            return _split_indices_single_threaded(self, split_info, sample_indices)
+
 
 @njit(parallel=True,
       locals={'sample_idx': uint32,
               'left_count': uint32,
               'right_count': uint32})
-def split_indices_parallel(context, split_info, sample_indices):
-    """Split samples into left and right arrays.
-
-    Parameters
-    ----------
-    context : SplittingContext
-        The splitting context
-    split_ingo : SplitInfo
-        The SplitInfo of the node to split
-    sample_indices : array of int
-        The indices of the samples at the node to split. This is a view on
-        context.partition, and it is modified inplace by placing the indices
-        of the left child at the beginning, and the indices of the right child
-        at the end.
-
-    Returns
-    -------
-    left_indices : array of int
-        The indices of the samples in the left child. This is a view on
-        context.partition.
-    right_indices : array of int
-        The indices of the samples in the right child. This is a view on
-        context.partition.
-    """
+def _split_indices_parallel(context, split_info, sample_indices):
     # This is a multi-threaded implementation inspired by lightgbm.
     # Here is a quick break down. Let's suppose we want to split a node with
     # 24 samples named from a to x. context.partition looks like this (the *
@@ -309,47 +313,20 @@ def split_indices_parallel(context, split_info, sample_indices):
             sample_indices[right_child_position:])
 
 @njit(parallel=False)
-def split_indices_single_thread(context, split_info, sample_indices):
-    """Split samples into left and right arrays.
-
-    This implementation requires less memory than the parallel version.
-
-    Parameters
-    ----------
-    context : SplittingContext
-        The splitting context
-    split_ingo : SplitInfo
-        The SplitInfo of the node to split
-    sample_indices : array of int
-        The indices of the samples at the node to split. This is a view on
-        context.partition, and it is modified inplace by placing the indices
-        of the left child at the beginning, and the indices of the right child
-        at the end.
-
-    Returns
-    -------
-    left_indices : array of int
-        The indices of the samples in the left child. This is a view on
-        context.partition.
-    right_indices : array of int
-        The indices of the samples in the right child. This is a view on
-        context.partition.
-    """
-    X_binned = context.X_binned.T[split_info.feature_idx]
+def _split_indices_single_threaded(context, split_info, sample_indices):
+    binned_feature = context.X_binned.T[split_info.feature_idx]
     n_samples = sample_indices.shape[0]
-
     # approach from left with i
     i = 0
     # approach from right with j
     j = n_samples - 1
-    X = X_binned
     pivot = split_info.bin_idx
     while i != j:
         # continue until we find an element that should be on right
-        while X[sample_indices[i]] <= pivot and i < n_samples:
+        while binned_feature[sample_indices[i]] <= pivot and i < n_samples:
             i += 1
         # same, but now an element that should be on the left
-        while X[sample_indices[j]] > pivot and j >= 0:
+        while binned_feature[sample_indices[j]] > pivot and j >= 0:
             j -= 1
         if i >= j:  # j can become smaller than j!
             break
@@ -358,8 +335,7 @@ def split_indices_single_thread(context, split_info, sample_indices):
             sample_indices[i], sample_indices[j] = sample_indices[j], sample_indices[i]
             i += 1
             j -= 1
-    return (sample_indices[:i],
-            sample_indices[i:])
+    return (sample_indices[:i], sample_indices[i:])
 
 
 @njit(parallel=True)
diff --git a/tests/test_splitting.py b/tests/test_splitting.py
index d3e8848..8dbbda6 100644
--- a/tests/test_splitting.py
+++ b/tests/test_splitting.py
@@ -6,7 +6,7 @@
 from pygbm.splitting import _find_histogram_split
 from pygbm.splitting import (SplittingContext, find_node_split,
                              find_node_split_subtraction,
-                             split_indices_parallel, split_indices_single_thread)
+                             _split_indices_parallel, _split_indices_single_threaded)
 
 
 @pytest.mark.parametrize('n_bins', [3, 32, 256])
@@ -271,8 +271,8 @@ def test_split_indices():
     assert si_root.feature_idx == 1
     assert si_root.bin_idx == 3
 
-    samples_left, samples_right = split_indices_parallel(
-        context, si_root, context.partition.view())
+    samples_left, samples_right = context.split_indices(
+        si_root, context.partition.view())
     assert set(samples_left) == set([0, 1, 3, 4, 5, 6, 8])
     assert set(samples_right) == set([2, 7, 9])
 
@@ -288,8 +288,9 @@ def test_split_indices():
     assert samples_left.shape[0] == si_root.n_samples_left
     assert samples_right.shape[0] == si_root.n_samples_right
 
-    samples_left_single_thread, samples_right_single_thread = split_indices_single_thread(
-        context, si_root, context.partition.view())
+    # test if the single thread version gives the same result
+    samples_left_single_thread, samples_right_single_thread = \
+        _split_indices_single_threaded(context, si_root, context.partition.view())
 
     assert samples_left.tolist() == samples_left_single_thread.tolist()
     assert samples_right.tolist() == samples_right_single_thread.tolist()

From e84cdc31e035f6b2872fcd49637af499ce40f932 Mon Sep 17 00:00:00 2001
From: Maarten Breddels <maartenbreddels@gmail.com>
Date: Wed, 19 Dec 2018 15:29:27 +0100
Subject: [PATCH 3/3] add --no-parallel-split argument to higg benchmark

---
 benchmarks/bench_higgs_boson.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py
index b50132e..93110c0 100644
--- a/benchmarks/bench_higgs_boson.py
+++ b/benchmarks/bench_higgs_boson.py
@@ -21,6 +21,7 @@
 parser.add_argument('--learning-rate', type=float, default=1.)
 parser.add_argument('--subsample', type=int, default=None)
 parser.add_argument('--max-bins', type=int, default=255)
+parser.add_argument('--no-parallel-split', action="store_true", default=False)
 args = parser.parse_args()
 
 HERE = os.path.dirname(__file__)
@@ -85,7 +86,8 @@ def load_data():
                                          max_leaf_nodes=n_leaf_nodes,
                                          n_iter_no_change=None,
                                          random_state=0,
-                                         verbose=1, parallel_splitting=False)
+                                         verbose=1,
+                                         parallel_splitting=not args.no_parallel_split)
 pygbm_model.fit(data_train, target_train)
 toc = time()
 predicted_test = pygbm_model.predict(data_test)