-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Accept pre-binned data in fit #74
Merged
Merged
Changes from 3 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
f18519a
Accept pre-binned data in fit
NicolasHug 1fcf558
Addressed comments
NicolasHug a5689ad
Merge branch 'master' into binned_X
NicolasHug 648ad17
Addressed comments
NicolasHug 3826804
Added reference to scipy version issue in tox.ini
NicolasHug 1e7075e
Trying tox with no whitespace, single = sign
NicolasHug 1f06008
double equal sign... ?
NicolasHug 688419d
sigh...
NicolasHug 32eb845
removing comment ???????
NicolasHug 056512d
trying == again. Comment was causing error
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -413,25 +413,30 @@ def _finalize_splittable_nodes(self): | |
node = self.splittable_nodes.pop() | ||
self._finalize_leaf(node) | ||
|
||
def make_predictor(self, bin_thresholds=None): | ||
def make_predictor(self, numerical_thresholds=None): | ||
"""Make a TreePredictor object out of the current tree. | ||
|
||
Parameters | ||
---------- | ||
bin_thresholds : array-like of floats, optional (default=None) | ||
The actual thresholds values of each bin. | ||
numerical_thresholds : array-like of floats, optional (default=None) | ||
The actual thresholds values of each bin. None if the training data | ||
was pre-binned. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add that the values of the array are expected to be sorted in increasing order. |
||
|
||
Returns | ||
------- | ||
A TreePredictor object. | ||
""" | ||
predictor_nodes = np.zeros(self.n_nodes, dtype=PREDICTOR_RECORD_DTYPE) | ||
self._fill_predictor_node_array(predictor_nodes, self.root, | ||
bin_thresholds=bin_thresholds) | ||
return TreePredictor(predictor_nodes) | ||
self._fill_predictor_node_array( | ||
predictor_nodes, self.root, | ||
numerical_thresholds=numerical_thresholds | ||
) | ||
has_numerical_thresholds = numerical_thresholds is not None | ||
return TreePredictor(nodes=predictor_nodes, | ||
has_numerical_thresholds=has_numerical_thresholds) | ||
|
||
def _fill_predictor_node_array(self, predictor_nodes, grower_node, | ||
bin_thresholds=None, next_free_idx=0): | ||
numerical_thresholds=None, next_free_idx=0): | ||
"""Helper used in make_predictor to set the TreePredictor fields.""" | ||
node = predictor_nodes[next_free_idx] | ||
node['count'] = grower_node.n_samples | ||
|
@@ -452,17 +457,18 @@ def _fill_predictor_node_array(self, predictor_nodes, grower_node, | |
feature_idx, bin_idx = split_info.feature_idx, split_info.bin_idx | ||
node['feature_idx'] = feature_idx | ||
node['bin_threshold'] = bin_idx | ||
if bin_thresholds is not None: | ||
threshold = bin_thresholds[feature_idx][bin_idx] | ||
node['threshold'] = threshold | ||
if numerical_thresholds is not None: | ||
node['threshold'] = numerical_thresholds[feature_idx][bin_idx] | ||
next_free_idx += 1 | ||
|
||
node['left'] = next_free_idx | ||
next_free_idx = self._fill_predictor_node_array( | ||
predictor_nodes, grower_node.left_child, | ||
bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) | ||
numerical_thresholds=numerical_thresholds, | ||
next_free_idx=next_free_idx) | ||
|
||
node['right'] = next_free_idx | ||
return self._fill_predictor_node_array( | ||
predictor_nodes, grower_node.right_child, | ||
bin_thresholds=bin_thresholds, next_free_idx=next_free_idx) | ||
numerical_thresholds=numerical_thresholds, | ||
next_free_idx=next_free_idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,8 +28,9 @@ class TreePredictor: | |
nodes : list of PREDICTOR_RECORD_DTYPE. | ||
The nodes of the tree. | ||
""" | ||
def __init__(self, nodes): | ||
def __init__(self, nodes, has_numerical_thresholds=True): | ||
self.nodes = nodes | ||
self.has_numerical_thresholds = has_numerical_thresholds | ||
|
||
def get_n_leaf_nodes(self): | ||
"""Return number of leaves.""" | ||
|
@@ -54,6 +55,10 @@ def predict_binned(self, binned_data, out=None): | |
y : array, shape (n_samples,) | ||
The raw predicted values. | ||
""" | ||
|
||
if binned_data.dtype != np.uint8: | ||
raise ValueError('binned_data dtype should be uint8') | ||
|
||
if out is None: | ||
out = np.empty(binned_data.shape[0], dtype=np.float32) | ||
_predict_binned(self.nodes, binned_data, out) | ||
|
@@ -74,6 +79,16 @@ def predict(self, X): | |
""" | ||
# TODO: introspect X to dispatch to numerical or categorical data | ||
# (dense or sparse) on a feature by feature basis. | ||
|
||
if not self.has_numerical_thresholds: | ||
raise ValueError( | ||
'This predictor does not have numerical thresholds so it can' | ||
'only predict pre-binned data.' | ||
) | ||
|
||
if X.dtype == np.uint8: | ||
raise ValueError('X dtype should be float32 or float64') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe be more explicit: if X.dtype == np.uint8:
raise ValueError('X has uint8 dtype: use grower.predict_binned(X) if X is pre-binned, or'
' convert X to a float32 dtype to be treated as numeral data') |
||
|
||
out = np.empty(X.shape[0], dtype=np.float32) | ||
_predict_from_numeric_data(self.nodes, X, out) | ||
return out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In retrospect I find the phrasing confusing. May I suggest the following: