From 9d680835fcd253c411187632346af2064453d444 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Tue, 16 Jan 2024 09:11:32 -0500 Subject: [PATCH] trying to fix cran error --- src/orsf_oop.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index 1402ae91..649ad1e6 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -537,7 +537,7 @@ double compute_mse_exported(arma::vec& y, // Load forest object if it was already grown if(!grow_mode){ - uword n_obs = loaded_forest["n_obs"]; + n_obs = loaded_forest["n_obs"]; std::vector rows_oobag = loaded_forest["rows_oobag"]; std::vector> cutpoint = loaded_forest["cutpoint"]; @@ -608,7 +608,7 @@ double compute_mse_exported(arma::vec& y, List forest_out; forest_out.push_back(n_obs, "n_obs"); - forest_out.push_back(forest->get_oobag_denom(), "oobag_denom"); + // forest_out.push_back(forest->get_oobag_denom(), "oobag_denom"); forest_out.push_back(forest->get_rows_oobag(), "rows_oobag"); forest_out.push_back(forest->get_cutpoint(), "cutpoint"); forest_out.push_back(forest->get_child_left(), "child_left"); @@ -644,7 +644,13 @@ double compute_mse_exported(arma::vec& y, vec vi_output; if(run_forest){ if(vi_type == VI_ANOVA){ - vi_output = forest->get_vi_numer() / forest->get_vi_denom(); + + uvec denom = forest->get_vi_denom(); + uvec zeros = find(denom == 0); + if(zeros.size() > 0) denom(zeros).fill(1); + + vi_output = forest->get_vi_numer() / denom; + } else { vi_output = forest->get_vi_numer() / n_tree; }