35 #ifndef VIGRA_RF_ALGORITHM_HXX 
   36 #define VIGRA_RF_ALGORITHM_HXX 
   38 #include "splices.hxx" 
   58     template<
class OrigMultiArray,
 
   61     void choose(OrigMultiArray     
const & in,
 
   70         for(Iter iter = b; iter != e; ++iter, ++ii)
 
  100     template<
class Feature_t, 
class Response_t>
 
  102                        Response_t 
const & response)
 
  125     typedef std::vector<int> FeatureList_t;
 
  126     typedef std::vector<double> ErrorList_t;
 
  127     typedef FeatureList_t::iterator Pivot_t;
 
  153     template<
class FeatureT, 
 
  156              class ErrorRateCallBack>
 
  157     bool init(FeatureT 
const & all_features,
 
  158               ResponseT 
const & response,
 
  161               ErrorRateCallBack errorcallback)
 
  163         bool ret_ = init(all_features, response, errorcallback); 
 
  166         vigra_precondition(std::distance(b, e) == static_cast<std::ptrdiff_t>(
selected.size()),
 
  167                            "Number of features in ranking != number of features matrix");
 
  172     template<
class FeatureT, 
 
  175     bool init(FeatureT 
const & all_features,
 
  176               ResponseT 
const & response,
 
  181         return init(all_features, response, b, e, ecallback);
 
  185     template<
class FeatureT, 
 
  187     bool init(FeatureT 
const & all_features,
 
  188               ResponseT 
const & response)
 
  190         return init(all_features, response, RFErrorCallback());
 
  202     template<
class FeatureT, 
 
  204              class ErrorRateCallBack>
 
  205     bool init(FeatureT 
const & all_features,
 
  206               ResponseT 
const & response,
 
  207               ErrorRateCallBack errorcallback)
 
  215         selected.resize(all_features.shape(1), 0);
 
  216         for(
unsigned int ii = 0; ii < 
selected.size(); ++ii)
 
  218         errors.resize(all_features.shape(1), -1);
 
  219         errors.back() = errorcallback(all_features, response);
 
  223         std::map<typename ResponseT::value_type, int>     res_map;
 
  224         std::vector<int>                                 cts;
 
  226         for(
int ii = 0; ii < response.shape(0); ++ii)
 
  228             if(res_map.find(response(ii, 0)) == res_map.end())
 
  230                 res_map[response(ii, 0)] = counter;
 
  234             cts[res_map[response(ii,0)]] +=1;
 
  236         no_features = double(*(std::max_element(cts.begin(),
 
  238                     / 
double(response.shape(0));
 
  293 template<
class FeatureT, 
class ResponseT, 
class ErrorRateCallBack>
 
  295                        ResponseT          
const & response,
 
  297                        ErrorRateCallBack          errorcallback)
 
  299     VariableSelectionResult::FeatureList_t & selected         = result.
selected;
 
  300     VariableSelectionResult::ErrorList_t &     errors            = result.
errors;
 
  301     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
 
  302     int featureCount = features.shape(1);
 
  304     if(!result.init(features, response, errorcallback))
 
  308         vigra_precondition(static_cast<int>(selected.size()) == featureCount,
 
  309                            "forward_selection(): Number of features in Feature " 
  310                            "matrix and number of features in previously used " 
  311                            "result struct mismatch!");
 
  315     int not_selected_size = std::distance(pivot, selected.end());
 
  316     while(not_selected_size > 1)
 
  318         std::vector<double> current_errors;
 
  319         VariableSelectionResult::Pivot_t next = pivot;
 
  320         for(
int ii = 0; ii < not_selected_size; ++ii, ++next)
 
  322             std::swap(*pivot, *next);
 
  324             detail::choose( features, 
 
  328             double error = errorcallback(cur_feats, response);
 
  329             current_errors.push_back(error);
 
  330             std::swap(*pivot, *next);
 
  332         int pos = std::distance(current_errors.begin(),
 
  333                                 std::min_element(current_errors.begin(),
 
  334                                                    current_errors.end()));
 
  336         std::advance(next, pos);
 
  337         std::swap(*pivot, *next);
 
  338         errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
 
  340             std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr, 
", "));
 
  341             std::cerr << 
"Choosing " << *pivot << 
" at error of " <<  current_errors[pos] << std::endl;
 
  344         not_selected_size = std::distance(pivot, selected.end());
 
  347 template<
class FeatureT, 
class ResponseT>
 
  349                        ResponseT          
const & response,
 
  350                        VariableSelectionResult & result)
 
  395 template<
class FeatureT, 
class ResponseT, 
class ErrorRateCallBack>
 
  397                              ResponseT         
const & response,
 
  399                           ErrorRateCallBack         errorcallback)
 
  401     int featureCount = features.shape(1);
 
  402     VariableSelectionResult::FeatureList_t & selected         = result.
selected;
 
  403     VariableSelectionResult::ErrorList_t &     errors            = result.
errors;
 
  404     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
 
  407     if(!result.init(features, response, errorcallback))
 
  411         vigra_precondition(static_cast<int>(selected.size()) == featureCount,
 
  412                            "backward_elimination(): Number of features in Feature " 
  413                            "matrix and number of features in previously used " 
  414                            "result struct mismatch!");
 
  416     pivot = selected.end() - 1;    
 
  418     int selected_size = std::distance(selected.begin(), pivot);
 
  419     while(selected_size > 1)
 
  421         VariableSelectionResult::Pivot_t next = selected.begin();
 
  422         std::vector<double> current_errors;
 
  423         for(
int ii = 0; ii < selected_size; ++ii, ++next)
 
  425             std::swap(*pivot, *next);
 
  427             detail::choose( features, 
 
  431             double error = errorcallback(cur_feats, response);
 
  432             current_errors.push_back(error);
 
  433             std::swap(*pivot, *next);
 
  435         int pos = std::distance(current_errors.begin(),
 
  436                                 std::min_element(current_errors.begin(),
 
  437                                                    current_errors.end()));
 
  438         next = selected.begin();
 
  439         std::advance(next, pos);
 
  440         std::swap(*pivot, *next);
 
  442         errors[std::distance(selected.begin(), pivot)-1] = current_errors[pos];
 
  443         selected_size = std::distance(selected.begin(), pivot);
 
  445             std::copy(current_errors.begin(), current_errors.end(), std::ostream_iterator<double>(std::cerr, 
", "));
 
  446             std::cerr << 
"Eliminating " << *pivot << 
" at error of " << current_errors[pos] << std::endl;
 
  452 template<
class FeatureT, 
class ResponseT>
 
  454                              ResponseT         
const & response,
 
  455                           VariableSelectionResult & result)
 
  492 template<
class FeatureT, 
class ResponseT, 
class ErrorRateCallBack>
 
  494                              ResponseT         
const & response,
 
  496                           ErrorRateCallBack         errorcallback)
 
  498     VariableSelectionResult::FeatureList_t & selected         = result.
selected;
 
  499     VariableSelectionResult::ErrorList_t &     errors            = result.
