Skip to content

Commit

Permalink
Fix rng for the column sampler. (#10998)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Nov 19, 2024
1 parent 1766d48 commit a92ed1d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/common/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class ColumnSampler {
};

inline auto MakeColumnSampler(Context const* ctx) {
std::uint32_t seed = common::GlobalRandomEngine()();
std::uint32_t seed = common::GlobalRandom()();
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
collective::SafeColl(rc);
auto cs = std::make_shared<common::ColumnSampler>(seed);
Expand Down
8 changes: 2 additions & 6 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -826,10 +826,7 @@ class GPUHistMaker : public TreeUpdater {
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";

// Synchronise the column sampling seed
std::uint32_t column_sampling_seed = common::GlobalRandom()();
SafeColl(collective::Broadcast(
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0));
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
this->column_sampler_ = common::MakeColumnSampler(ctx_);

curt::SetDevice(ctx_->Ordinal());
p_fmat->Info().feature_types.SetDevice(ctx_->Device());
Expand Down Expand Up @@ -968,8 +965,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {

monitor_.Start(__func__);
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
uint32_t column_sampling_seed = common::GlobalRandom()();
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
this->column_sampler_ = common::MakeColumnSampler(ctx_);

p_last_fmat_ = p_fmat;
initialised_ = true;
Expand Down
22 changes: 22 additions & 0 deletions tests/python/test_updaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ def test_exact_sample_by_node_error(self) -> None:
num_boost_round=2,
)

@pytest.mark.parametrize("tree_method", ["approx", "hist"])
def test_colsample_rng(self, tree_method: str) -> None:
"""Test rng has an effect on column sampling."""
X, y, _ = tm.make_regression(128, 16, use_cupy=False)
reg0 = xgb.XGBRegressor(
n_estimators=2,
colsample_bynode=0.5,
random_state=42,
tree_method=tree_method,
)
reg0.fit(X, y)

reg1 = xgb.XGBRegressor(
n_estimators=2,
colsample_bynode=0.5,
random_state=43,
tree_method=tree_method,
)
reg1.fit(X, y)

assert list(reg0.feature_importances_) != list(reg1.feature_importances_)

@given(
exact_parameter_strategy,
hist_parameter_strategy,
Expand Down

0 comments on commit a92ed1d

Please sign in to comment.