diff --git a/adanet/core/estimator_distributed_test_runner.py b/adanet/core/estimator_distributed_test_runner.py index b108bbc5..25dc4195 100644 --- a/adanet/core/estimator_distributed_test_runner.py +++ b/adanet/core/estimator_distributed_test_runner.py @@ -50,8 +50,6 @@ # pylint: disable=g-direct-tensorflow-import # Contrib -from tensorflow.contrib.boosted_trees.estimator_batch.estimator import CoreGradientBoostedDecisionTreeEstimator -from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.python.utils import losses as bt_losses from tensorflow.python.ops import partitioned_variables from tensorflow.python.training import session_manager as session_manager_lib @@ -314,13 +312,12 @@ def tree_loss_fn(labels, logits): learning_rate=.001), config=config), "gbdt": - CoreGradientBoostedDecisionTreeEstimator( + tf.estimator.BoostedTreesEstimator( head=tree_head, - learner_config=learner_pb2.LearnerConfig(num_classes=n_classes), - examples_per_layer=8, - num_trees=None, - center_bias=False, # Required for multi-class. feature_columns=feature_columns, + n_trees=10, + n_batches_per_layer=1, + center_bias=False, config=config), }