errors;
 
  500     VariableSelectionResult::Pivot_t       & iter            = result.pivot;
 
  501     int featureCount = features.shape(1);
 
  503     if(!result.init(features, response, errorcallback))
 
  507         vigra_precondition(static_cast<int>(selected.size()) == featureCount,
 
  508                            "forward_selection(): Number of features in Feature " 
  509                            "matrix and number of features in previously used " 
  510                            "result struct mismatch!");
 
  514     for(; iter != selected.end(); ++iter)
 
  518         detail::choose( features, 
 
  522         double error = errorcallback(cur_feats, response);
 
  523         errors[std::distance(selected.begin(), iter)] = error;
 
  525             std::copy(selected.begin(), iter+1, std::ostream_iterator<int>(std::cerr, 
", "));
 
  526             std::cerr << 
"Choosing " << *(iter+1) << 
" at error of " <<  error << std::endl;
 
  532 template<
class FeatureT, 
class ResponseT>
 
  534                              ResponseT         
const & response,
 
  535                           VariableSelectionResult & result)
 
  542 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
 
  557     ClusterNode():NodeBase(){}
 
  558     ClusterNode(    
int                      nCol,
 
  559                     BT::T_Container_type    &   topology,
 
  560                     BT::P_Container_type    &   split_param)
 
  561                 :   BT(nCol + 5, 5,topology, split_param)
 
  571     ClusterNode(           BT::T_Container_type  
const  &   topology,
 
  572                     BT::P_Container_type  
const  &   split_param,
 
  574                 :   
NodeBase(5 , 5,topology, split_param, n)
 
  580     ClusterNode( BT & node_)
 
  585         BT::parameter_size_ += 0;
 
  591     void set_index(
int in)
 
  617     HC_Entry(
int p, 
int l, 
int a, 
bool in)
 
  618         : parent(p), level(l), addr(a), infm(in)
 
  647     double dist_func(
double a, 
double b)
 
  649         return std::min(a, b); 
 
  655     template<
class Functor>
 
  659         std::vector<int> stack; 
 
  660         stack.push_back(begin_addr); 
 
  661         while(!stack.empty())
 
  663             ClusterNode node(topology_, parameters_, stack.
back());
 
  667                 if(node.columns_size() != 1)
 
  669                     stack.push_back(node.child(0));
 
  670                     stack.push_back(node.child(1));
 
  678     template<
class Functor>
 
  682         std::queue<HC_Entry> queue; 
 
  687         queue.push(
HC_Entry(parent,level,begin_addr, infm)); 
 
  688         while(!queue.empty())
 
  690             level  = queue.front().level;
 
  691             parent = queue.front().parent;
 
  692             addr   = queue.front().addr;
 
  693             infm   = queue.front().infm;
 
  694             ClusterNode node(topology_, parameters_, queue.
front().addr);
 
  698                 parnt = ClusterNode(topology_, parameters_, parent); 
 
  701             bool istrue = tester(node, level, parnt, infm);
 
  702             if(node.columns_size() != 1)
 
  704                 queue.push(
HC_Entry(addr, level +1,node.child(0),istrue));
 
  705                 queue.push(
HC_Entry(addr, level +1,node.child(1),istrue));
 
  712     void save(std::string file, std::string prefix)
 
  717                                     Shp(topology_.
size(),1),
 
  721                                     Shp(parameters_.
size(), 1),
 
  722                                     parameters_.
data()));
 
  732     template<
class T, 
class C>
 
  736         std::vector<std::pair<int, int> > addr; 
 
  738         for(
int ii = 0; ii < distance.
shape(0); ++ii)
 
  740             addr.push_back(std::make_pair(topology_.
size(), ii));
 
  741             ClusterNode leaf(1, topology_, parameters_);
 
  742             leaf.set_index(index);
 
  744             leaf.columns_begin()[0] = ii;
 
  747         while(addr.size() != 1)
 
  752             double min_dist = dist((addr.begin()+ii_min)->second, 
 
  753                               (addr.begin()+jj_min)->second);
 
  754             for(
unsigned int ii = 0; ii < addr.size(); ++ii)
 
  756                 for(
unsigned int jj = ii+1; jj < addr.size(); ++jj)
 
  758                     if(  dist((addr.begin()+ii_min)->second, 
 
  759                               (addr.begin()+jj_min)->second)
 
  760                        > dist((addr.begin()+ii)->second, 
 
  761                               (addr.begin()+jj)->second))
 
  763                         min_dist = dist((addr.begin()+ii)->second, 
 
  764                               (addr.begin()+jj)->second);
 
  776                 ClusterNode firstChild(topology_, 
 
  778                                        (addr.begin() +ii_min)->first);
 
  779                 ClusterNode secondChild(topology_, 
 
  781                                        (addr.begin() +jj_min)->first);
 
  782                 col_size = firstChild.columns_size() + secondChild.columns_size();
 
  784             int cur_addr = topology_.
