Skip to content

Commit

Permalink
almost done separating members into survival
Browse files Browse the repository at this point in the history
  • Loading branch information
bcjaeger committed Sep 11, 2023
1 parent 822846e commit 2db644d
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 54 deletions.
7 changes: 1 addition & 6 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ void Forest::init_trees(){
for(uword i = 0; i < n_tree; ++i){

trees[i]->init(data.get(),
&unique_event_times,
tree_seeds[i],
mtry,
leaf_min_events,
leaf_min_obs,
vi_type,
vi_max_pvalue,
split_rule,
split_min_events,
split_min_obs,
split_min_stat,
split_max_cuts,
Expand Down Expand Up @@ -258,9 +255,7 @@ void Forest::predict_in_threads(uint thread_idx,

trees[i]->predict_leaf(prediction_data, oobag);

trees[i]->predict_value(result_ptr, denom_ptr,
pred_horizon, 'S',
oobag);
trees[i]->predict_value(result_ptr, denom_ptr, 'S', oobag);

// Check for user interrupt
if (aborted) {
Expand Down
7 changes: 4 additions & 3 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ void ForestSurvival::load(arma::uword n_tree,
Rcout << std::endl << std::endl;
}


// Create trees
trees.reserve(n_tree);

Expand All @@ -55,7 +54,8 @@ void ForestSurvival::load(arma::uword n_tree,
forest_leaf_pred_indx[i],
forest_leaf_pred_prob[i],
forest_leaf_pred_chaz[i],
forest_leaf_summary[i])
forest_leaf_summary[i],
pred_horizon)
);
}

Expand All @@ -72,7 +72,8 @@ void ForestSurvival::plant() {
for (arma::uword i = 0; i < n_tree; ++i) {
trees.push_back(std::make_unique<TreeSurvival>(leaf_min_events,
split_min_events,
&unique_event_times));
&unique_event_times,
pred_horizon));
}

}
Expand Down
1 change: 0 additions & 1 deletion src/ForestSurvival.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class ForestSurvival: public Forest {
ForestSurvival(const ForestSurvival&) = delete;
ForestSurvival& operator=(const ForestSurvival&) = delete;


void load(arma::uword n_tree,
std::vector<std::vector<double>>& forest_cutpoint,
std::vector<std::vector<arma::uword>>& forest_child_left,
Expand Down
23 changes: 11 additions & 12 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
mtry(0),
vi_type(VI_NONE),
vi_max_pvalue(DEFAULT_ANOVA_VI_PVALUE),
leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
// leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
leaf_min_obs(DEFAULT_LEAF_MIN_OBS),
split_rule(DEFAULT_SPLITRULE),
split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
// split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
split_min_obs(DEFAULT_SPLIT_MIN_OBS),
split_min_stat(DEFAULT_SPLIT_MIN_STAT),
split_max_cuts(DEFAULT_SPLIT_MAX_CUTS),
Expand Down Expand Up @@ -53,10 +53,10 @@
mtry(0),
vi_type(VI_NONE),
vi_max_pvalue(DEFAULT_ANOVA_VI_PVALUE),
leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
// leaf_min_events(DEFAULT_LEAF_MIN_EVENTS),
leaf_min_obs(DEFAULT_LEAF_MIN_OBS),
split_rule(DEFAULT_SPLITRULE),
split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
// split_min_events(DEFAULT_SPLIT_MIN_EVENTS),
split_min_obs(DEFAULT_SPLIT_MIN_OBS),
split_min_stat(DEFAULT_SPLIT_MIN_STAT),
split_max_cuts(DEFAULT_SPLIT_MAX_CUTS),
Expand All @@ -79,15 +79,14 @@


void Tree::init(Data* data,
arma::vec* unique_event_times,
int seed,
arma::uword mtry,
double leaf_min_events,
// double leaf_min_events,
double leaf_min_obs,
VariableImportance vi_type,
double vi_max_pvalue,
SplitRule split_rule,
double split_min_events,
// double split_min_events,
double split_min_obs,
double split_min_stat,
arma::uword split_max_cuts,
Expand All @@ -105,17 +104,16 @@
random_number_generator.seed(seed);

this->data = data;
this->unique_event_times = unique_event_times;
this->n_cols_total = data->n_cols;
this->n_rows_total = data->n_rows;
this->seed = seed;
this->mtry = mtry;
this->leaf_min_events = leaf_min_events;
// this->leaf_min_events = leaf_min_events;
this->leaf_min_obs = leaf_min_obs;
this->vi_type = vi_type;
this->vi_max_pvalue = vi_max_pvalue;
this->split_rule = split_rule;
this->split_min_events = split_min_events;
// this->split_min_events = split_min_events;
this->split_min_obs = split_min_obs;
this->split_min_stat = split_min_stat;
this->split_max_cuts = split_max_cuts;
Expand Down Expand Up @@ -920,7 +918,6 @@

void Tree::predict_value(arma::mat* pred_output,
arma::vec* pred_denom,
arma::vec& pred_times,
char pred_type,
bool oobag){

Expand Down Expand Up @@ -954,8 +951,10 @@

} while (it < pred_leaf_sort.end());

}


double Tree::compute_prediction_accuracy(){
return(0.0);
}


Expand Down
15 changes: 6 additions & 9 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
Tree& operator=(const Tree&) = delete;

void init(Data* data,
arma::vec* unique_event_times,
int seed,
arma::uword mtry,
double leaf_min_events,
// double leaf_min_events,
double leaf_min_obs,
VariableImportance vi_type,
double vi_max_pvalue,
SplitRule split_rule,
double split_min_events,
// double split_min_events,
double split_min_obs,
double split_min_stat,
arma::uword split_max_cuts,
Expand Down Expand Up @@ -86,7 +85,6 @@

virtual void predict_value(arma::mat* pred_output,
arma::vec* pred_denom,
arma::vec& pred_times,
char pred_type,
bool oobag);

Expand Down Expand Up @@ -124,9 +122,6 @@
arma::vec* vi_numer;
arma::uvec* vi_denom;

// pointer to event times in forest
arma::vec* unique_event_times;

// Pointer to original data
Data* data;

Expand Down Expand Up @@ -179,12 +174,12 @@
std::mt19937_64 random_number_generator;

// tree growing members
double leaf_min_events;
// double leaf_min_events;
double leaf_min_obs;

// node split members
SplitRule split_rule;
double split_min_events;
// double split_min_events;
double split_min_obs;
double split_min_stat;
arma::uword split_max_cuts;
Expand Down Expand Up @@ -227,6 +222,8 @@
// leaf values (only in leaf nodes)
std::vector<double> leaf_summary;

virtual double compute_prediction_accuracy();



protected:
Expand Down
77 changes: 66 additions & 11 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

TreeSurvival::TreeSurvival(double leaf_min_events,
double split_min_events,
arma::vec* unique_event_times){
arma::vec* unique_event_times,
arma::vec pred_horizon){

this->leaf_min_events = leaf_min_events;
this->split_min_events = split_min_events;
this->unique_event_times = unique_event_times;
this->pred_horizon = pred_horizon;

}

Expand All @@ -33,11 +35,13 @@
std::vector<arma::vec>& leaf_pred_indx,
std::vector<arma::vec>& leaf_pred_prob,
std::vector<arma::vec>& leaf_pred_chaz,
std::vector<double>& leaf_summary) :
std::vector<double>& leaf_summary,
arma::vec pred_horizon) :
Tree(cutpoint, child_left, coef_values, coef_indices, leaf_summary),
leaf_pred_indx(leaf_pred_indx),
leaf_pred_prob(leaf_pred_prob),
leaf_pred_chaz(leaf_pred_chaz){ }
leaf_pred_chaz(leaf_pred_chaz),
pred_horizon(pred_horizon){ }

void TreeSurvival::resize_leaves(arma::uword new_size) {

Expand Down Expand Up @@ -399,7 +403,8 @@

uvec::iterator event;

double total=0, concordant=0;
// protection from case where there are no comparables.
double total=0.001, concordant=0;

for (event = event_indices.begin(); event < event_indices.end(); ++event) {

Expand All @@ -409,7 +414,7 @@

total += w_node[j];

if (g_node[j] > g_node[*event]){
if (g_node[j] < g_node[*event]){

concordant += w_node[j];

Expand All @@ -425,8 +430,6 @@

}

Rcout << "concordance is: " << concordant / total << std::endl;

return(concordant / total);

}
Expand Down Expand Up @@ -558,7 +561,6 @@

void TreeSurvival::predict_value(arma::mat* pred_output,
arma::vec* pred_denom,
arma::vec& pred_times,
char pred_type,
bool oobag){

Expand Down Expand Up @@ -592,7 +594,7 @@

vec leaf_times, leaf_values;

vec temp_vec(pred_times.size());
vec temp_vec(pred_horizon.size());
double temp_dbl;

do {
Expand All @@ -614,10 +616,10 @@
// (wasteful b/c leaf_times ascend)
i = 0;

for(j = 0; j < pred_times.size(); j++){
for(j = 0; j < pred_horizon.size(); j++){

// t is the current prediction time
double t = pred_times[j];
double t = pred_horizon[j];

// if t < t', where t' is the max time in this leaf,
// then we may find a time t* such that t* < t < t'.
Expand Down Expand Up @@ -683,5 +685,58 @@

}

double TreeSurvival::compute_prediction_accuracy(){


vec y_time = y_oobag.unsafe_col(0);
vec y_status = y_oobag.unsafe_col(1);

uvec oobag_pred_leaf = pred_leaf(rows_oobag);

vec mortality(rows_oobag.size());

for(uword i = 0; i < mortality.size(); ++i){
mortality[i] = leaf_summary[oobag_pred_leaf[i]];
}

Rcout << mortality << std::endl;

// uvec event_indices = find(y_status == 1);
//
// uvec::iterator event;
//
// // protection from case where there are no comparables.
// double total=0.001, concordant=0;
//
// for (event = event_indices.begin(); event < event_indices.end(); ++event) {
//
// for(uword j = *event; j < y_node.n_rows; ++j){
//
// if (y_time[j] > y_time[*event]) { // ties not counted
//
// total += w_node[j];
//
// if (g_node[j] < g_node[*event]){
//
// concordant += w_node[j];
//
// } else if (g_node[j] == g_node[*event]){
//
// concordant += (w_node[j] / 2);
//
// }
//
// }
//
// }
//
// }

// return(concordant / total);
return(0.0);

}


} // namespace aorsf

Loading

0 comments on commit 2db644d

Please sign in to comment.