Skip to content

Commit

Permalink
prep to use arma field for leaf data
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Sep 6, 2023
1 parent b372d5b commit 27f96ac
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 74 deletions.
86 changes: 14 additions & 72 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,20 +211,15 @@

for (i = rows_node.begin(); i != rows_node.end(); ++i) {

// if event occurred for this observation
if(y_inbag.at(*i, 1) == 1){
if(x_first_undef){

if(x_first_undef){
x_first_value = x_inbag.at(*i, j);
x_first_undef = false;

x_first_value = x_inbag.at(*i, j);
x_first_undef = false;

} else {

if(x_inbag.at(*i, j) != x_first_value){
return(true);
}
} else {

if(x_inbag.at(*i, j) != x_first_value){
return(true);
}

}
Expand All @@ -234,11 +229,6 @@
if(VERBOSITY > 1){

mat x_print = x_inbag.rows(rows_node);
mat y_print = y_inbag.rows(rows_node);

uvec rows_event = find(y_print.col(1) == 1);
x_print = x_print.rows(rows_event);

Rcout << "Column " << j << " was sampled but ";
Rcout << "unique values of column " << j << " are ";
Rcout << unique(x_print.col(j)) << std::endl;
Expand Down Expand Up @@ -426,65 +416,16 @@

}

double Tree::score_logrank(){

double
n_risk=0,
g_risk=0,
observed=0,
expected=0,
V=0,
temp1,
temp2,
n_events;

vec y_time = y_node.unsafe_col(0);
vec y_status = y_node.unsafe_col(1);

bool break_loop = false;

uword i = y_node.n_rows-1;

// breaking condition of outer loop governed by inner loop
for (; ;){

temp1 = y_time[i];
double Tree::compute_split_score(){

n_events = 0;
// default method is to pick one completely at random
// (this won't stay the default - it's a placeholder)

for ( ; y_time[i] == temp1; i--) {
std::normal_distribution<double> ndist_score(0, 1);

n_risk += w_node[i];
n_events += y_status[i] * w_node[i];
g_risk += g_node[i] * w_node[i];
observed += y_status[i] * g_node[i] * w_node[i];
double result = ndist_score(random_number_generator);

if(i == 0){
break_loop = true;
break;
}

}

// should only do these calculations if n_events > 0,
// but in practice its often faster to multiply by 0
// versus check if n_events is > 0.

temp2 = g_risk / n_risk;
expected += n_events * temp2;

// update variance if n_risk > 1 (if n_risk == 1, variance is 0)
// definitely check if n_risk is > 1 b/c otherwise divide by 0
if (n_risk > 1){
temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
V += temp1 * (1 - temp2);
}

if(break_loop) break;

}

return(pow(expected-observed, 2) / V);
return(result);

}

Expand Down Expand Up @@ -561,7 +502,8 @@
// flip node assignments from left to right, up to the next cutpoint
g_node.elem(lincomb_sort.subvec(it_start, *it)).fill(0);
// compute split statistics with this cut-point
stat = score_logrank();
stat = compute_split_score();
// stat = score_logrank();
// update leaderboard
if(stat > stat_best) { stat_best = stat; it_best = *it; }
// set up next loop run
Expand Down
4 changes: 2 additions & 2 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@

void sample_cols();

bool is_col_splittable(arma::uword j);
virtual bool is_col_splittable(arma::uword j);

bool is_node_splittable(arma::uword node_id);

virtual bool is_node_splittable_internal();

virtual arma::uvec find_cutpoints();

double score_logrank();
virtual double compute_split_score();

double node_split(arma::uvec& cuts_all);

Expand Down
130 changes: 130 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,58 @@

}

bool TreeSurvival::is_col_splittable(uword j){

uvec::iterator i;

// initialize as 0 but do not make comparisons until x_first_value
// is formally defined at the first instance of status == 1
double x_first_value=0;

bool x_first_undef = true;

for (i = rows_node.begin(); i != rows_node.end(); ++i) {

// if event occurred for this observation
// column is only splittable if X is non-constant among
// observations where an event occurred.
if(y_inbag.at(*i, 1) == 1){

if(x_first_undef){

x_first_value = x_inbag.at(*i, j);
x_first_undef = false;

} else {

if(x_inbag.at(*i, j) != x_first_value){
return(true);
}

}

}

}

if(VERBOSITY > 1){

mat x_print = x_inbag.rows(rows_node);
mat y_print = y_inbag.rows(rows_node);

uvec rows_event = find(y_print.col(1) == 1);
x_print = x_print.rows(rows_event);

Rcout << "Column " << j << " was sampled but ";
Rcout << "unique values of column " << j << " are ";
Rcout << unique(x_print.col(j)) << std::endl;

}

return(false);

}

bool TreeSurvival::is_node_splittable_internal(){

double n_risk = sum(w_node);
Expand Down Expand Up @@ -219,6 +271,84 @@

}

double TreeSurvival::compute_split_score(){

double result;

switch (split_rule) {

case SPLIT_LOGRANK: {
result = score_logrank();
break;
}

}

return(result);

}

double TreeSurvival::score_logrank(){

double
n_risk=0,
g_risk=0,
observed=0,
expected=0,
V=0,
temp1,
temp2,
n_events;

vec y_time = y_node.unsafe_col(0);
vec y_status = y_node.unsafe_col(1);

bool break_loop = false;

uword i = y_node.n_rows-1;

// breaking condition of outer loop governed by inner loop
for (; ;){

temp1 = y_time[i];

n_events = 0;

for ( ; y_time[i] == temp1; i--) {

n_risk += w_node[i];
n_events += y_status[i] * w_node[i];
g_risk += g_node[i] * w_node[i];
observed += y_status[i] * g_node[i] * w_node[i];

if(i == 0){
break_loop = true;
break;
}

}

// should only do these calculations if n_events > 0,
// but in practice its often faster to multiply by 0
// versus check if n_events is > 0.

temp2 = g_risk / n_risk;
expected += n_events * temp2;

// update variance if n_risk > 1 (if n_risk == 1, variance is 0)
// definitely check if n_risk is > 1 b/c otherwise divide by 0
if (n_risk > 1){
temp1 = n_events * temp2 * (n_risk-n_events) / (n_risk-1);
V += temp1 * (1 - temp2);
}

if(break_loop) break;

}

return(pow(expected-observed, 2) / V);

}


} // namespace aorsf
Expand Down
6 changes: 6 additions & 0 deletions src/TreeSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@

double compute_max_leaves() override;

bool is_col_splittable(arma::uword j) override;

bool is_node_splittable_internal() override;

arma::uvec find_cutpoints() override;

double compute_split_score() override;

double score_logrank();

};

} // namespace aorsf
Expand Down

0 comments on commit 27f96ac

Please sign in to comment.