36 #ifndef VIGRA_RANDOM_FOREST_DEPREC_HXX 
   37 #define VIGRA_RANDOM_FOREST_DEPREC_HXX 
   45 #include "vigra/mathutil.hxx" 
   46 #include "vigra/array_vector.hxx" 
   47 #include "vigra/sized_int.hxx" 
   48 #include "vigra/matrix.hxx" 
   49 #include "vigra/random.hxx" 
   50 #include "vigra/functorexpression.hxx" 
   63 template<
class DataMatrix>
 
   64 class RandomForestDeprecFeatureSorter
 
   66     DataMatrix 
const & data_;
 
   71     RandomForestDeprecFeatureSorter(DataMatrix 
const & data, 
MultiArrayIndex sortColumn)
 
   73       sortColumn_(sortColumn)
 
   78         sortColumn_ = sortColumn;
 
   83         return data_(l, sortColumn_) < data_(r, sortColumn_);
 
   87 template<
class LabelArray>
 
   88 class RandomForestDeprecLabelSorter
 
   90     LabelArray 
const & labels_;
 
   94     RandomForestDeprecLabelSorter(LabelArray 
const & labels)
 
  100         return labels_[l] < labels_[r];
 
  104 template <
class CountArray>
 
  105 class RandomForestDeprecClassCounter
 
  107     ArrayVector<int> 
const & labels_;
 
  108     CountArray & counts_;
 
  112     RandomForestDeprecClassCounter(ArrayVector<int> 
const & labels, CountArray & counts)
 
  126         ++counts_[labels_[l]];
 
  130 struct DecisionTreeDeprecCountNonzeroFunctor
 
  132     double operator()(
double old, 
double other)
 const 
  140 struct DecisionTreeDeprecNode
 
  143     : thresholdIndex(t), splitColumn(bestColumn)
 
  152 struct DecisionTreeDeprecNodeProxy
 
  154     DecisionTreeDeprecNodeProxy(ArrayVector<INT> 
const & tree, INT n)
 
  155     : node(const_cast<ArrayVector<INT> &>(tree).begin()+n)
 
  158     INT & child(INT l)
 const 
  163     INT & decisionWeightsIndex()
 const 
  168     typename ArrayVector<INT>::iterator decisionColumns()
 const 
  173     mutable typename ArrayVector<INT>::iterator node;
 
  176 struct DecisionTreeDeprecAxisSplitFunctor
 
  178     ArrayVector<Int32> splitColumns;
 
  179     ArrayVector<double> classCounts, currentCounts[2], bestCounts[2], classWeights;
 
  181     double totalCounts[2], bestTotalCounts[2];
 
  182     int mtry, classCount, bestSplitColumn;
 
  183     bool pure[2], isWeighted;
 
  185     void init(
int mtry, 
int cols, 
int classCount, ArrayVector<double> 
const & weights)
 
  188         splitColumns.resize(cols);
 
  189         for(
int k=0; k<cols; ++k)
 
  192         this->classCount = classCount;
 
  193         classCounts.resize(classCount);
 
  194         currentCounts[0].resize(classCount);
 
  195         currentCounts[1].resize(classCount);
 
  196         bestCounts[0].resize(classCount);
 
  197         bestCounts[1].resize(classCount);
 
  199         isWeighted = weights.size() > 0;
 
  201             classWeights = weights;
 
  203             classWeights.resize(classCount, 1.0);
 
  206     bool isPure(
int k)
 const 
  211     unsigned int totalCount(
int k)
 const 
  213         return (
unsigned int)bestTotalCounts[k];
 
  216     int sizeofNode()
 const { 
return 4; }
 
  218     int writeSplitParameters(ArrayVector<Int32> & tree,
 
  219                                 ArrayVector<double> &terminalWeights)
 
  221         int currentWeightIndex = terminalWeights.size();
 
  222         terminalWeights.push_back(threshold);
 
  224         int currentNodeIndex = tree.size();
 
  227         tree.push_back(currentWeightIndex);
 
  228         tree.push_back(bestSplitColumn);
 
  230         return currentNodeIndex;
 
  233     void writeWeights(
int l, ArrayVector<double> &terminalWeights)
 
  235         for(
int k=0; k<classCount; ++k)
 
  236             terminalWeights.push_back(isWeighted
 
  238                                            : bestCounts[l][k] / totalCount(l));
 
  241     template <
