35 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX 
   36 #define VIGRA_RANDOM_FOREST_SPLIT_HXX 
   42 #include "../mathutil.hxx" 
   43 #include "../array_vector.hxx" 
   44 #include "../sized_int.hxx" 
   45 #include "../matrix.hxx" 
   46 #include "../random.hxx" 
   47 #include "../functorexpression.hxx" 
   48 #include "rf_nodeproxy.hxx" 
   50 #include "rf_region.hxx" 
   59 class CompileTimeError;
 
   69         static void exec(Iter , Iter )
 
   74     class Normalise<ClassificationTag>
 
   78         static void exec (Iter begin, Iter end)
 
   80             double bla = std::accumulate(begin, end, 0.0);
 
   81             for(
int ii = 0; ii < end - begin; ++ii)
 
   82                 begin[ii] = begin[ii]/bla ;
 
  115         t_data.push_back(in.column_count_);
 
  116         t_data.push_back(in.class_count_);
 
  124     int classCount()
 const 
  126         return int(t_data[1]);
 
  129     int featureCount()
 const 
  131         return int(t_data[0]);
 
  149     template<
class T, 
class C, 
class T2, 
class C2, 
class Region, 
class Random>
 
  158         CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
 
  167     template<
class T, 
class C, 
class T2,
class C2, 
class Region, 
class Random>
 
  175         if(ext_param_.class_weights_.
size() != region.classCounts().size())
 
  177             std::copy(region.classCounts().begin(),
 
  178                       region.classCounts().end(),
 
  183             std::transform(region.classCounts().begin(),
 
  184                            region.classCounts().end(),
 
  185                            ext_param_.class_weights_.
begin(),
 
  186                            ret.prob_begin(), std::multiplies<double>());
 
  188         detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
 
  192         return e_ConstProbNode;
 
  200 template<
class DataMatrix>
 
  203     DataMatrix 
const & data_;
 
  210                             double thresVal = 0.0)
 
  212       sortColumn_(sortColumn),
 
  218         sortColumn_ = sortColumn;
 
  220     void setThreshold(
double value)
 
  227         return data_(l, sortColumn_) < data_(r, sortColumn_);
 
  231         return data_(l, sortColumn_) < thresVal_;
 
  235 template<
class DataMatrix>
 
  236 class DimensionNotEqual
 
  238     DataMatrix 
const & data_;
 
  243     DimensionNotEqual(DataMatrix 
const & data, 
 
  246       sortColumn_(sortColumn)
 
  251         sortColumn_ = sortColumn;
 
  256         return data_(l, sortColumn_) != data_(r, sortColumn_);
 
  260 template<
class DataMatrix>
 
  261 class SortSamplesByHyperplane
 
  263     DataMatrix 
const & data_;
 
  264     Node<i_HyperplaneNode> 
const & node_;
 
  268     SortSamplesByHyperplane(DataMatrix              
const & data, 
 
  269                             Node<i_HyperplaneNode>  
const & node)
 
  279         double result_l = -1 * node_.intercept();
 
  280         for(
int ii = 0; ii < node_.columns_size(); ++ii)
 
  282             result_l +=     
rowVector(data_, l)[node_.columns_begin()[ii]] 
 
  283                         *   node_.weights()[ii];
 
  290         return (*
this)[l]  < (*this)[r];
 
  304 template <
class DataSource, 
class CountArray>
 
  307     DataSource  
const &     labels_;
 
  308     CountArray        &     counts_;
 
  327         counts_[labels_[l]] +=1;
 
  342         double operator[](
size_t)
 const 
  362     template<
class Array, 
class Array2>
 
  364                               Array2    
const & weights, 
 
  365                               double            total = 1.0)
 const 
  367         return impurity(hist, weights, total);
 
  372     template<
class Array>
 
  373     double operator()(Array 
const & hist, 
double total = 1.0)
 const 
  380     template<
class Array>
 
  381     static double impurity(Array 
const & hist, 
double total)
 
  383         return impurity(hist, detail::ConstArr<1>(), total);
 
  388     template<
class Array, 
class Array2>
 
  390                               Array2    