size();
 
  785             begin_addr = cur_addr;
 
  787             ClusterNode parent(col_size,
 
  790             ClusterNode firstChild(topology_, 
 
  792                                    (addr.begin() +ii_min)->first);
 
  793             ClusterNode secondChild(topology_, 
 
  795                                    (addr.begin() +jj_min)->first);
 
  796             parent.parameters_begin()[0] = min_dist;
 
  797             parent.set_index(index);
 
  799             std::merge(firstChild.columns_begin(), firstChild.columns_end(),
 
  800                        secondChild.columns_begin(),secondChild.columns_end(),
 
  801                        parent.columns_begin());
 
  805             if(*parent.columns_begin() ==  *firstChild.columns_begin())
 
  807                 parent.child(0) = (addr.begin()+ii_min)->first;
 
  808                 parent.child(1) = (addr.begin()+jj_min)->first;
 
  809                 (addr.begin()+ii_min)->first = cur_addr;
 
  811                 to_desc = (addr.begin()+jj_min)->second;
 
  812                 addr.erase(addr.begin()+jj_min);
 
  816                 parent.child(1) = (addr.begin()+ii_min)->first;
 
  817                 parent.child(0) = (addr.begin()+jj_min)->first;
 
  818                 (addr.begin()+jj_min)->first = cur_addr;
 
  820                 to_desc = (addr.begin()+ii_min)->second;
 
  821                 addr.erase(addr.begin()+ii_min);
 
  825             for(
int jj = 0 ; jj < static_cast<int>(addr.size()); ++jj)
 
  829                 double bla = dist_func(
 
  830                                   dist(to_desc, (addr.begin()+jj)->second),
 
  831                                   dist((addr.begin()+ii_keep)->second,
 
  832                                         (addr.begin()+jj)->second));
 
  834                 dist((addr.begin()+ii_keep)->second,
 
  835                      (addr.begin()+jj)->second) = bla;
 
  836                 dist((addr.begin()+jj)->second,
 
  837                      (addr.begin()+ii_keep)->second) = bla;
 
  858     bool operator()(Node& node)
 
  871 template<
class Iter, 
class DT>
 
  876     Matrix<double> tmp_mem_;
 
  879     Matrix<double> feats_;
 
  886     template<
class Feat_T, 
class Label_T>
 
  889                    Feat_T 
const & feats,
 
  890                    Label_T 
const & labls, 
 
  895         :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
 
  898          feats_(_spl(a,b).size(), feats.shape(1)),
 
  899          labels_(_spl(a,b).size(),1),
 
  905         copy_splice(_spl(a,b),
 
  906                     _spl(feats.shape(1)),
 
  909         copy_splice(_spl(a,b),
 
  910                     _spl(labls.shape(1)),
 
  916     bool operator()(Node& node)
 
  920         int class_count = perm_imp.
shape(1) - 1;
 
  922         for(
int kk = 0; kk < nPerm; ++kk)
 
  925             for(
int ii = 0; ii < 
rowCount(feats_); ++ii)
 
  928                 for(
int jj = 0; jj < node.columns_size(); ++jj)
 
  930                     if(node.columns_begin()[jj] != feats_.shape(1))
 
  931                         tmp_mem_(ii, node.columns_begin()[jj]) 
 
  932                             = tmp_mem_(index, node.columns_begin()[jj]);
 
  936             for(
int ii = 0; ii < 
rowCount(tmp_mem_); ++ii)
 
  943                     ++perm_imp(index,labels_(ii, 0));
 
  945                     ++perm_imp(index, class_count);
 
  949         double node_status  = perm_imp(index, class_count);
 
  950         node_status /= nPerm;
 
  951         node_status -= orig_imp(0, class_count);
 
  953         node_status /= oob_size;
 
  954         node.status() += node_status;
 
  975     void save(std::string file, std::string prefix)
 
  983     bool operator()(Node& node)
 
  985         for(
int ii = 0; ii < node.columns_size(); ++ii)
 
  986             variables(index, ii) = node.columns_begin()[ii];
 
 1000     bool operator()(Nde & cur, 
int , Nde parent, 
bool )
 
 1003             cur.status() = std::min(parent.status(), cur.status());
 
 1030     std::ofstream graphviz;
 
 1035          std::string 
const  gz)
 
 1036         :features_(features), labels_(labels), 
 
 1037         graphviz(gz.c_str(), std::ios::out)
 
 1039         graphviz << 
"digraph G\n{\n node [shape=\"record\"]";
 
 1043         graphviz << 
"\n}\n";
 
 1048     bool operator()(Nde & cur, 
int , Nde parent, 
bool )
 
 1050         graphviz << 
"node" << cur.index() << 
" [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() << 
"\\n";
 
 1051         graphviz << 
" status: " << cur.status() << 
"\\n";
 
 1052         for(
int kk = 0; kk < cur.columns_size(); ++kk)
 
 1054                 graphviz  << cur.columns_begin()[kk] << 
" ";
 
 1058         graphviz << 
"\"] [color = \"" <<cur.status() << 
" 1.000 1.000\"];\n";
 
 1060         graphviz << 
"\"node" << parent.index() << 
"\" -> \"node" << cur.index() <<
"\";\n";
 
 1080     int                         repetition_count_;
 
 1086     void save(std::string filename, std::string prefix)
 
 1088         std::string prefix1 = 
"cluster_importance_" + prefix;
 
 1092         prefix1 = 
"vars_" + prefix;
 
 1100     :   repetition_count_(rep_cnt), clustering(clst)
 
 1106     template<
class RF, 
class PR>
 
 1109         Int32 const  class_count = rf.ext_param_.class_count_;
 
 1110         Int32 const  column_count = rf.ext_param_.column_count_+1;
 
 1131     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1135         Int32                   column_count = rf.ext_param_.column_count_ +1;
 
 1136         Int32                   class_count  = rf.ext_param_.class_count_;  
 
 1140         typename PR::Feature_t & features 
 
 1141             = 
const_cast<typename PR::Feature_t &
>(pr.features());
 
 1148         if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
 
 1152             for(
int ii = 0; ii < pr.features().shape(0); ++ii)
 
 1153                indices.push_back(ii); 
 
 1154             std::random_shuffle(indices.begin(), indices.end());
 
 1155             for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
 
 1157                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
 
 1159                     oob_indices.push_back(indices[ii]);
 
 1160                     ++cts[pr.response()(indices[ii], 0)];
 
 1166             for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
 
 1167                 if(!sm.is_used()[ii])
 
 1168                     oob_indices.push_back(ii);
 
 1178                     oob_right(Shp_t(1, class_count + 1)); 
 
 1181         for(iter = oob_indices.
begin(); 
 
 1182             iter != oob_indices.
end(); 
 
 1186                     .predictLabel(
rowVector(features, *iter)) 
 
 1187                 ==  pr.response()(*iter, 0))
 
 1190                 ++oob_right[pr.response()(*iter,0)];
 
 1192                 ++oob_right[class_count];
 
 1197                     perm_oob_right (Shp_t(2* column_count-1, class_count + 1)); 
 
 1200             pc(oob_indices.
begin(), oob_indices.
end(), 
 
 1209         perm_oob_right  /=  repetition_count_;
 
 1210         for(
int ii = 0; ii < 
rowCount(perm_oob_right); ++ii)
 
 1211             rowVector(perm_oob_right, ii) -= oob_right;
 
 1213         perm_oob_right       *= -1;
 
 1214         perm_oob_right       /= oob_indices.
size();
 
 1223     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1231     template<
class RF, 
class PR>
 
 1271 template<
class FeatureT, 
class ResponseT>
 
 1273                                          ResponseT         
const &     response,
 
 1280         if(features.shape(0) > 40000)
 
 1287         RF.
learn(features, response,
 
 1288                  create_visitor(missc, progress));
 
 1303                   create_visitor(progress, ci));
 
 1316 template<
class FeatureT, 
class ResponseT>
 
 1318                                          ResponseT         
const &     response,
 
 1319                                     HClustering               & linkage)
 
 1326 template<
class Array1, 
class Vector1>
 
 1327 void get_ranking(Array1 
const & in, Vector1 & out)
 
 1329     std::map<double, int> mymap;
 
 1330     for(
int ii = 0; ii < in.size(); ++ii)
 
 1332     for(std::map<double, int>::reverse_iterator iter = mymap.rbegin(); iter!= mymap.rend(); ++iter)
 
 1334         out.push_back(iter->second);
 
 1340 #endif //VIGRA_RF_ALGORITHM_HXX 
UInt32 uniformInt() const 
Definition: random.hxx:464
double no_features
Definition: rf_algorithm.hxx:151
reference back()
Definition: array_vector.hxx:321
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArray< 2, double > cluster_stdev_
Definition: rf_algorithm.hxx:1079
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Topology_type column_data() const 
Definition: rf_nodeproxy.hxx:159
MultiArray< 2, double > cluster_importance_
Definition: rf_algorithm.hxx:1076
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning. 
Definition: rf_common.hxx:411
const difference_type & shape() const 
Definition: multi_array.hxx:1648
Definition: rf_algorithm.hxx:1067
void visit_at_end(RF &rf, PR &)
Definition: rf_algorithm.hxx:1232
const_iterator begin() const 
Definition: array_vector.hxx:223
void visit_at_beginning(RF const &rf, PR const &)
Definition: rf_algorithm.hxx:1107
NodeBase()
Definition: rf_nodeproxy.hxx:237
Definition: rf_algorithm.hxx:847
NormalizeStatus(double m)
Definition: rf_algorithm.hxx:854
Definition: rf_algorithm.hxx:996
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Definition: rf_visitors.hxx:863
void forward_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:294
Definition: rf_visitors.hxx:1495
void backward_elimination(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:396
Definition: rf_algorithm.hxx:611
Definition: rf_algorithm.hxx:872
Definition: rf_algorithm.hxx:83
difference_type_1 size() const 
Definition: multi_array.hxx:1641
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
Definition: multi_fwd.hxx:63
Random forest version 2 (see also vigra::rf3::RandomForest for version 3) 
Definition: random_forest.hxx:147
reference front()
Definition: array_vector.hxx:307
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
bool init(FeatureT const &all_features, ResponseT const &response, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:205
void breadth_first_traversal(Functor &tester)
Definition: rf_algorithm.hxx:679
Definition: rf_algorithm.hxx:638
void cluster_permutation_importance(FeatureT const &features, ResponseT const &response, HClustering &linkage, MultiArray< 2, double > &distance)
Definition: rf_algorithm.hxx:1272
Definition: rf_algorithm.hxx:963
Parameter_type parameters_begin() const 
Definition: rf_nodeproxy.hxx:207
Definition: metaprogramming.hxx:123
double oob_breiman
Definition: rf_visitors.hxx:874
ErrorList_t errors
Definition: rf_algorithm.hxx:146
void writeHDF5(...)
Store array data in an HDF5 file. 
Definition: rf_visitors.hxx:1460
INT & typeID()
Definition: rf_nodeproxy.hxx:136
void cluster(MultiArrayView< 2, T, C > distance)
Definition: rf_algorithm.hxx:733
Definition: rf_algorithm.hxx:1024
MultiArray< 2, int > variables
Definition: rf_algorithm.hxx:1073
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: rf_visitors.hxx:101
void rank_selection(FeatureT const &features, ResponseT const &response, VariableSelectionResult &result, ErrorRateCallBack errorcallback)
Definition: rf_algorithm.hxx:493
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_algorithm.hxx:1224
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
Definition: random.hxx:336
Options object for the random forest. 
Definition: rf_common.hxx:170
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_algorithm.hxx:1132
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy 
Definition: rf_common.hxx:374
RandomForestOptions & tree_count(unsigned int in)
Definition: rf_common.hxx:500
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator 
Definition: random_forest.hxx:941
const_iterator end() const 
Definition: array_vector.hxx:237
const_pointer data() const 
Definition: array_vector.hxx:209
void iterate(Functor &tester)
Definition: rf_algorithm.hxx:656
FeatureList_t selected
Definition: rf_algorithm.hxx:133
size_type size() const 
Definition: array_vector.hxx:358
MultiArrayView< 2, int > variables
Definition: rf_algorithm.hxx:969
double operator()(Feature_t const &features, Response_t const &response)
Definition: rf_algorithm.hxx:101
RFErrorCallback(RandomForestOptions opt=RandomForestOptions())
Definition: rf_algorithm.hxx:93
Definition: rf_algorithm.hxx:116
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344