class U, 
class C, 
class AxesIterator, 
class WeightIterator>
 
  242     bool decideAtNode(MultiArrayView<2, U, C> 
const & features,
 
  243                       AxesIterator a, WeightIterator w)
 const 
  245         return (features(0, *a) < *w);
 
  248     template <
class U, 
class C, 
class IndexIterator, 
class Random>
 
  249     IndexIterator findBestSplit(MultiArrayView<2, U, C> 
const & features,
 
  250                                 ArrayVector<int> 
const & labels,
 
  251                                 IndexIterator indices, 
int exampleCount,
 
  257 template <
class U, 
class C, 
class IndexIterator, 
class Random>
 
  259 DecisionTreeDeprecAxisSplitFunctor::findBestSplit(MultiArrayView<2, U, C> 
const & features,
 
  260                                             ArrayVector<int> 
const & labels,
 
  261                                             IndexIterator indices, 
int exampleCount,
 
  265     for(
int k=0; k<mtry; ++k)
 
  266         std::swap(splitColumns[k], splitColumns[k+randint(
columnCount(features)-k)]);
 
  268     RandomForestDeprecFeatureSorter<MultiArrayView<2, U, C> > sorter(features, 0);
 
  269     RandomForestDeprecClassCounter<ArrayVector<double> > counter(labels, classCounts);
 
  270     std::for_each(indices, indices+exampleCount, counter);
 
  273     double minGini = NumericTraits<double>::max();
 
  274     IndexIterator bestSplit = indices;
 
  275     for(
int k=0; k<mtry; ++k)
 
  277         sorter.setColumn(splitColumns[k]);
 
  278         std::sort(indices, indices+exampleCount, sorter);
 
  280         currentCounts[0].init(0);
 
  281         std::transform(classCounts.begin(), classCounts.end(), classWeights.begin(),
 
  282                        currentCounts[1].begin(), std::multiplies<double>());
 
  284         totalCounts[1] = std::accumulate(currentCounts[1].begin(), currentCounts[1].end(), 0.0);
 
  285         for(
int m = 0; m < exampleCount-1; ++m)
 
  287             int label = labels[indices[m]];
 
  288             double w = classWeights[label];
 
  289             currentCounts[0][label] += w;
 
  291             currentCounts[1][label] -= w;
 
  294             if (m < exampleCount-2 &&
 
  295                 features(indices[m], splitColumns[k]) == features(indices[m+1], splitColumns[k]))
 
  301                 gini = currentCounts[0][0]*currentCounts[0][1] / totalCounts[0] +
 
  302                        currentCounts[1][0]*currentCounts[1][1] / totalCounts[1];
 
  306                 for(
int l=0; l<classCount; ++l)
 
  307                     gini += currentCounts[0][l]*(1.0 - currentCounts[0][l] / totalCounts[0]) +
 
  308                             currentCounts[1][l]*(1.0 - currentCounts[1][l] / totalCounts[1]);
 
  313                 bestSplit = indices+m;
 
  314                 bestSplitColumn = splitColumns[k];
 
  315                 bestCounts[0] = currentCounts[0];
 
  316                 bestCounts[1] = currentCounts[1];
 
  325     sorter.setColumn(bestSplitColumn);
 
  326     std::sort(indices, indices+exampleCount, sorter);
 
  328     for(
int k=0; k<2; ++k)
 
  330         bestTotalCounts[k] = std::accumulate(bestCounts[k].begin(), bestCounts[k].end(), 0.0);
 
  333     threshold = (features(bestSplit[0], bestSplitColumn) + features(bestSplit[1], bestSplitColumn)) / 2.0;
 
  337     std::for_each(indices, bestSplit, counter);
 
  338     pure[0] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
 
  340     std::for_each(bestSplit, indices+exampleCount, counter);
 
  341     pure[1] = 1.0 == std::accumulate(classCounts.begin(), classCounts.end(), 0.0, DecisionTreeDeprecCountNonzeroFunctor());
 
  346 enum  { DecisionTreeDeprecNoParent = -1 };
 
  348 template <
class Iterator>
 
  349 struct DecisionTreeDeprecStackEntry
 
  351     DecisionTreeDeprecStackEntry(Iterator i, 
int c,
 
  352                            int lp = DecisionTreeDeprecNoParent, 
int rp = DecisionTreeDeprecNoParent)
 
  353     : indices(i), exampleCount(c),
 
  354       leftParent(lp), rightParent(rp)
 
  358     int exampleCount, leftParent, rightParent;
 
  361 class DecisionTreeDeprec
 
  364     typedef Int32 TreeInt;
 
  365     ArrayVector<TreeInt>  tree_;
 
  366     ArrayVector<double> terminalWeights_;
 
  367     unsigned int classCount_;
 
  368     DecisionTreeDeprecAxisSplitFunctor split;
 
  373     DecisionTreeDeprec(
unsigned int classCount)
 
  374     : classCount_(classCount)
 
  377     void reset(
unsigned int classCount = 0)
 
  380             classCount_ = classCount;
 
  382         terminalWeights_.clear();
 
  385     template <
class U, 
class C, 
class Iterator, 
class Options, 
class Random>
 
  386     void learn(MultiArrayView<2, U, C> 
const & features,
 
  387                ArrayVector<int> 
const & labels,
 
  388                Iterator indices, 
int exampleCount,
 
  389                Options 
const & options,
 
  392     template <
class U, 
class C>
 
  393     ArrayVector<double>::const_iterator
 
  394     predict(MultiArrayView<2, U, C> 
const & features)
 const 
  399             DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
 
  400             nodeindex = split.decideAtNode(features, node.decisionColumns(),
 
  401                                        terminalWeights_.begin() + node.decisionWeightsIndex())
 
  405                 return terminalWeights_.begin() + (-nodeindex);
 
  409     template <
class U, 
class C>
 
  411     predictLabel(MultiArrayView<2, U, C> 
const & features)
 const 
  413         ArrayVector<double>::const_iterator weights = predict(features);
 
  414         return argMax(weights, weights+classCount_) - weights;
 
  417     template <
class U, 
class C>
 
  419     leafID(MultiArrayView<2, U, C> 
const & features)
 const 
  424             DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, nodeindex);
 
  425             nodeindex = split.decideAtNode(features, node.decisionColumns(),
 
  426                                        terminalWeights_.begin() + node.decisionWeightsIndex())
 
  434     void depth(
int & maxDep, 
int & interiorCount, 
int & leafCount, 
int k = 0, 
int d = 1)
 const 
  436         DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
 
  439         for(
int l=0; l<2; ++l)
 
  441             int child = node.child(l);
 
  443                 depth(maxDep, interiorCount, leafCount, child, d);
 
  453     void printStatistics(std::ostream & o)
 const 
  455         int maxDep = 0, interiorCount = 0, leafCount = 0;
 
  456         depth(maxDep, interiorCount, leafCount);
 
  458         o << 
"interior nodes: " << interiorCount <<
 
  459              ", terminal nodes: " << leafCount <<
 
  460              ", depth: " << maxDep << 
"\n";
 
  463     void print(std::ostream & o, 
int k = 0, std::string s = 
"")
 const 
  465         DecisionTreeDeprecNodeProxy<TreeInt> node(tree_, k);
 
  466         o << s << (*node.decisionColumns()) << 
" " << terminalWeights_[node.decisionWeightsIndex()] << 
"\n";
 
  468         for(
int l=0; l<2; ++l)
 
  470             int child = node.child(l);
 
  472                 o << s << 
" weights " << terminalWeights_[-child] << 
" " 
  473                                       << terminalWeights_[-child+1] << 
"\n";
 
  475                 print(o, child, s+
" ");
 
  481 template <
class U, 
class C, 
class Iterator, 
class Options, 
class Random>
 
  482 void DecisionTreeDeprec::learn(MultiArrayView<2, U, C> 
const & features,
 
  483                           ArrayVector<int> 
const & labels,
 
  484                           Iterator indices, 
int exampleCount,
 
  485                           Options 
const & options,
 
  488     ArrayVector<double> 
const & classLoss = options.class_weights;
 
  490     vigra_precondition(classLoss.size() == 0 || classLoss.size() == classCount_,
 
  491         "DecisionTreeDeprec2::learn(): class weights array has wrong size.");
 
  495     unsigned int mtry = options.mtry;
 
  498     split.init(mtry, cols, classCount_, classLoss);
 
  500     typedef DecisionTreeDeprecStackEntry<Iterator> Entry;
 
  501     ArrayVector<Entry> stack;
 
  502     stack.push_back(Entry(indices, exampleCount));
 
  504     while(!stack.empty())
 
  507         indices = stack.back().indices;
 
  508         exampleCount = stack.back().exampleCount;
 
  509         int leftParent  = stack.back().leftParent,
 
  510             rightParent = stack.back().rightParent;
 
  514         Iterator bestSplit = split.findBestSplit(features, labels, indices, exampleCount, randint);
 
  517         int currentNode = split.writeSplitParameters(tree_, terminalWeights_);
 
  519         if(leftParent != DecisionTreeDeprecNoParent)
 
  520             DecisionTreeDeprecNodeProxy<TreeInt>(tree_, leftParent).child(0) = currentNode;
 
  521         if(rightParent != DecisionTreeDeprecNoParent)
 
  522             DecisionTreeDeprecNodeProxy<TreeInt>(tree_, rightParent).child(1) = currentNode;
 
  523         leftParent = currentNode;
 
  524         rightParent = DecisionTreeDeprecNoParent;
 
  526         for(
int l=0; l<2; ++l)
 
  529             if(!split.isPure(l) && split.totalCount(l) >= options.min_split_node_size)
 
  532                 stack.push_back(Entry(indices, split.totalCount(l), leftParent, rightParent));
 
  536                 DecisionTreeDeprecNodeProxy<TreeInt>(tree_, currentNode).child(l) = -(TreeInt)terminalWeights_.size();
 
  538                 split.writeWeights(l, terminalWeights_);
 
  540             std::swap(leftParent, rightParent);
 
  549 class RandomForestOptionsDeprec
 
  554     RandomForestOptionsDeprec()
 
  555     : training_set_proportion(1.0),
 
  557       min_split_node_size(1),
 
  558       training_set_size(0),
 
  559       sample_with_replacement(true),
 
  560       sample_classes_individually(false),
 
  572     RandomForestOptionsDeprec & featuresPerNode(
unsigned int n)
 
  585     RandomForestOptionsDeprec & sampleWithReplacement(
bool r)
 
  587         sample_with_replacement = r;
 
  591     RandomForestOptionsDeprec & setTreeCount(
unsigned int cnt)
 
  607     RandomForestOptionsDeprec & trainingSetSizeProportional(
double p)
 
  609         vigra_precondition(p >= 0.0 && p <= 1.0,
 
  610             "RandomForestOptionsDeprec::trainingSetSizeProportional(): proportion must be in [0, 1].");
 
  611         if(training_set_size == 0) 
 
  612             training_set_proportion = p;
 
  624     RandomForestOptionsDeprec & trainingSetSizeAbsolute(
unsigned int s)
 
  626         training_set_size = s;
 
  628             training_set_proportion = 0.0;
 
  642     RandomForestOptionsDeprec & sampleClassesIndividually(
bool s)
 
  644         sample_classes_individually = s;
 
  656     RandomForestOptionsDeprec & minSplitNodeSize(
unsigned int n)
 
  660         min_split_node_size = n;
 
  671     template <
class WeightIterator>
 
  672     RandomForestOptionsDeprec & weights(WeightIterator weights, 
unsigned int classCount)
 
  674         class_weights.clear();
 
  676             class_weights.insert(weights, classCount);
 
  680     RandomForestOptionsDeprec & oobData(MultiArrayView<2, UInt8>& data)
 
  686     MultiArrayView<2, UInt8> oob_data;
 
  687     ArrayVector<double> class_weights;
 
  688     double training_set_proportion;
 
  689     unsigned int mtry, min_split_node_size, training_set_size;
 
  690     bool sample_with_replacement, sample_classes_individually;
 
  691     unsigned int treeCount;
 
  700 template <
class ClassLabelType>
 
  701 class RandomForestDeprec
 
  704     ArrayVector<ClassLabelType> classes_;
 
  705     ArrayVector<detail::DecisionTreeDeprec> trees_;
 
  707     RandomForestOptionsDeprec options_;
 
  713     template<
class ClassLabelIterator>
 
  714     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
 
  715                   unsigned int treeCount = 255,
 
  716                   RandomForestOptionsDeprec 
const & options = RandomForestOptionsDeprec())
 
  717     : classes_(cl, cend),
 
  718       trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
 
  722         vigra_precondition(options.training_set_proportion == 0.0 ||
 
  723                            options.training_set_size == 0,
 
  724             "RandomForestOptionsDeprec: absolute and proportional training set sizes " 
  725             "cannot be specified at the same time.");
 
  726         vigra_precondition(classes_.size() > 1,
 
  727             "RandomForestOptionsDeprec::weights(): need at least two classes.");
 
  728         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
 
  729             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
 
  732     RandomForestDeprec(ClassLabelType 
const & c1, ClassLabelType 
const & c2,
 
  733                   unsigned int treeCount = 255,
 
  734                   RandomForestOptionsDeprec 
const & options = RandomForestOptionsDeprec())
 
  736       trees_(treeCount, detail::DecisionTreeDeprec(2)),
 
  740         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == 2,
 
  741             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
 
  746     template<
class ClassLabelIterator>
 
  747     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
 
  748                   RandomForestOptionsDeprec 
const & options )
 
  749     : classes_(cl, cend),
 
  750       trees_(options.treeCount , detail::DecisionTreeDeprec(classes_.size())),
 
  755         vigra_precondition(options.training_set_proportion == 0.0 ||
 
  756                            options.training_set_size == 0,
 
  757             "RandomForestOptionsDeprec: absolute and proportional training set sizes " 
  758             "cannot be specified at the same time.");
 
  759         vigra_precondition(classes_.size() > 1,
 
  760             "RandomForestOptionsDeprec::weights(): need at least two classes.");
 
  761         vigra_precondition(options.class_weights.size() == 0 || options.class_weights.size() == classes_.size(),
 
  762             "RandomForestOptionsDeprec::weights(): wrong number of classes.");
 
  767     template<
class ClassLabelIterator, 
class TreeIterator, 
class WeightIterator>
 
  768     RandomForestDeprec(ClassLabelIterator cl, ClassLabelIterator cend,
 
  770                   TreeIterator trees, WeightIterator weights)
 
  771     : classes_(cl, cend),
 
  772       trees_(treeCount, detail::DecisionTreeDeprec(classes_.size())),
 
  773       columnCount_(columnCount)
 
  775         for(
unsigned int k=0; k<treeCount; ++k, ++trees, ++weights)
 
  777             trees_[k].tree_ = *trees;
 
  778             trees_[k].terminalWeights_ = *weights;
 
  782     int featureCount()
 const 
  784         vigra_precondition(columnCount_ > 0,
 
  785            "RandomForestDeprec::featureCount(): Random forest has not been trained yet.");
 
  789     int labelCount()
 const 
  791         return classes_.size();
 
  794     int treeCount()
 const 
  796         return trees_.size();
 
  800     template <
class U, 
class C, 
class Array, 
class Random>
 
  801     double learn(MultiArrayView<2, U, C> 
const & features, Array 
const & labels,
 
  802                Random 
const& random);
 
  804     template <
class U, 
class C, 
class Array>
 
  805     double learn(MultiArrayView<2, U, C> 
const & features, Array 
const & labels)
 
  807         RandomNumberGenerator<> generator(RandomSeed);
 
  808         return learn(features, labels, generator);
 
  811     template <
class U, 
class C>
 
  812     ClassLabelType predictLabel(MultiArrayView<2, U, C> 
const & features) 
const;
 
  814     template <
class U, 
class C1, 
class T, 
class C2>
 
  815     void predictLabels(MultiArrayView<2, U, C1> 
const & features,
 
  816                        MultiArrayView<2, T, C2> & labels)
 const 
  818         vigra_precondition(features.shape(0) == labels.shape(0),
 
  819             "RandomForestDeprec::predictLabels(): Label array has wrong size.");
 
  820         for(
int k=0; k<features.shape(0); ++k)
 
  821             labels(k,0) = predictLabel(
rowVector(features, k));
 
  824     template <
class U, 
class C, 
class Iterator>
 
  825     ClassLabelType predictLabel(MultiArrayView<2, U, C> 
const & features,
 
  826                                 Iterator priors) 
const;
 
  828     template <
class U, 
class C1, 
class T, 
class C2>
 
  829     void predictProbabilities(MultiArrayView<2, U, C1> 
const & features,
 
  830                               MultiArrayView<2, T, C2> & prob) 
const;
 
  832     template <
class U, 
class C1, 
class T, 
class C2>
 
  833     void predictNodes(MultiArrayView<2, U, C1> 
const & features,
 
  834                                                    MultiArrayView<2, T, C2> & NodeIDs) 
const;
 
  837 template <
class ClassLabelType>
 
  838 template <
class U, 
class C1, 
class Array, 
class Random>
 
  840 RandomForestDeprec<ClassLabelType>::learn(MultiArrayView<2, U, C1> 
const & features,
 
  841                                              Array 
const & labels,
 
  842                                              Random 
const& random)
 
  844     unsigned int classCount = classes_.size();
 
  845     unsigned int m = 
rowCount(features);
 
  847     vigra_precondition((
unsigned int)(m) == (
unsigned int)labels.size(),
 
  848       "RandomForestDeprec::learn(): Label array has wrong size.");
 
  850     vigra_precondition(options_.training_set_size <= m || options_.sample_with_replacement,
 
  851        "RandomForestDeprec::learn(): Requested training set size exceeds total number of examples.");
 
  858        "RandomForestDeprec::learn(): mtry must be less than number of features.");
 
  861     if(options_.sample_classes_individually)
 
  862         msamples = int(
std::ceil(
double(msamples) / classCount));
 
  864     ArrayVector<int> intLabels(m), classExampleCounts(classCount);
 
  869         typedef std::map<ClassLabelType, int > LabelChecker;
 
  870         typedef typename LabelChecker::iterator LabelCheckerIterator;
 
  871         LabelChecker labelChecker;
 
  872         for(
unsigned int k=0; k<classCount; ++k)
 
  873             labelChecker[classes_[k]] = k;
 
  875         for(
unsigned int k=0; k<m; ++k)
 
  877             LabelCheckerIterator found = labelChecker.find(labels[k]);
 
  878             vigra_precondition(found != labelChecker.end(),
 
  879                 "RandomForestDeprec::learn(): Unknown class label encountered.");
 
  880             intLabels[k] = found->second;
 
  881             ++classExampleCounts[intLabels[k]];
 
  883         minClassCount = *
argMin(classExampleCounts.begin(), classExampleCounts.end());
 
  884         vigra_precondition(minClassCount > 0,
 
  885              "RandomForestDeprec::learn(): At least one class is missing in the training set.");
 
  886         if(msamples > 0 && options_.sample_classes_individually &&
 
  887                           !options_.sample_with_replacement)
 
  889             vigra_precondition(msamples <= minClassCount,
 
  890                 "RandomForestDeprec::learn(): Too few examples in smallest class to reach " 
  891                 "requested training set size.");
 
  895     ArrayVector<int> indices(m);
 
  896     for(
unsigned int k=0; k<m; ++k)
 
  899     if(options_.sample_classes_individually)
 
  901         detail::RandomForestDeprecLabelSorter<ArrayVector<int> > sorter(intLabels);
 
  902         std::sort(indices.begin(), indices.end(), sorter);
 
  905     ArrayVector<int> usedIndices(m), oobCount(m), oobErrorCount(m);
 
  907     UniformIntRandomFunctor<Random> randint(0, m-1, random);
 
  909     for(
unsigned int k=0; k<trees_.size(); ++k)
 
  913         ArrayVector<int> trainingSet;
 
  916         if(options_.sample_classes_individually)
 
  919             for(
unsigned int l=0; l<classCount; ++l)
 
  921                 int lc = classExampleCounts[l];
 
  922                 int lsamples = (msamples == 0)
 
  923                                    ? 
int(
std::ceil(options_.training_set_proportion*lc))
 
  926                 if(options_.sample_with_replacement)
 
  928                     for(
int ll=0; ll<lsamples; ++ll)
 
  930                         trainingSet.push_back(indices[first+randint(lc)]);
 
  931                         ++usedIndices[trainingSet.back()];
 
  936                     for(
int ll=0; ll<lsamples; ++ll)
 
  938                         std::swap(indices[first+ll], indices[first+ll+randint(lc-ll)]);
 
  939                         trainingSet.push_back(indices[first+ll]);
 
  940                         ++usedIndices[trainingSet.back()];
 
  950                 msamples = int(
std::ceil(options_.training_set_proportion*m));
 
  952             if(options_.sample_with_replacement)
 
  954                 for(
int l=0; l<msamples; ++l)
 
  956                     trainingSet.push_back(indices[randint(m)]);
 
  957                     ++usedIndices[trainingSet.back()];
 
  962                 for(
int l=0; l<msamples; ++l)
 
  964                     std::swap(indices[l], indices[l+randint(m-l)]);
 
  965                     trainingSet.push_back(indices[l]);
 
  966                     ++usedIndices[trainingSet.back()];
 
  973         trees_[k].learn(features, intLabels,
 
  974                         trainingSet.begin(), trainingSet.size(),
 
  975                         options_.featuresPerNode(mtry), randint);
 
  986         for(
unsigned int l=0; l<m; ++l)
 
  991                 if(trees_[k].predictLabel(
rowVector(features, l)) != intLabels[l])
 
  994                     if(options_.oob_data.data() != 0)
 
  995                         options_.oob_data(l, k) = 2;
 
  997                 else if(options_.oob_data.data() != 0)
 
  999                     options_.oob_data(l, k) = 1;
 
 1008         #ifdef VIGRA_RF_VERBOSE 
 1009         trees_[k].printStatistics(std::cerr);
 
 1012     double oobError = 0.0;
 
 1013     int totalOobCount = 0;
 
 1014     for(
unsigned int l=0; l<m; ++l)
 
 1017             oobError += double(oobErrorCount[l]) / oobCount[l];
 
 1020     return oobError / totalOobCount;
 
 1023 template <
