35 #ifndef RF_VISITORS_HXX 
   36 #define RF_VISITORS_HXX 
   39 # include "vigra/hdf5impex.hxx" 
   41 #include <vigra/windows.h> 
   45 #include <vigra/metaprogramming.hxx> 
   46 #include <vigra/multi_pointoperators.hxx> 
  141     template<
class Tree, 
class Split, 
class Region, 
class Feature_t, 
class Label_t>
 
  147                             Feature_t     & features,
 
  150         ignore_argument(tree,split,parent,leftChild,rightChild,features,labels);
 
  162     template<
class RF, 
class PR, 
class SM, 
class ST>
 
  165         ignore_argument(rf,pr,sm,st,index);
 
  174     template<
class RF, 
class PR>
 
  177         ignore_argument(rf,pr);
 
  186     template<
class RF, 
class PR>
 
  189         ignore_argument(rf,pr);
 
  204     template<
class TR, 
class IntT, 
class TopT,
class Feat>
 
  207         ignore_argument(tr,index,node_t,features);
 
  214     template<
class TR, 
class IntT, 
class TopT,
class Feat>
 
  253 template <
class Visitor, 
class Next = StopVisiting>
 
  263         next_(next), visitor_(visitor)
 
  268         next_(stop_), visitor_(visitor)
 
  271     template<
