37 #ifndef VIGRA_RANDOM_FOREST_HXX 
   38 #define VIGRA_RANDOM_FOREST_HXX 
   46 #include "mathutil.hxx" 
   47 #include "array_vector.hxx" 
   48 #include "sized_int.hxx" 
   50 #include "metaprogramming.hxx" 
   52 #include "functorexpression.hxx" 
   53 #include "random_forest/rf_common.hxx" 
   54 #include "random_forest/rf_nodeproxy.hxx" 
   55 #include "random_forest/rf_split.hxx" 
   56 #include "random_forest/rf_decisionTree.hxx" 
   57 #include "random_forest/rf_visitors.hxx" 
   58 #include "random_forest/rf_region.hxx" 
   59 #include "sampling.hxx" 
   60 #include "random_forest/rf_preprocessing.hxx" 
   61 #include "random_forest/rf_online_prediction_set.hxx" 
   62 #include "random_forest/rf_earlystopping.hxx" 
   63 #include "random_forest/rf_ridge_split.hxx" 
   83 inline SamplerOptions make_sampler_opt ( RandomForestOptions     & RF_opt)
 
   85     SamplerOptions return_opt;
 
   87     return_opt.
stratified(RF_opt.stratification_method_ == RF_EQUAL);
 
  146 template <
class LabelType = 
double , 
class PreprocessorTag = ClassificationTag >
 
  153     typedef detail::DecisionTree            DecisionTree_t;
 
  160     typedef LabelType                       LabelT;
 
  227     template<
class TopologyIterator, 
class ParameterIterator>
 
  229                   TopologyIterator         topology_begin,
 
  230                   ParameterIterator        parameter_begin,
 
  234         trees_(treeCount, DecisionTree_t(problem_spec)),
 
  235         ext_param_(problem_spec),
 
  241         for(
int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
 
  243             trees_[k].topology_ = *topology_begin;
 
  244             trees_[k].parameters_ = *parameter_begin;
 
  262         vigra_precondition(ext_param_.used() == 
true,
 
  263            "RandomForest::ext_param(): " 
  264            "Random forest has not been trained yet.");
 
  281         vigra_precondition(ext_param_.used() == 
false,
 
  282             "RandomForest::set_ext_param():" 
  283             "Random forest has been trained! Call reset()" 
  284             "before specifying new extrinsic parameters.");
 
  308     DecisionTree_t 
const & 
tree(
int index)
 const 
  310         return trees_[index];
 
  315     DecisionTree_t & 
tree(
int index)
 
  317         return trees_[index];
 
  325       return ext_param_.column_count_;
 
  336       return ext_param_.column_count_;
 
  344       return ext_param_.class_count_;
 
  351       return options_.tree_count_;
 
  392     template <
class U, 
class C1,
 
  403                 Random_t                 
const  &   random);
 
  405     template <
class U, 
class C1,
 
  426     template <