class ClassLabelType>
 
 1024 template <
class U, 
class C>
 
 1026 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> 
const & features)
 const 
 1028     vigra_precondition(
columnCount(features) >= featureCount(),
 
 1029         "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
 
 1030     vigra_precondition(
rowCount(features) == 1,
 
 1031         "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
 
 1032     Matrix<double> prob(1, classes_.size());
 
 1033     predictProbabilities(features, prob);
 
 1034     return classes_[
argMax(prob)];
 
 1039 template <
class ClassLabelType>
 
 1040 template <
class U, 
class C, 
class Iterator>
 
 1042 RandomForestDeprec<ClassLabelType>::predictLabel(MultiArrayView<2, U, C> 
const & features,
 
 1043                                            Iterator priors)
 const 
 1045     using namespace functor;
 
 1046     vigra_precondition(
columnCount(features) >= featureCount(),
 
 1047         "RandomForestDeprec::predictLabel(): Too few columns in feature matrix.");
 
 1048     vigra_precondition(
rowCount(features) == 1,
 
 1049         "RandomForestDeprec::predictLabel(): Feature matrix must have a single row.");
 
 1050     Matrix<double> prob(1,classes_.size());
 
 1051     predictProbabilities(features, prob);
 
 1052     std::transform(prob.begin(), prob.end(), priors, prob.begin(), Arg1()*Arg2());
 
 1053     return classes_[
argMax(prob)];
 
 1056 template <
class ClassLabelType>
 
 1057 template <
class U, 
class C1, 
class T, 
class C2>
 
 1059 RandomForestDeprec<ClassLabelType>::predictProbabilities(MultiArrayView<2, U, C1> 
const & features,
 
 1060                                                    MultiArrayView<2, T, C2> & prob)
 const 
 1067       "RandomForestDeprec::predictProbabilities(): Feature matrix and probability matrix size mismatch.");
 
 1071     vigra_precondition(
columnCount(features) >= featureCount(),
 
 1072       "RandomForestDeprec::predictProbabilities(): Too few columns in feature matrix.");
 
 1074       "RandomForestDeprec::predictProbabilities(): Probability matrix must have as many columns as there are classes.");
 
 1077     for(
int row=0; row < 
rowCount(features); ++row)
 
 1082         ArrayVector<double>::const_iterator weights;
 
 1085     double totalWeight = 0.0;
 
 1089         for(
unsigned int l=0; l<classes_.size(); ++l)
 
 1093         for(
unsigned int k=0; k<trees_.size(); ++k)
 
 1096             weights = trees_[k].predict(
rowVector(features, row));
 
 1099             for(
unsigned int l=0; l<classes_.size(); ++l)
 
 1101                 prob(row, l) += detail::RequiresExplicitCast<T>::cast(weights[l]);
 
 1103                 totalWeight += weights[l];
 
 1108         for(
unsigned int l=0; l<classes_.size(); ++l)
 
 1109                 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
 
 1114 template <