const & weights, 
 
  394         int     class_count     = hist.size();
 
  395         double  entropy            = 0.0;
 
  398             double p0           = (hist[0]/total);
 
  399             double p1           = (hist[1]/total);
 
  404             for(
int ii = 0; ii < class_count; ++ii)
 
  406                 double w        = weights[ii];
 
  407                 double pii      = hist[ii]/total;
 
  411         entropy             = total * entropy;
 
  424     template<
class Array, 
class Array2>
 
  426                               Array2    
const & weights, 
 
  427                               double            total = 1.0)
 const 
  429         return impurity(hist, weights, total);
 
  434     template<
class Array>
 
  435     double operator()(Array 
const & hist, 
double total = 1.0)
 const 
  442     template<
class Array>
 
  443     static double impurity(Array 
const & hist, 
double total)
 
  445         return impurity(hist, detail::ConstArr<1>(), total);
 
  450     template<
class Array, 
class Array2>
 
  452                               Array2    
const & weights, 
 
  456         int     class_count     = hist.size();
 
  460             double w            = weights[0] * weights[1];
 
  461             gini                = w * (hist[0] * hist[1] / total);
 
  465             for(
int ii = 0; ii < class_count; ++ii)
 
  467                 double w        = weights[ii];
 
  468                 gini           += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
 
  476 template <
class DataSource, 
class Impurity= GiniCriterion>
 
  480     DataSource  
const &         labels_;
 
  481     ArrayVector<double>        counts_;
 
  482     ArrayVector<double> 
const  class_weights_;
 
  483     double                      total_counts_;
 
  489     ImpurityLoss(DataSource  
const & labels, 
 
  490                                 ProblemSpec<T> 
const & ext_)
 
  492       counts_(ext_.class_count_, 0.0),
 
  493       class_weights_(ext_.class_weights_),
 
  503     template<
