Skip to content

Commit

Permalink
Change GBDT subestimator to tf.estimator.BoostedTreeEstimator in test.
Browse files Browse the repository at this point in the history
…#1 #121

In estimator_distributed_test_runner.py.

PiperOrigin-RevId: 270085601
  • Loading branch information
csvillalta authored and cweill committed Sep 19, 2019
1 parent a469fe9 commit 55e068e
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions adanet/core/estimator_distributed_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
# pylint: disable=g-direct-tensorflow-import

# Contrib
from tensorflow.contrib.boosted_trees.estimator_batch.estimator import CoreGradientBoostedDecisionTreeEstimator
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.utils import losses as bt_losses
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.training import session_manager as session_manager_lib
Expand Down Expand Up @@ -314,13 +312,12 @@ def tree_loss_fn(labels, logits):
learning_rate=.001),
config=config),
"gbdt":
CoreGradientBoostedDecisionTreeEstimator(
tf.estimator.BoostedTreesEstimator(
head=tree_head,
learner_config=learner_pb2.LearnerConfig(num_classes=n_classes),
examples_per_layer=8,
num_trees=None,
center_bias=False, # Required for multi-class.
feature_columns=feature_columns,
n_trees=10,
n_batches_per_layer=1,
center_bias=False,
config=config),
}

Expand Down

0 comments on commit 55e068e

Please sign in to comment.