class U, 
class C1, 
class U2,
class C2, 
class Visitor_t>
 
  427     void learn( MultiArrayView<2, U, C1> 
const  & features,
 
  428                 MultiArrayView<2, U2,C2> 
const  & labels,
 
  438     template <
class U, 
class C1, 
class U2,
class C2,
 
  439               class Visitor_t, 
class Split_t>
 
  440     void learn(   MultiArrayView<2, U, C1> 
const  & features,
 
  441                   MultiArrayView<2, U2,C2> 
const  & labels,
 
  470     template <
class U, 
class C1, 
class U2,
class C2>
 
  482     template<
class U,
class C1,
 
  495                         bool adjust_thresholds=
false);
 
  497     template <
class U, 
class C1, 
class U2,
class C2>
 
  502         onlineLearn(features,
 
  512     template<
class U,
class C1,
 
  518     void reLearnTree(MultiArrayView<2,U,C1> 
const & features,
 
  519                      MultiArrayView<2,U2,C2> 
const & response,
 
  526     template<
class U, 
class C1, 
class U2, 
class C2>
 
  527     void reLearnTree(MultiArrayView<2, U, C1> 
const & features,
 
  528                      MultiArrayView<2, U2, C2> 
const & labels,
 
  531         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
 
  561     template <
class U, 
class C, 
class Stop>
 
  562     LabelType 
predictLabel(MultiArrayView<2, U, C>
const & features, Stop & stop) 
const;
 
  564     template <
class U, 
class C>
 
  565     LabelType 
predictLabel(MultiArrayView<2, U, C>
const & features)
 
  575     template <
class U, 
class C>
 
  576     LabelType 
predictLabel(MultiArrayView<2, U, C> 
const & features,
 
  577                                 ArrayVectorView<double> prior) 
const;
 
  589     template <
class U, 
class C1, 
class T, 
class C2>
 
  593         vigra_precondition(features.
shape(0) == labels.
shape(0),
 
  594             "RandomForest::predictLabels(): Label array has wrong size.");
 
  595         for(
int k=0; k<features.
shape(0); ++k)
 
  597             vigra_precondition(!detail::contains_nan(
rowVector(features, k)),
 
  598                 "RandomForest::predictLabels(): NaN in feature matrix.");
 
  613     template <
class U, 
class C1, 
class T, 
class C2>
 
  616                        LabelType nanLabel)
 const 
  618         vigra_precondition(features.
shape(0) == labels.
shape(0),
 
  619             "RandomForest::predictLabels(): Label array has wrong size.");
 
  620         for(
int k=0; k<features.
shape(0); ++k)
 
  622             if(detail::contains_nan(
rowVector(features, k)))
 
  623                 labels(k,0) = nanLabel;
 
  638     template <
class U, 
class C1, 
class T, 
class C2, 
class Stop>
 
  643         vigra_precondition(features.
shape(0) == labels.
shape(0),
 
  644             "RandomForest::predictLabels(): Label array has wrong size.");
 
  645         for(
int k=0; k<features.
shape(0); ++k)
 
  660     template <
class U, 
class C1, 
class T, 
class C2, 
class Stop>
 
  664     template <
class T1,
class T2, 
class C>
 
  674     template <
class U, 
class C1, 
class T, 
class C2>
 
  681     template <
class U, 
class C1, 
class T, 
class C2>
 
  691 template <
class LabelType, 
class PreprocessorTag>
 
  692 template<
class U,
class C1,
 
  698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> 
const & features,
 
  699                                                              MultiArrayView<2,U2,C2> 
const & response,
 
  705                                                              bool adjust_thresholds)
 
  707     online_visitor_.activate();
 
  708     online_visitor_.adjust_thresholds=adjust_thresholds;
 
  712     typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
 
  713     typedef          UniformIntRandomFunctor<Random_t>
 
  720     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
  721     Default_Stop_t default_stop(options_);
 
  722     typename RF_CHOOSER(Stop_t)::type stop
 
  723             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
 
  724     Default_Split_t default_split;
 
  725     typename RF_CHOOSER(Split_t)::type split
 
  726             = RF_CHOOSER(Split_t)::choose(split_, default_split);
 
  727     rf::visitors::StopVisiting stopvisiting;
 
  728     typedef  rf::visitors::detail::VisitorNode
 
  729                 <rf::visitors::OnlineLearnVisitor,
 
  730                  typename RF_CHOOSER(Visitor_t)::type>
 
  733         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
 
  735     vigra_precondition(options_.prepare_online_learning_,
"onlineLearn: online learning must be enabled on RandomForest construction");
 
  741     ext_param_.class_count_=0;
 
  742     Preprocessor_t preprocessor(    features, response,
 
  743                                     options_, ext_param_);
 
  746     RandFunctor_t           randint     ( random);
 
  749     split.set_external_parameters(ext_param_);
 
  750     stop.set_external_parameters(ext_param_);
 
  754     PoissonSampler<RandomTT800> poisson_sampler(1.0,
vigra::Int32(new_start_index),
vigra::Int32(ext_param().row_count_));
 
  760     for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
 
  762         online_visitor_.tree_id=ii;
 
  763         poisson_sampler.sample();
 
  764         std::map<int,int> leaf_parents;
 
  765         leaf_parents.clear();
 
  767         for(
int s=0;s<poisson_sampler.numOfSamples();++s)
 
  769             int sample=poisson_sampler[s];
 
  770             online_visitor_.current_label=preprocessor.response()(sample,0);
 
  771             online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
 
  772             int leaf=trees_[ii].getToLeaf(
rowVector(features,sample),online_visitor_);
 
  776             online_visitor_.add_to_index_list(ii,leaf,sample);
 
  779             if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
 
  781                 leaf_parents[leaf]=online_visitor_.last_node_id;
 
  786         std::map<int,int>::iterator leaf_iterator;
 
  787         for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
 
  789             int leaf=leaf_iterator->first;
 
  790             int parent=leaf_iterator->second;
 
  791             int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
 
  792             ArrayVector<Int32> indeces;
 
  794             indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
 
  795             StackEntry_t stack_entry(indeces.begin(),
 
  797                                      ext_param_.class_count_);
 
  802                 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
 
  804                     stack_entry.leftParent=parent;
 
  808                     vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,
"last_node_id seems to be wrong");
 
  809                     stack_entry.rightParent=parent;
 
  813             trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
 
  815             online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
 
  828     online_visitor_.deactivate();
 
  831 template<
class LabelType, 
class PreprocessorTag>
 
  832 template<
class U,
class C1,
 
  853     ext_param_.class_count_=0;
 
  861     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
  863     typename RF_CHOOSER(Stop_t)::type stop
 
  864             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
 
  866     typename RF_CHOOSER(Split_t)::type split
 
  867             = RF_CHOOSER(Split_t)::choose(split_, default_split);
 
  871                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
 
  873         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
 
  875     vigra_precondition(options_.prepare_online_learning_,
"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
 
  876     online_visitor_.activate();
 
  879     RandFunctor_t           randint     ( random);
 
  885     Preprocessor_t preprocessor(    features, response,
 
  886                                     options_, ext_param_);
 
  889     split.set_external_parameters(ext_param_);
 
  890     stop.set_external_parameters(ext_param_);
 
  897                                preprocessor.strata().end(),
 
  898                                detail::make_sampler_opt(options_)
 
  899                                         .sampleSize(ext_param().actual_msample_),
 
  906         first_stack_entry(  sampler.sampledIndices().begin(),
 
  907                             sampler.sampledIndices().end(),
 
  908                             ext_param_.class_count_);
 
  910         .set_oob_range(     sampler.oobIndices().begin(),
 
  911                             sampler.oobIndices().end());
 
  912     online_visitor_.reset_tree(treeId);
 
  913     online_visitor_.tree_id=treeId;
 
  914     trees_[treeId].reset();
 
  916         .learn( preprocessor.features(),
 
  917                 preprocessor.response(),
 
  924         .visit_after_tree(  *
this,
 
  930     online_visitor_.deactivate();
 
  933 template <
class LabelType, 
class PreprocessorTag>
 
  934 template <
class U, 
class C1,
 
  946                             Random_t                 
const  &   random)
 
  957     vigra_precondition(features.
shape(0) == response.
shape(0),
 
  958         "RandomForest::learn(): shape mismatch between features and response.");
 
  965     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
  967     typename RF_CHOOSER(Stop_t)::type stop
 
  968             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
 
  970     typename RF_CHOOSER(Split_t)::type split
 
  971             = RF_CHOOSER(Split_t)::choose(split_, default_split);
 
  975                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
 
  977         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
 
  979     if(options_.prepare_online_learning_)
 
  980         online_visitor_.activate();
 
  982         online_visitor_.deactivate();
 
  986     RandFunctor_t           randint     ( random);
 
  993     Preprocessor_t preprocessor(    features, response,
 
  994                                     options_, ext_param_);
 
  997     split.set_external_parameters(ext_param_);
 
  998     stop.set_external_parameters(ext_param_);
 
 1002     trees_.resize(options_.tree_count_  , DecisionTree_t(ext_param_));
 
 1005                                preprocessor.strata().end(),
 
 1006                                detail::make_sampler_opt(options_)
 
 1007                                         .sampleSize(ext_param().actual_msample_),
 
 1010     visitor.visit_at_beginning(*
this, preprocessor);
 
 1013     for(
int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
 
 1019             first_stack_entry(  sampler.sampledIndices().begin(),
 
 1020                                 sampler.sampledIndices().end(),
 
 1021                                 ext_param_.class_count_);
 
 1023             .set_oob_range(     sampler.oobIndices().begin(),
 
 1024                                 sampler.oobIndices().end());
 
 1026             .learn(             preprocessor.features(),
 
 1027                                 preprocessor.response(),
 
 1034             .visit_after_tree(  *
this,
 
 1041     visitor.visit_at_end(*
this, preprocessor);
 
 1043     online_visitor_.deactivate();
 
 1049 template <
class LabelType, 
class Tag>
 
 1050 template <
class U, 
class C, 
class Stop>
 
 1054     vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
 
 1055         "RandomForestn::predictLabel():" 
 1056             " Too few columns in feature matrix.");
 
 1057     vigra_precondition(
rowCount(features) == 1,
 
 1058         "RandomForestn::predictLabel():" 
 1059             " Feature matrix must have a singlerow.");
 
 1062     predictProbabilities(features, probabilities, stop);
 
 1063     ext_param_.to_classlabel(
argMax(probabilities), d);
 
 1069 template <
class LabelType, 
class PreprocessorTag>
 
 1070 template <
class U, 
class C>
 
 1075     using namespace functor;
 
 1076     vigra_precondition(
columnCount(features) >= ext_param_.column_count_,
 
 1077         "RandomForestn::predictLabel(): Too few columns in feature matrix.");
 
 1078     vigra_precondition(
rowCount(features) == 1,
 
 1079         "RandomForestn::predictLabel():" 
 1080         " Feature matrix must have a single row.");
 
 1081     Matrix<double>  prob(1,ext_param_.class_count_);
 
 1082     predictProbabilities(features, prob);
 
 1083     std::transform( prob.begin(), prob.end(),
 
 1084                     priors.
begin(), prob.begin(),
 
 1087     ext_param_.to_classlabel(
argMax(prob), d);
 
 1091 template<
class LabelType,
class PreprocessorTag>
 
 1092 template <
class T1,
class T2, 
class C>
 
 1101                        "RandomFroest::predictProbabilities():" 
 1102                        " Feature matrix and probability matrix size mismatch.");
 
 1105     vigra_precondition( 
columnCount(predictionSet.features) >= ext_param_.column_count_,
 
 1106       "RandomForestn::predictProbabilities():" 
 1107         " Too few columns in feature matrix.");
 
 1109                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
 
 1110       "RandomForestn::predictProbabilities():" 
 1111       " Probability matrix must have as many columns as there are classes.");
 
 1114     std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
 
 1117     for(
int k=0; k<options_.tree_count_; ++k)
 
 1119         set_id=(set_id+1) % predictionSet.indices[0].size();
 
 1120         typedef std::set<SampleRange<T1> > my_set;
 
 1121         typedef typename my_set::iterator set_it;
 
 1124         std::vector<std::pair<int,set_it> > stack;
 
 1126         for(set_it i=predictionSet.ranges[set_id].begin();
 
 1127              i!=predictionSet.ranges[set_id].end();++i)
 
 1128             stack.push_back(std::pair<int,set_it>(2,i));
 
 1130         int num_decisions=0;
 
 1131         while(!stack.empty())
 
 1133             set_it range=stack.back().second;
 
 1134             int index=stack.back().first;
 
 1138             if(trees_[k].isLeafNode(trees_[k].topology_[index]))
 
 1141                                                                             trees_[k].parameters_,
 
 1142                                                                             index).prob_begin();
 
 1143                 for(
int i=range->start;i!=range->end;++i)
 
 1146                     for(
int l=0; l<ext_param_.class_count_; ++l)
 
 1148                         prob(predictionSet.indices[set_id][i], l) += 
static_cast<T2
>(weights[l]);
 
 1150                         totalWeights[predictionSet.indices[set_id][i]] += 
static_cast<T1
>(weights[l]);
 
 1157                 if(trees_[k].topology_[index]!=i_ThresholdNode)
 
 1159                     throw std::runtime_error(
"predicting with online prediction sets is only supported for RFs with threshold nodes");
 
 1161                 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
 
 1162                 if(range->min_boundaries[node.column()]>=node.threshold())
 
 1165                     stack.push_back(std::pair<int,set_it>(node.child(1),range));
 
 1168                 if(range->max_boundaries[node.column()]<node.threshold())
 
 1171                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
 
 1175                 SampleRange<T1> new_range=*range;
 
 1176                 new_range.min_boundaries[node.column()]=FLT_MAX;
 
 1177                 range->max_boundaries[node.column()]=-FLT_MAX;
 
 1178                 new_range.start=new_range.end=range->end;
 
 1180                 while(i!=range->end)
 
 1183                     if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
 
 1185                         new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
 
 1186                                                                     predictionSet.features(predictionSet.indices[set_id][i],node.column()));
 
 1189                         std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
 
 1194                         range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
 
 1195                                                                  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
 
 1200                 if(range->start==range->end)
 
 1202                     predictionSet.ranges[set_id].erase(range);
 
 1206                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
 
 1209                 if(new_range.start!=new_range.end)
 
 1211                     std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
 
 1212                     stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
 
 1216         predictionSet.cumulativePredTime[k]=num_decisions;
 
 1218     for(
unsigned int i=0;i<totalWeights.size();++i)
 
 1222         for(
int l=0; l<ext_param_.class_count_; ++l)
 
 1225             prob(i, l) /= totalWeights[i];
 
 1227         assert(test==totalWeights[i]);
 
 1228         assert(totalWeights[i]>0.0);
 
 1232 template <
class LabelType, 
class PreprocessorTag>
 
 1233 template <
class U, 
class C1, 
class T, 
class C2, 
class Stop_t>
 
 1236                            MultiArrayView<2, T, C2> &       prob,
 
 1237                            Stop_t                   &       stop_)
 const 
 1243       "RandomForestn::predictProbabilities():" 
 1244         " Feature matrix and probability matrix size mismatch.");
 
 1248     vigra_precondition( 
columnCount(features) >= ext_param_.column_count_,
 
 1249       "RandomForestn::predictProbabilities():" 
 1250         " Too few columns in feature matrix.");
 
 1252                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
 
 1253       "RandomForestn::predictProbabilities():" 
 1254       " Probability matrix must have as many columns as there are classes.");
 
 1256     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
 1257     Default_Stop_t default_stop(options_);
 
 1258     typename RF_CHOOSER(Stop_t)::type & stop
 
 1259             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
 
 1261     stop.set_external_parameters(ext_param_, tree_count());
 
 1262     prob.init(NumericTraits<T>::zero());
 
 1272     for(
int row=0; row < 
rowCount(features); ++row)
 
 1274         MultiArrayView<2, U, StridedArrayTag> currentRow(
rowVector(features, row));
 
 1278         if(detail::contains_nan(currentRow))
 
 1284         ArrayVector<double>::const_iterator weights;
 
 1287         double totalWeight = 0.0;
 
 1290         for(
int k=0; k<options_.tree_count_; ++k)
 
 1293             weights = trees_[k ].predict(currentRow);
 
 1296             int weighted = options_.predict_weighted_;
 
 1297             for(
int l=0; l<ext_param_.class_count_; ++l)
 
 1299                 double cur_w = weights[l] * (weighted * (*(weights-1))
 
 1301                 prob(row, l) += 
static_cast<T
>(cur_w);
 
 1303                 totalWeight += cur_w;
 
 1305             if(stop.after_prediction(weights,
 
 1315         for(
int l=0; l< ext_param_.class_count_; ++l)
 
 1317             prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
 
 1323 template <
class LabelType, 
class PreprocessorTag>
 
 1324 template <
class U, 
class C1, 
class T, 
class C2>
 
 1325 void RandomForest<LabelType, PreprocessorTag>
 
 1326     ::predictRaw(MultiArrayView<2, U, C1>
const &  features,
 
 1327                            MultiArrayView<2, T, C2> &       prob)
 const 
 1333       "RandomForestn::predictProbabilities():" 
 1334         " Feature matrix and probability matrix size mismatch.");
 
 1338     vigra_precondition( 
columnCount(features) >= ext_param_.column_count_,
 
 1339       "RandomForestn::predictProbabilities():" 
 1340         " Too few columns in feature matrix.");
 
 1342                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
 
 1343       "RandomForestn::predictProbabilities():" 
 1344       " Probability matrix must have as many columns as there are classes.");
 
 1346     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
 1347     prob.init(NumericTraits<T>::zero());
 
 1357     for(
int row=0; row < 
rowCount(features); ++row)
 
 1359         ArrayVector<double>::const_iterator weights;
 
 1362         double totalWeight = 0.0;
 
 1365         for(
int k=0; k<options_.tree_count_; ++k)
 
 1368             weights = trees_[k ].predict(
rowVector(features, row));
 
 1371             int weighted = options_.predict_weighted_;
 
 1372             for(
int l=0; l<ext_param_.class_count_; ++l)
 
 1374                 double cur_w = weights[l] * (weighted * (*(weights-1))
 
 1376                 prob(row, l) += 
static_cast<T
>(cur_w);
 
 1378                 totalWeight += cur_w;
 
 1382     prob/= options_.tree_count_;
 
 1388 #include "random_forest/rf_algorithm.hxx" 
 1389 #endif // VIGRA_RANDOM_FOREST_HXX 
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters 
Definition: random_forest.hxx:278
int class_count() const 
return number of classes used while training. 
Definition: random_forest.hxx:342
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag 
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
Definition: rf_preprocessing.hxx:63
int feature_count() const 
return number of features used while training. 
Definition: random_forest.hxx:323
int column_count() const 
return number of features used while training. 
Definition: random_forest.hxx:334
Create random samples from a sequence of indices. 
Definition: sampling.hxx:232
const difference_type & shape() const 
Definition: multi_array.hxx:1648
void sample()
Definition: sampling.hxx:467
Definition: rf_split.hxx:993
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const 
predict multiple labels with given features 
Definition: random_forest.hxx:614
const_iterator begin() const 
Definition: array_vector.hxx:223
problem specification class for the random forest. 
Definition: rf_common.hxx:538
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor 
Definition: random_forest.hxx:197
Standard early stopping criterion. 
Definition: rf_common.hxx:885
ProblemSpec_t const & ext_param() const 
return external parameters for viewing 
Definition: random_forest.hxx:260
DecisionTree_t & tree(int index)
access trees 
Definition: random_forest.hxx:315
DecisionTree_t const & tree(int index) const 
access const trees 
Definition: random_forest.hxx:308
Options_t & set_options()
access random forest options 
Definition: random_forest.hxx:291
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration 
Definition: random_forest.hxx:471
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const 
predict the class probabilities for multiple labels 
Definition: random_forest.hxx:675
Random forest version 2 (see also vigra::rf3::RandomForest for version 3) 
Definition: random_forest.hxx:147
Options_t const & options() const 
access const random forest options 
Definition: random_forest.hxx:301
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence. 
Definition: algorithm.hxx:96
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const 
predict a label given a feature. 
Definition: random_forest.hxx:1052
Definition: rf_visitors.hxx:583
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const 
predict multiple labels with given features 
Definition: random_forest.hxx:590
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class. 
Definition: sampling.hxx:141
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement. 
Definition: sampling.hxx:83
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
int tree_count() const 
return number of trees 
Definition: random_forest.hxx:349
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:838
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source. 
Definition: random_forest.hxx:228
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray. 
Definition: multi_array.hxx:704
Options object for the random forest. 
Definition: rf_common.hxx:170
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const 
predict multiple labels with given features 
Definition: random_forest.hxx:639
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1206
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator 
Definition: random_forest.hxx:941
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const 
predict the class probabilities for multiple labels 
Definition: rf_visitors.hxx:234