class Counts>
 
  504     double increment_histogram(Counts 
const & counts)
 
  506         std::transform(counts.begin(), counts.end(),
 
  507                        counts_.begin(), counts_.begin(),
 
  508                        std::plus<double>());
 
  509         total_counts_ = std::accumulate( counts_.begin(), 
 
  512         return impurity_(counts_, class_weights_, total_counts_);
 
  515     template<
class Counts>
 
  516     double decrement_histogram(Counts 
const & counts)
 
  518         std::transform(counts.begin(), counts.end(),
 
  519                        counts_.begin(), counts_.begin(),
 
  520                        std::minus<double>());
 
  521         total_counts_ = std::accumulate( counts_.begin(), 
 
  524         return impurity_(counts_, class_weights_, total_counts_);
 
  528     double increment(Iter begin, Iter end)
 
  530         for(Iter iter = begin; iter != end; ++iter)
 
  532             counts_[labels_(*iter, 0)] +=1.0;
 
  535         return impurity_(counts_, class_weights_, total_counts_);
 
  539     double decrement(Iter 
const &  begin, Iter 
const & end)
 
  541         for(Iter iter = begin; iter != end; ++iter)
 
  543             counts_[labels_(*iter,0)] -=1.0;
 
  546         return impurity_(counts_, class_weights_, total_counts_);
 
  549     template<
class Iter, 
class Resp_t>
 
  550     double init (Iter , Iter , Resp_t resp)
 
  553         std::copy(resp.begin(), resp.end(), counts_.begin());
 
  554         total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 
 
  555         return impurity_(counts_,class_weights_, total_counts_);
 
  558     ArrayVector<double> 
const & response()
 
  566     template <
class DataSource>
 
  567     class RegressionForestCounter
 
  571         DataSource 
const &      labels_;
 
  572         ArrayVector <double>    mean_;
 
  573         ArrayVector <double>    variance_;
 
  574         ArrayVector <double>    tmp_;
 
  579         RegressionForestCounter(DataSource 
const & labels, 
 
  580                                 ProblemSpec<T> 
const & ext_)
 
  583         mean_(ext_.response_size_, 0.0),
 
  584         variance_(ext_.response_size_, 0.0),
 
  585         tmp_(ext_.response_size_),
 
  590         double increment (Iter begin, Iter end)
 
  592             for(Iter iter = begin; iter != end; ++iter)
 
  595                 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
 
  596                     tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
 
  597                 double f  = 1.0 / count_,
 
  599                 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
 
  600                     mean_[ii] += f*tmp_[ii]; 
 
  601                 for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
 
  602                     variance_[ii] += f1*
sq(tmp_[ii]);
 
  604             double res = std::accumulate(variance_.begin(), 
 
  607                                          std::plus<double>());
 
  613         double decrement (Iter begin, Iter end)
 
  615             for(Iter iter = begin; iter != end; ++iter)
 
  624             for(
unsigned int ii = 0; ii < mean_.size(); ++ii)
 
  627                 for(Iter iter = begin; iter != end; ++iter)
 
  629                     mean_[ii] += labels_(*iter, ii);
 
  633                 for(Iter iter = begin; iter != end; ++iter)
 
  635                     variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
 
  638             double res = std::accumulate(variance_.begin(), 
 
  641                                          std::plus<double>());
 
  647         template<
class Iter, 
class Resp_t>
 
  648         double init (Iter begin, Iter end, Resp_t )
 
  651             return this->increment(begin, end);
 
  656         ArrayVector<double> 
const & response()
 
  670 template <
class DataSource>
 
  671 class RegressionForestCounter2
 
  675     DataSource 
const &      labels_;
 
  676     ArrayVector <double>    mean_;
 
  677     ArrayVector <double>    variance_;
 
  678     ArrayVector <double>    tmp_;
 
  682     RegressionForestCounter2(DataSource 
const & labels, 
 
  683                             ProblemSpec<T> 
const & ext_)
 
  686         mean_(ext_.response_size_, 0.0),
 
  687         variance_(ext_.response_size_, 0.0),
 
  688         tmp_(ext_.response_size_),
 
  693     double increment (Iter begin, Iter end)
 
  695         for(Iter iter = begin; iter != end; ++iter)
 
  698             for(
int ii = 0; ii < mean_.size(); ++ii)
 
  699                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
 
  700             double f  = 1.0 / count_,
 
  702             for(
int ii = 0; ii < mean_.size(); ++ii)
 
  703                 mean_[ii] += f*tmp_[ii]; 
 
  704             for(
int ii = 0; ii < mean_.size(); ++ii)
 
  705                 variance_[ii] += f1*
sq(tmp_[ii]);
 
  707         double res = std::accumulate(variance_.begin(), 
 
  711                 /((count_ == 1)? 1:(count_ -1));
 
  717     double decrement (Iter begin, Iter end)
 
  719         for(Iter iter = begin; iter != end; ++iter)
 
  721             double f  = 1.0 / count_,
 
  723             for(
int ii = 0; ii < mean_.size(); ++ii)
 
  724                 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f); 
 
  725             for(
int ii = 0; ii < mean_.size(); ++ii)
 
  726                 variance_[ii] -= f1*
sq(labels_(*iter,ii) - mean_[ii]);
 
  729         double res =  std::accumulate(variance_.begin(), 
 
  733                 /((count_ == 1)? 1:(count_ -1));
 
  783     template<
class Iter, 
class Resp_t>
 
  784     double init (Iter begin, Iter end, Resp_t resp)
 
  787         return this->increment(begin, end, resp);
 
  791     ArrayVector<double> 
const & response()
 
  804 template<
class Tag, 
class Datatyp>
 
  810 template<
class Datatype>
 
  811 struct LossTraits<GiniCriterion, Datatype>
 
  813     typedef ImpurityLoss<Datatype, GiniCriterion> type;
 
  816 template<
class Datatype>
 
  817 struct LossTraits<EntropyCriterion, Datatype>
 
  819     typedef ImpurityLoss<Datatype, EntropyCriterion> type;
 
  822 template<
class Datatype>
 
  823 struct LossTraits<LSQLoss, Datatype>
 
  825     typedef RegressionForestCounter<Datatype> type;
 
  830 template<
class LineSearchLossTag>
 
  837     std::ptrdiff_t               min_index_;
 
  838     double                  min_threshold_;
 
  847         class_weights_(ext.class_weights_),
 
  850         bestCurrentCounts[0].resize(ext.class_count_);
 
  851         bestCurrentCounts[1].resize(ext.class_count_);
 
  856         class_weights_ = ext.class_weights_; 
 
  858         bestCurrentCounts[0].resize(ext.class_count_);
 
  859         bestCurrentCounts[1].resize(ext.class_count_);
 
  888     template<   
class DataSourceF_t,
 
  893                     DataSource_t    
const & labels,
 
  896                     Array           
const & region_response)
 
  898         std::sort(begin, end, 
 
  901             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
  902         LineSearchLoss left(labels, ext_param_); 
 
  903         LineSearchLoss right(labels, ext_param_);
 
  907         min_gini_ = right.init(begin, end, region_response);  
 
  908         min_threshold_ = *begin;
 
  910         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
 
  913         I_Iter next = std::adjacent_find(iter, end, comp);
 
  917             double lr  =  right.decrement(iter, next + 1);
 
  918             double ll  =  left.increment(iter , next + 1);
 
  919             double loss = lr +ll;
 
  921 #ifdef CLASSIFIER_TEST 
  924             if(loss < min_gini_ )
 
  927                 bestCurrentCounts[0] = left.response();
 
  928                 bestCurrentCounts[1] = right.response();
 
  929 #ifdef CLASSIFIER_TEST 
  930                 min_gini_       = loss < min_gini_? loss : min_gini_;
 
  934                 min_index_      = next - begin +1 ;
 
  935                 min_threshold_  = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
 
  938             next = std::adjacent_find(iter, end, comp);
 
  945     template<
class DataSource_t, 
class Iter, 
class Array>
 
  946     double loss_of_region(DataSource_t 
const & labels,
 
  949                           Array 
const & region_response)
 const 
  952             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
  953         LineSearchLoss region_loss(labels, ext_param_);
 
  955             region_loss.init(begin, end, region_response);
 
  965         template<
class Region, 
class LabelT>
 
  966         static void exec(Region & , LabelT & )
 
  971     struct Correction<ClassificationTag>
 
  973         template<
class Region, 
class LabelT>
 
  974         static void exec(Region & region, LabelT & labels) 
 
  976             if(std::accumulate(region.classCounts().begin(),
 
  977                                region.classCounts().end(), 0.0) != region.size())
 
  979                 RandomForestClassCounter<   LabelT, 
 
  980                                             ArrayVector<double> >
 
  981                     counter(labels, region.classCounts());
 
  982                 std::for_each(  region.begin(), region.end(), counter);
 
  983                 region.classCountsIsValid = 
true;
 
  992 template<
class ColumnDecisionFunctor, 
class Tag = ClassificationTag>
 
 1001     ColumnDecisionFunctor       bgfunc;
 
 1003     double                      region_gini_;
 
 1010     double minGini()
 const 
 1012         return min_gini_[bestSplitIndex];
 
 1014     int bestSplitColumn()
 const 
 1016         return splitColumns[bestSplitIndex];
 
 1018     double bestSplitThreshold()
 const 
 1020         return min_thresholds_[bestSplitIndex];
 
 1027         bgfunc.set_external_parameters( SB::ext_param_);
 
 1028         int featureCount_ = SB::ext_param_.column_count_;
 
 1029         splitColumns.resize(featureCount_);
 
 1030         for(
int k=0; k<featureCount_; ++k)
 
 1031             splitColumns[k] = k;
 
 1032         min_gini_.resize(featureCount_);
 
 1033         min_indices_.resize(featureCount_);
 
 1034         min_thresholds_.resize(featureCount_);
 
 1038     template<
class T, 
class C, 
class T2, 
class C2, 
class Region, 
class Random>
 
 1046         typedef typename Region::IndexIterator IndexIterator;
 
 1047         if(region.size() == 0)
 
 1049            std::cerr << 
"SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n" 
 1050                         "continuing learning process...."; 
 
 1053         detail::Correction<Tag>::exec(region, labels);
 
 1057         region_gini_ = bgfunc.loss_of_region(labels,
 
 1060                                              region.classCounts());
 
 1061         if(region_gini_ <= SB::ext_param_.precision_)
 
 1065         for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
 
 1066             std::swap(splitColumns[ii], 
 
 1067                       splitColumns[ii+ randint(features.
shape(1) - ii)]);
 
 1071         double  current_min_gini    = region_gini_;
 
 1072         int     num2try             = features.
shape(1);
 
 1073         for(
int k=0; k<num2try; ++k)
 
 1078                    region.
begin(), region.end(), 
 
 1079                    region.classCounts());
 
 1080             min_gini_[k]            = bgfunc.min_gini_; 
 
 1081             min_indices_[k]         = bgfunc.min_index_;
 
 1082             min_thresholds_[k]      = bgfunc.min_threshold_;
 
 1083 #ifdef CLASSIFIER_TEST 
 1084             if(     bgfunc.min_gini_ < current_min_gini
 
 1087             if(bgfunc.min_gini_ < current_min_gini)
 
 1090                 current_min_gini = bgfunc.min_gini_;
 
 1091                 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
 
 1092                 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
 
 1093                 childRegions[0].classCountsIsValid = 
true;
 
 1094                 childRegions[1].classCountsIsValid = 
true;
 
 1097                 num2try = SB::ext_param_.actual_mtry_;
 
 1108         Node<i_ThresholdNode>   node(SB::t_data, SB::p_data);
 
 1110         node.threshold()    = min_thresholds_[bestSplitIndex];
 
 1111         node.column()       = splitColumns[bestSplitIndex];
 
 1115             sorter(features, node.column(), node.threshold());
 
 1116         IndexIterator bestSplit =
 
 1117             std::partition(region.begin(), region.end(), sorter);
 
 1119         childRegions[0].setRange(   region.begin()  , bestSplit       );
 
 1120         childRegions[0].rule = region.rule;
 
 1121         childRegions[0].rule.push_back(std::make_pair(1, 1.0));
 
 1122         childRegions[1].setRange(   bestSplit       , region.end()    );
 
 1123         childRegions[1].rule = region.rule;
 
 1124         childRegions[1].rule.push_back(std::make_pair(1, 1.0));
 
 1126         return i_ThresholdNode;
 
 1171     std::ptrdiff_t          min_index_;
 
 1172     double                  min_threshold_;
 
 1181         class_weights_(ext.class_weights_),
 
 1184         bestCurrentCounts[0].resize(ext.class_count_);
 
 1185         bestCurrentCounts[1].resize(ext.class_count_);
 
 1191         class_weights_ = ext.class_weights_; 
 
 1193         bestCurrentCounts[0].resize(ext.class_count_);
 
 1194         bestCurrentCounts[1].resize(ext.class_count_);
 
 1197     template<   
class DataSourceF_t,
 
 1201     void operator()(DataSourceF_t   
const & column,
 
 1202                     DataSource_t    
const & labels,
 
 1205                     Array           
const & region_response)
 
 1207         std::sort(begin, end, 
 
 1210             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
 1211         LineSearchLoss left(labels, ext_param_);
 
 1212         LineSearchLoss right(labels, ext_param_);
 
 1213         right.init(begin, end, region_response);
 
 1215         min_gini_ = NumericTraits<double>::max();
 
 1216         min_index_ = 
floor(
double(end - begin)/2.0); 
 
 1217         min_threshold_ =  column[*(begin + min_index_)];
 
 1219             sorter(column, 0, min_threshold_);
 
 1220         I_Iter part = std::partition(begin, end, sorter);
 
 1221         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
 
 1224             part= std::adjacent_find(part, end, comp)+1;
 
 1233             min_threshold_ = column[*part];
 
 1235         min_gini_ = right.decrement(begin, part) 
 
 1236               +     left.increment(begin , part);
 
 1238         bestCurrentCounts[0] = left.response();
 
 1239         bestCurrentCounts[1] = right.response();
 
 1241         min_index_      = part - begin;
 
 1244     template<
class DataSource_t, 
class Iter, 
class Array>
 
 1245     double loss_of_region(DataSource_t 
const & labels,
 
 1248                           Array 
const & region_response)
 const 
 1251             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
 1252         LineSearchLoss region_loss(labels, ext_param_);
 
 1254             region_loss.init(begin, end, region_response);
 
 1272     std::ptrdiff_t          min_index_;
 
 1273     double                  min_threshold_;
 
 1284         class_weights_(ext.class_weights_),
 
 1288         bestCurrentCounts[0].resize(ext.class_count_);
 
 1289         bestCurrentCounts[1].resize(ext.class_count_);
 
 1295         class_weights_(ext.class_weights_),
 
 1299         bestCurrentCounts[0].resize(ext.class_count_);
 
 1300         bestCurrentCounts[1].resize(ext.class_count_);
 
 1306         class_weights_ = ext.class_weights_; 
 
 1308         bestCurrentCounts[0].resize(ext.class_count_);
 
 1309         bestCurrentCounts[1].resize(ext.class_count_);
 
 1312     template<   
class DataSourceF_t,
 
 1316     void operator()(DataSourceF_t   
const & column,
 
 1317                     DataSource_t    
const & labels,
 
 1320                     Array           
const & region_response)
 
 1322         std::sort(begin, end, 
 
 1325             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
 1326         LineSearchLoss left(labels, ext_param_);
 
 1327         LineSearchLoss right(labels, ext_param_);
 
 1328         right.init(begin, end, region_response);
 
 1331         min_gini_ = NumericTraits<double>::max();
 
 1332         int tmp_pt = random.
uniformInt(std::distance(begin, end));
 
 1333         min_index_ = tmp_pt;
 
 1334         min_threshold_ =  column[*(begin + min_index_)];
 
 1336             sorter(column, 0, min_threshold_);
 
 1337         I_Iter part = std::partition(begin, end, sorter);
 
 1338         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
 
 1341             part= std::adjacent_find(part, end, comp)+1;
 
 1350             min_threshold_ = column[*part];
 
 1352         min_gini_ = right.decrement(begin, part) 
 
 1353               +     left.increment(begin , part);
 
 1355         bestCurrentCounts[0] = left.response();
 
 1356         bestCurrentCounts[1] = right.response();
 
 1358         min_index_      = part - begin;
 
 1361     template<
class DataSource_t, 
class Iter, 
class Array>
 
 1362     double loss_of_region(DataSource_t 
const & labels,
 
 1365                           Array 
const & region_response)
 const 
 1368             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
 
 1369         LineSearchLoss region_loss(labels, ext_param_);
 
 1371             region_loss.init(begin, end, region_response);
 
 1382 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX 
UInt32 uniformInt() const 
Definition: random.hxx:464
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:443
Definition: rf_region.hxx:57
Definition: rf_nodeproxy.hxx:626
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
Definition: rf_split.hxx:201
const difference_type & shape() const 
Definition: multi_array.hxx:1648
Definition: rf_split.hxx:993
Definition: rf_split.hxx:305
const_iterator begin() const 
Definition: array_vector.hxx:223
void set_external_parameters(ProblemSpec< T > const &in)
Definition: rf_split.hxx:112
problem specification class for the random forest. 
Definition: rf_common.hxx:538
iterator begin()
Definition: multi_array.hxx:1921
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition: rf_split.hxx:150
Definition: rf_split.hxx:356
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const 
Definition: rf_split.hxx:425
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:389
Definition: rf_nodeproxy.hxx:87
double operator()(Array const &hist, double total=1.0) const 
Definition: rf_split.hxx:435
NumericTraits< T >::Promote sq(T t)
The square function. 
Definition: mathutil.hxx:382
Definition: rf_split.hxx:831
double operator()(Array const &hist, double total=1.0) const 
Definition: rf_split.hxx:373
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
void reset()
Definition: rf_split.hxx:137
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality. 
Definition: mathutil.hxx:1638
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
void operator()(DataSourceF_t const &column, DataSource_t const &labels, I_Iter &begin, I_Iter &end, Array const ®ion_response)
Definition: rf_split.hxx:892
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:381
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray. 
Definition: multi_array.hxx:704
static double impurity(Array const &hist, Array2 const &weights, double total)
Definition: rf_split.hxx:451
Definition: rf_split.hxx:1264
size_type size() const 
Definition: array_vector.hxx:358
int floor(FixedPoint< IntBits, FracBits > v)
rounding down. 
Definition: fixedpoint.hxx:667
double & weights()
Definition: rf_nodeproxy.hxx:115
Definition: rf_split.hxx:92
double operator()(Array const &hist, Array2 const &weights, double total=1.0) const 
Definition: rf_split.hxx:363
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region ®ion, Random)
Definition: rf_split.hxx:168
Definition: rf_split.hxx:418