class ClassLabelType>
 
 1115 template <
class U, 
class C1, 
class T, 
class C2>
 
 1117 RandomForestDeprec<ClassLabelType>::predictNodes(MultiArrayView<2, U, C1> 
const & features,
 
 1118                                                    MultiArrayView<2, T, C2> & NodeIDs)
 const 
 1120     vigra_precondition(
columnCount(features) >= featureCount(),
 
 1121       "RandomForestDeprec::getNodesRF(): Too few columns in feature matrix.");
 
 1123       "RandomForestDeprec::getNodesRF(): Too few rows in NodeIds matrix");
 
 1124     vigra_precondition(
columnCount(NodeIDs) >= treeCount(),
 
 1125       "RandomForestDeprec::getNodesRF(): Too few columns in NodeIds matrix.");
 
 1127     for(
unsigned int k=0; k<trees_.size(); ++k)
 
 1129         for(
int row=0; row < 
rowCount(features); ++row)
 
 1131             NodeIDs(row,k) = trees_[k].leafID(
rowVector(features, row));
 
 1141 #endif // VIGRA_RANDOM_FOREST_HXX 
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:671
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence. 
Definition: algorithm.hxx:96
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:697
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:684
Iterator argMin(Iterator first, Iterator last)
Find the minimum element in a sequence. 
Definition: algorithm.hxx:68
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up. 
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down. 
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root. 
Definition: fixedpoint.hxx:616