class Tree, 
class Split, 
class Region, 
class Feature_t, 
class Label_t>
 
  272     void visit_after_split( Tree          & tree,
 
  277                             Feature_t     & features,
 
  280         if(visitor_.is_active())
 
  281             visitor_.visit_after_split(tree, split,
 
  282                                        parent, leftChild, rightChild,
 
  284         next_.visit_after_split(tree, split, parent, leftChild, rightChild,
 
  288     template<
class RF, 
class PR, 
class SM, 
class ST>
 
  289     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, 
int index)
 
  291         if(visitor_.is_active())
 
  292             visitor_.visit_after_tree(rf, pr, sm, st, index);
 
  293         next_.visit_after_tree(rf, pr, sm, st, index);
 
  296     template<
class RF, 
class PR>
 
  297     void visit_at_beginning(RF & rf, PR & pr)
 
  299         if(visitor_.is_active())
 
  300             visitor_.visit_at_beginning(rf, pr);
 
  301         next_.visit_at_beginning(rf, pr);
 
  303     template<
class RF, 
class PR>
 
  304     void visit_at_end(RF & rf, PR & pr)
 
  306         if(visitor_.is_active())
 
  307             visitor_.visit_at_end(rf, pr);
 
  308         next_.visit_at_end(rf, pr);
 
  311     template<
class TR, 
class IntT, 
class TopT,
class Feat>
 
  312     void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
 
  314         if(visitor_.is_active())
 
  315             visitor_.visit_external_node(tr, index, node_t,features);
 
  316         next_.visit_external_node(tr, index, node_t,features);
 
  318     template<
class TR, 
class IntT, 
class TopT,
class Feat>
 
  319     void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
 
  321         if(visitor_.is_active())
 
  322             visitor_.visit_internal_node(tr, index, node_t,features);
 
  323         next_.visit_internal_node(tr, index, node_t,features);
 
  328         if(visitor_.is_active() && visitor_.has_value())
 
  329             return visitor_.return_val();
 
  330         return next_.return_val();
 
  354 template<
class A, 
class B>
 
  355 detail::VisitorNode<A, detail::VisitorNode<B> >
 
  368 template<
class A, 
class B, 
class C>
 
  369 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
 
  384 template<
class A, 
class B, 
class C, 
class D>
 
  385 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  386     detail::VisitorNode<D> > > >
 
  403 template<
class A, 
class B, 
class C, 
class D, 
class E>
 
  404 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  405     detail::VisitorNode<D, detail::VisitorNode<E> > > > >
 
  425 template<
class A, 
class B, 
class C, 
class D, 
class E,
 
  427 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  428     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
 
  450 template<
class A, 
class B, 
class C, 
class D, 
class E,
 
  452 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  453     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
 
  454     detail::VisitorNode<G> > > > > > >
 
  456                D & d, E & e, F & f, G & g)
 
  478 template<
class A, 
class B, 
class C, 
class D, 
class E,
 
  479          class F, 
class G, 
class H>
 
  480 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  481     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
 
  482     detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
 
  509 template<
class A, 
class B, 
class C, 
class D, 
class E,
 
  510          class F, 
class G, 
class H, 
class I>
 
  511 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  512     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
 
  513     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
 
  541 template<
class A, 
class B, 
class C, 
class D, 
class E,
 
  542          class F, 
class G, 
class H, 
class I, 
class J>
 
  543 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C,
 
  544     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F,
 
  545     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
 
  546     detail::VisitorNode<J> > > > > > > > > >
 
  587     bool adjust_thresholds;
 
  597         adjust_thresholds(
false), tree_id(0), last_node_id(0), current_label(0)
 
  599     struct MarginalDistribution
 
  602         Int32 leftTotalCounts;
 
  604         Int32 rightTotalCounts;
 
  611     struct TreeOnlineInformation
 
  613         std::vector<MarginalDistribution> mag_distributions;
 
  614         std::vector<IndexList> index_lists;
 
  616         std::map<int,int> interior_to_index;
 
  618         std::map<int,int> exterior_to_index;
 
  622     std::vector<TreeOnlineInformation> trees_online_information;
 
  626     template<
class RF,
class PR>
 
  630         trees_online_information.resize(rf.options_.tree_count_);
 
  637         trees_online_information[tree_id].mag_distributions.clear();
 
  638         trees_online_information[tree_id].index_lists.clear();
 
  639         trees_online_information[tree_id].interior_to_index.clear();
 
  640         trees_online_information[tree_id].exterior_to_index.clear();
 
  645     template<
class RF, 
class PR, 
class SM, 
class ST>
 
  651     template<
class Tree, 
class Split, 
class Region, 
class Feature_t, 
class Label_t>
 
  652     void visit_after_split( Tree          & tree,
 
  657                             Feature_t     & features,
 
  661         int addr=tree.topology_.size();
 
  662         if(split.createNode().typeID() == i_ThresholdNode)
 
  664             if(adjust_thresholds)
 
  667                 linear_index=trees_online_information[tree_id].mag_distributions.size();
 
  668                 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
 
  669                 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
 
  671                 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
 
  672                 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
 
  674                 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
 
  675                 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
 
  677                 double gap_left,gap_right;
 
  679                 gap_left=features(leftChild[0],split.bestSplitColumn());
 
  680                 for(i=1;i<leftChild.size();++i)
 
  681                     if(features(leftChild[i],split.bestSplitColumn())>gap_left)
 
  682                         gap_left=features(leftChild[i],split.bestSplitColumn());
 
  683                 gap_right=features(rightChild[0],split.bestSplitColumn());
 
  684                 for(i=1;i<rightChild.size();++i)
 
  685                     if(features(rightChild[i],split.bestSplitColumn())<gap_right)
 
  686                         gap_right=features(rightChild[i],split.bestSplitColumn());
 
  687                 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
 
  688                 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
 
  694             linear_index=trees_online_information[tree_id].index_lists.size();
 
  695             trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
 
  697             trees_online_information[tree_id].index_lists.push_back(IndexList());
 
  699             trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
 
  700             std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
 
  703     void add_to_index_list(
int tree,
int node,
int index)
 
  707         TreeOnlineInformation &ti=trees_online_information[tree];
 
  708         ti.index_lists[ti.exterior_to_index[node]].push_back(index);
 
  710     void move_exterior_node(
int src_tree,
int src_index,
int dst_tree,
int dst_index)
 
  714         trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
 
  715         trees_online_information[src_tree].exterior_to_index.erase(src_index);
 
  722     template<
class TR, 
class IntT, 
class TopT,
class Feat>
 
  726             if(adjust_thresholds)
 
  728                 vigra_assert(node_t==i_ThresholdNode,
"We can only visit threshold nodes");
 
  730                 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
 
  731                 TreeOnlineInformation &ti=trees_online_information[tree_id];
 
  732                 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
 
  733                 if(value>m.gap_left && value<m.gap_right)
 
  736                     if(m.leftCounts[current_label]/
double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
 
  746                     Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
 
  749                 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
 
  751                     ++m.rightTotalCounts;
 
  752                     ++m.rightCounts[current_label];
 
  757                     ++m.rightCounts[current_label];
 
  805     template<
class RF, 
class PR, 
class SM, 
class ST>
 
  809         if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
 
  811             oobCount.resize(rf.ext_param_.row_count_, 0);
 
  812             oobErrorCount.resize(rf.ext_param_.row_count_, 0);
 
  815         for(
int l = 0; l < rf.ext_param_.row_count_; ++l)
 
  822                             .predictLabel(
rowVector(pr.features(), l))
 
  823                     !=  pr.response()(l,0))
 
  834     template<
class RF, 
class PR>
 
  838         for(
int l=0; l < static_cast<int>(rf.ext_param_.row_count_); ++l)
 
  842                 oobError += double(oobErrorCount[l]) / oobCount[l];
 
  880     void save(std::string filen, std::string pathn)
 
  882         if(*(pathn.end()-1) != 
'/')
 
  884         const char* filename = filen.c_str();
 
  887         writeHDF5(filename, (pathn + 
"breiman_error").c_str(), temp);
 
  893     template<
class RF, 
class PR>
 
  894     void visit_at_beginning(RF & rf, PR &)
 
  896         class_count = rf.class_count();
 
  897         tmp_prob.
reshape(Shp(1, class_count), 0);
 
  898         prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
 
  899         is_weighted = rf.options().predict_weighted_;
 
  900         indices.resize(rf.ext_param().row_count_);
 
  901         if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
 
  903             oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
 
  905         for(
int ii = 0; ii < rf.ext_param().row_count_; ++ii)
 
  911     template<
class RF, 
class PR, 
class SM, 
class ST>
 
  912     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST &, 
int index)
 
  919         if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
 
  921             ArrayVector<int> oob_indices;
 
  922             ArrayVector<int> cts(class_count, 0);
 
  923             std::random_shuffle(indices.
begin(), indices.
end());
 
  924             for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
 
  926                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
 
  928                     oob_indices.push_back(indices[ii]);
 
  929                     ++cts[pr.response()(indices[ii], 0)];
 
  932             for(
unsigned int ll = 0; ll < oob_indices.size(); ++ll)
 
  935                 ++oobCount[oob_indices[ll]];
 
  940                 int pos =  rf.tree(index).getToLeaf(
rowVector(pr.features(),oob_indices[ll]));
 
  941                 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
 
  942                                                     rf.tree(index).parameters_,
 
  945                 for(
int ii = 0; ii < class_count; ++ii)
 
  947                     tmp_prob[ii] = node.prob_begin()[ii];
 
  951                     for(
int ii = 0; ii < class_count; ++ii)
 
  952                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
 
  954                 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
 
  959             for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
 
  962                 if(!sm.is_used()[ll])
 
  970                     int pos =  rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
 
  971                     Node<e_ConstProbNode> node ( rf.tree(index).topology_,
 
  972                                                         rf.tree(index).parameters_,
 
  975                     for(
int ii = 0; ii < class_count; ++ii)
 
  977                         tmp_prob[ii] = node.prob_begin()[ii];
 
  981                         for(
int ii = 0; ii < class_count; ++ii)
 
  982                             tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
 
  993     template<
class RF, 
class PR>
 
  997         int totalOobCount =0;
 
  998         int breimanstyle = 0;
 
  999         for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
 
 1074     void save(std::string filen, std::string pathn)
 
 1076         if(*(pathn.end()-1) != 
'/')
 
 1078         const char* filename = filen.c_str();
 
 1084         writeHDF5(filename, (pathn + 
"per_tree_error").c_str(), temp);
 
 1086         writeHDF5(filename, (pathn + 
"per_tree_error_std").c_str(), temp);
 
 1088         writeHDF5(filename, (pathn + 
"breiman_error").c_str(), temp);
 
 1090         writeHDF5(filename, (pathn + 
"ulli_error").c_str(), temp);
 
 1096     template<
class RF, 
class PR>
 
 1097     void visit_at_beginning(RF & rf, PR &)
 
 1099         class_count = rf.class_count();
 
 1100         if(class_count == 2)
 
 1104         tmp_prob.
reshape(Shp(1, class_count), 0);
 
 1105         prob_oob.
reshape(Shp(rf.ext_param().row_count_,class_count), 0);
 
 1106         is_weighted = rf.options().predict_weighted_;
 
 1110         if(
int(oobCount.
size()) != rf.ext_param_.row_count_)
 
 1112             oobCount.
reshape(Shp(rf.ext_param_.row_count_, 1), 0);
 
 1113             oobErrorCount.
reshape(Shp(rf.ext_param_.row_count_,1), 0);
 
 1117     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1118     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST &, 
int index)
 
 1123         for(
int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
 
 1126             if(!sm.is_used()[ll])
 
 1134                 int pos =  rf.tree(index).getToLeaf(
rowVector(pr.features(),ll));
 
 1135                 Node<e_ConstProbNode> node ( rf.tree(index).topology_,
 
 1136                                                     rf.tree(index).parameters_,
 
 1139                 for(
int ii = 0; ii < class_count; ++ii)
 
 1141                     tmp_prob[ii] = node.prob_begin()[ii];
 
 1145                     for(
int ii = 0; ii < class_count; ++ii)
 
 1146                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
 
 1149                 int label = 
argMax(tmp_prob);
 
 1151                 if(label != pr.response()(ll, 0))
 
 1156                     ++oobErrorCount[ll];
 
 1160         int breimanstyle = 0;
 
 1161         int totalOobCount = 0;
 
 1162         for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
 
 1179             MultiArrayView<3, double> current_roc
 
 1181             for(
int gg = 0; gg < current_roc.shape(2); ++gg)
 
 1183                 for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
 
 1187                         int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
 
 1189                         current_roc(pr.response()(ll, 0), pred, gg)+= 1;
 
 1192                 current_roc.
bindOuter(gg)/= totalOobCount;
 
 1196         oob_per_tree[index] = double(wrong_oob)/double(total_oob);
 
 1202     template<
class RF, 
class PR>
 
 1207         int totalOobCount =0;
 
 1208         int breimanstyle = 0;
 
 1209         for(
int ll=0; ll < static_cast<int>(rf.ext_param_.row_count_); ++ll)
 
 1260     int                         repetition_count_;
 
 1264     void save(std::string filename, std::string prefix)
 
 1266         prefix = 
"variable_importance_" + prefix;
 
 1279     :   repetition_count_(rep_cnt)
 
 1286     template<
class Tree, 
class Split, 
class Region, 
class Feature_t, 
class Label_t>
 
 1297         Int32 const  class_count = tree.ext_param_.class_count_;
 
 1298         Int32 const  column_count = tree.ext_param_.column_count_;
 
 1307         if(split.createNode().typeID() == i_ThresholdNode)
 
 1309             Node<i_ThresholdNode> node(split.createNode());
 
 1311                 += split.region_gini_ - split.minGini();
 
 1321     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1325         Int32                   column_count = rf.ext_param_.column_count_;
 
 1326         Int32                   class_count  = rf.ext_param_.class_count_;
 
 1336         typedef typename PR::FeatureWithMemory_t FeatureArray;
 
 1337         typedef typename FeatureArray::value_type FeatureValue;
 
 1339         FeatureArray features = pr.features();
 
 1345         for(
int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
 
 1346             if(!sm.is_used()[ii])
 
 1347                 oob_indices.push_back(ii);
 
 1353 #ifdef CLASSIFIER_TEST 
 1364                     oob_right(Shp_t(1, class_count + 1));
 
 1366                     perm_oob_right (Shp_t(1, class_count + 1));
 
 1370         for(iter = oob_indices.
begin();
 
 1371             iter != oob_indices.
end();
 
 1375                     .predictLabel(
rowVector(features, *iter))
 
 1376                 ==  pr.response()(*iter, 0))
 
 1379                 ++oob_right[pr.response()(*iter,0)];
 
 1381                 ++oob_right[class_count];
 
 1385         for(
int ii = 0; ii < column_count; ++ii)
 
 1387             perm_oob_right.
init(0.0);
 
 1389             backup_column.clear();
 
 1390             for(iter = oob_indices.
begin();
 
 1391                 iter != oob_indices.
end();
 
 1394                 backup_column.push_back(features(*iter,ii));
 
 1398             for(
int rr = 0; rr < repetition_count_; ++rr)
 
 1401                 int n = oob_indices.
size();
 
 1402                 for(
int jj = n-1; jj >= 1; --jj)
 
 1403                     std::swap(features(oob_indices[jj], ii),
 
 1404                               features(oob_indices[randint(jj+1)], ii));
 
 1407                 for(iter = oob_indices.
begin();
 
 1408                     iter != oob_indices.
end();
 
 1412                             .predictLabel(
rowVector(features, *iter))
 
 1413                         ==  pr.response()(*iter, 0))
 
 1416                         ++perm_oob_right[pr.response()(*iter, 0)];
 
 1418                         ++perm_oob_right[class_count];
 
 1425             perm_oob_right  /=  repetition_count_;
 
 1426             perm_oob_right -=oob_right;
 
 1427             perm_oob_right *= -1;
 
 1428             perm_oob_right      /=  oob_indices.
size();
 
 1431                           Shp_t(ii+1,class_count+1)) += perm_oob_right;
 
 1433             for(
int jj = 0; jj < int(oob_indices.
size()); ++jj)
 
 1434                 features(oob_indices[jj], ii) = backup_column[jj];
 
 1443     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1451     template<
class RF, 
class PR>
 
 1464     template<
class RF, 
class PR, 
class SM, 
class ST>
 
 1465     void visit_after_tree(RF& rf, PR &,  SM &, ST &, 
int index){
 
 1466         if(index != rf.options().tree_count_-1) {
 
 1467             std::cout << 
"\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << 
"%]" 
 1468                       << 
" (" << index+1 << 
" of " << rf.options().tree_count_ << 
") done" << std::flush;
 
 1471             std::cout << 
"\r[" << std::setw(10) << 100.0 << 
"%]" << std::endl;
 
 1475     template<
class RF, 
class PR>
 
 1476     void visit_at_end(RF 
const & rf, PR 
const &) {
 
 1477         std::string a = 
TOCS;
 
 1478         std::cout << 
"all " << rf.options().tree_count_ << 
" trees have been learned in " << a  << std::endl;
 
 1481     template<
class RF, 
class PR>
 
 1482     void visit_at_beginning(RF 
const & rf, PR 
const &) {
 
 1484         std::cout << 
"growing random forest, which will have " << rf.options().tree_count_ << 
" trees" << std::endl;
 
 1532     void save(std::string, std::string)
 
 1550     template<
class RF, 
class PR>
 
 1551     void visit_at_beginning(RF 
const & rf, PR & pr)
 
 1554         int n = rf.ext_param_.column_count_;
 
 1557         corr_l.
reshape(Shp(n +1, 10));
 
 1560         noise_l.
reshape(Shp(pr.features().shape(0), 10));
 
 1562         for(
int ii = 0; ii < 
noise.
size(); ++ii)
 
 1564             noise[ii]   = random.uniform53();
 
 1565             noise_l[ii] = random.uniform53()  > 0.5;
 
 1567         bgfunc = ColumnDecisionFunctor( rf.ext_param_);
 
 1568         tmp_labels.
reshape(pr.response().shape());
 
 1573     template<
class RF, 
class PR>
 
 1574     void visit_at_end(RF 
const &, PR 
const &)
 
 1583         for(
int jj = 0; jj < rC-1; ++jj)
 
 1588         for(
int jj = 0; jj < rC; ++jj)
 
 1594         FindMinMax<double> minmax;
 
 1597         for(
int jj = 0; jj < rC; ++jj)
 
 1604         for(
int jj = 0; jj < rC; ++jj)
 
 1607         FindMinMax<double> minmax2;
 
 1609         for(
int jj = 0; jj < rC; ++jj)
 
 1615     template<
class Tree, 
class Split, 
class Region, 
class Feature_t, 
class Label_t>
 
 1616     void visit_after_split( Tree          &,
 
 1621                             Feature_t     & features,
 
 1624         if(split.createNode().typeID() == i_ThresholdNode)
 
 1628             for(
int ii = 0; ii < parent.size(); ++ii)
 
 1630                 tmp_labels[parent[ii]]
 
 1631                     = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
 
 1632                 ++tmp_cc[tmp_labels[parent[ii]]];
 
 1634             double region_gini = bgfunc.loss_of_region(tmp_labels,
 
 1639             int n = split.bestSplitColumn();
 
 1643             for(
int k = 0; k < features.shape(1); ++k)
 
 1647                        parent.begin(), parent.end(),
 
 1649                 wgini = (region_gini - bgfunc.min_gini_);
 
 1653             for(
int k = 0; k < 10; ++k)
 
 1657                        parent.begin(), parent.end(),
 
 1659                 wgini = (region_gini - bgfunc.min_gini_);
 
 1664             for(
int k = 0; k < 10; ++k)
 
 1668                        parent.begin(), parent.end(),
 
 1670                 wgini = (region_gini - bgfunc.min_gini_);
 
 1674             bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
 
 1675             wgini = (region_gini - bgfunc.min_gini_);
 
 1679             region_gini = split.region_gini_;
 
 1681             Node<i_ThresholdNode> node(split.createNode());
 
 1684                  +=split.region_gini_ - split.minGini();
 
 1686             for(
int k = 0; k < 10; ++k)
 
 1690                              parent.begin(), parent.end(),
 
 1691                              parent.classCounts());
 
 1697             for(
int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
 
 1699                 wgini = region_gini - split.min_gini_[k];
 
 1702                                       split.splitColumns[k])
 
 1706             for(
int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
 
 1708                 split.bgfunc(
columnVector(features, split.splitColumns[k]),
 
 1710                              parent.begin(), parent.end(),
 
 1711                              parent.classCounts());
 
 1712                 wgini = region_gini - split.bgfunc.min_gini_;
 
 1714                                       split.splitColumns[k]) += wgini;
 
 1721                 SortSamplesByDimensions<Feature_t>
 
 1722                 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
 
 1723             std::partition(parent.begin(), parent.end(), sorter);
 
 1733 #endif // RF_VISITORS_HXX 
#define TIC
Definition: timing.hxx:322
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:994
MultiArray< 2, double > oob_per_tree
Definition: rf_visitors.hxx:1025
void visit_at_beginning(RF &rf, const PR &)
Definition: rf_visitors.hxx:627
MultiArray< 2, double > noise
Definition: rf_visitors.hxx:1505
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
MultiArray< 2, double > variable_importance_
Definition: rf_visitors.hxx:1259
const difference_type & shape() const 
Definition: multi_array.hxx:1648
MultiArray< 2, double > breiman_per_tree
Definition: rf_visitors.hxx:1050
void after_tree_ip_impl(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:1322
MultiArray< 2, double > corr_noise
Definition: rf_visitors.hxx:1509
void visit_after_tree(RF &, PR &, SM &, ST &, int)
Definition: rf_visitors.hxx:646
const_iterator begin() const 
Definition: array_vector.hxx:223
MultiArray< 2, double > similarity
Definition: rf_visitors.hxx:1521
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Definition: rf_visitors.hxx:863
MultiArray< 4, double > oobroc_per_tree
Definition: rf_visitors.hxx:1067
Definition: rf_visitors.hxx:1495
ArrayVector< int > numChoices
Definition: rf_visitors.hxx:1529
Definition: rf_visitors.hxx:1230
MultiArrayView< N, T, StridedArrayTag > transpose() const 
Definition: multi_array.hxx:1567
double oob_breiman
Definition: rf_visitors.hxx:1038
double oob_mean
Definition: rf_visitors.hxx:1028
MultiArray< 2, double > gini_missc
Definition: rf_visitors.hxx:1501
double return_val()
Definition: rf_visitors.hxx:225
difference_type_1 size() const 
Definition: multi_array.hxx:1641
MultiArray< 2, double > distance
Definition: rf_visitors.hxx:1524
Definition: multi_fwd.hxx:63
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:635
double oobError
Definition: rf_visitors.hxx:787
Definition: rf_visitors.hxx:254
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
void visit_internal_node(TR &, IntT, TopT, Feat &)
Definition: rf_visitors.hxx:215
void init(U const &initial)
Definition: array_vector.hxx:146
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &, int index)
Definition: rf_visitors.hxx:806
double oob_per_tree2
Definition: rf_visitors.hxx:1045
Definition: rf_split.hxx:831
MultiArray & init(const U &init)
Definition: multi_array.hxx:2851
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence. 
Definition: algorithm.hxx:96
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:1452
Definition: rf_visitors.hxx:1015
Definition: rf_visitors.hxx:583
double oob_breiman
Definition: rf_visitors.hxx:874
#define TOCS
Definition: timing.hxx:325
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void writeHDF5(...)
Store array data in an HDF5 file. 
Definition: rf_visitors.hxx:1460
double oob_std
Definition: rf_visitors.hxx:1031
void visit_internal_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:723
void visit_at_end(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:175
void visit_at_end(RF &rf, PR &pr)
Definition: rf_visitors.hxx:1203
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
Definition: rf_visitors.hxx:101
void visit_at_end(RF &rf, PR &)
Definition: rf_visitors.hxx:835
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude) 
Definition: fftw3.hxx:1002
void visit_at_beginning(RF const &rf, PR const &pr)
Definition: rf_visitors.hxx:187
Definition: random.hxx:336
void visit_after_split(Tree &tree, Split &split, Region &parent, Region &leftChild, Region &rightChild, Feature_t &features, Label_t &labels)
Definition: rf_visitors.hxx:142
const_iterator end() const 
Definition: array_vector.hxx:237
const_pointer data() const 
Definition: array_vector.hxx:209
size_type size() const 
Definition: array_vector.hxx:358
MultiArrayView subarray(difference_type p, difference_type q) const 
Definition: multi_array.hxx:1528
void inspectMultiArray(...)
Call an analyzing functor at every element of a multi-dimensional array. 
void visit_after_split(Tree &tree, Split &split, Region &, Region &, Region &, Feature_t &, Label_t &)
Definition: rf_visitors.hxx:1287
void visit_external_node(TR &tr, IntT index, TopT node_t, Feat &features)
Definition: rf_visitors.hxx:205
Definition: rf_visitors.hxx:782
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:163
void visit_after_tree(RF &rf, PR &pr, SM &sm, ST &st, int index)
Definition: rf_visitors.hxx:1444
MultiArrayView< N-M, T, StrideTag > bindOuter(const TinyVector< Index, M > &d) const 
Definition: multi_array.hxx:2184
Definition: rf_visitors.hxx:234
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344