35 #ifndef VIGRA_RF3_VISITORS_HXX 
   36 #define VIGRA_RF3_VISITORS_HXX 
   40 #include "../multi_array.hxx" 
   41 #include "../multi_shape.hxx" 
   89     template <
typename VISITORS, 
typename RF, 
typename FEATURES, 
typename LABELS>
 
   98     template <
typename TREE, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  105     template <
typename RF, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  115     template <
typename TREE,
 
  179     template <
typename TREE, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  186         double const EPS = 1e-20;
 
  190         is_in_bag_.resize(weights.size(), 
true);
 
  191         for (
size_t i = 0; i < weights.size(); ++i)
 
  195                 is_in_bag_[i] = 
false;
 
  201             throw std::runtime_error(
"OOBError::visit_before_tree(): The tree has no out-of-bags.");
 
  207     template <
typename VISITORS, 
typename RF, 
typename FEATURES, 
typename LABELS>
 
  211             const FEATURES & features,
 
  212             const LABELS & labels
 
  215         vigra_precondition(rf.num_trees() > 0, 
"OOBError::visit_after_training(): Number of trees must be greater than zero after training.");
 
  216         vigra_precondition(visitors.size() == rf.num_trees(), 
"OOBError::visit_after_training(): Number of visitors must be equal to number of trees.");
 
  217         size_t const num_instances = features.shape()[0];
 
  218         auto const num_features = features.shape()[1];
 
  219         for (
auto vptr : visitors)
 
  220             vigra_precondition(vptr->is_in_bag_.size() == num_instances, 
"OOBError::visit_after_training(): Some visitors have the wrong number of data points.");
 
  223         typedef typename std::remove_const<LABELS>::type Labels;
 
  226         for (
size_t i = 0; i < (size_t)num_instances; ++i)
 
  229             std::vector<size_t> tree_indices;
 
  230             for (
size_t k = 0; k < visitors.size(); ++k)
 
  231                 if (!visitors[k]->is_in_bag_[i])
 
  232                     tree_indices.push_back(k);
 
  235             auto const sub_features = features.subarray(Shape2(i, 0), Shape2(i+1, num_features));
 
  236             rf.predict(sub_features, pred, 1, tree_indices);
 
  237             if (pred(0) != labels(i))
 
  249     std::vector<bool> is_in_bag_; 
 
  269     template <
typename TREE, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  279         auto const num_features = features.shape()[1];
 
  283         double const EPS = 1e-20;
 
  285         is_in_bag_.resize(weights.size(), 
true);
 
  286         for (
size_t i = 0; i < weights.size(); ++i)
 
  290                 is_in_bag_[i] = 
false;
 
  295             throw std::runtime_error(
"VariableImportance::visit_before_tree(): The tree has no out-of-bags.");
 
  301     template <
typename TREE,
 
  317         typename SCORER::Functor functor;
 
  318         auto const region_impurity = functor.region_score(labels, weights, begin, end);
 
  319         auto const split_impurity = scorer.best_score_;
 
  320         variable_importance_(scorer.best_dim_, tree.num_classes()+1) += region_impurity - split_impurity;
 
  326     template <
typename RF, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  328                           const FEATURES & features,
 
  329                           const LABELS & labels,
 
  333         typedef typename std::remove_const<FEATURES>::type Features;
 
  334         typedef typename std::remove_const<LABELS>::type Labels;
 
  336         typedef typename Features::value_type FeatureType;
 
  338         auto const num_features = features.shape()[1];
 
  345         copy_out_of_bags(features, labels, feats, labs);
 
  346         auto const num_oobs = feats.shape()[0];
 
  351         rf.predict(feats, pred, 1);
 
  352         for (
size_t i = 0; i < (size_t)labs.size(); ++i)
 
  354             if (labs(i) == pred(i))
 
  356                 oob_right(labs(i)) += 1.0; 
 
  357                 oob_right(rf.num_classes()) += 1.0; 
 
  363         for (
size_t j = 0; j < (size_t)num_features; ++j)
 
  366             backup = feats.template bind<1>(j);
 
  372                 for (
int ii = num_oobs-1; ii >= 1; --ii)
 
  373                     std::swap(feats(ii, j), feats(randint(ii+1), j));
 
  376                 rf.predict(feats, pred, 1);
 
  377                 for (
size_t i = 0; i < (size_t)labs.size(); ++i)
 
  379                     if (labs(i) == pred(i))
 
  381                         perm_oob_right(0, labs(i)) += 1.0; 
 
  382                         perm_oob_right(0, rf.num_classes()) += 1.0; 
 
  389             perm_oob_right.bind<0>(0) -= oob_right;
 
  390             perm_oob_right *= -1;
 
  391             perm_oob_right /= num_oobs;
 
  395             feats.template bind<1>(j) = backup;
 
  402     template <
typename VISITORS, 
typename RF, 
typename FEATURES, 
typename LABELS>
 
  406             const FEATURES & features,
 
  409         vigra_precondition(rf.num_trees() > 0, 
"VariableImportance::visit_after_training(): Number of trees must be greater than zero after training.");
 
  410         vigra_precondition(visitors.size() == rf.num_trees(), 
"VariableImportance::visit_after_training(): Number of visitors must be equal to number of trees.");
 
  413         auto const num_features = features.shape()[1];
 
  415         for (
auto vptr : visitors)
 
  418                                "VariableImportance::visit_after_training(): Shape mismatch.");
 
  464     template <
typename F0, 
typename L0, 
typename F1, 
typename L1>
 
  465     void copy_out_of_bags(
 
  466             F0 
const & features_in,
 
  467             L0 
const & labels_in,
 
  471         auto const num_instances = features_in.shape()[0];
 
  472         auto const num_features = features_in.shape()[1];
 
  476         for (
auto x : is_in_bag_)
 
  481         features_out.reshape(Shape2(num_oobs, num_features));
 
  482         labels_out.reshape(
Shape1(num_oobs));
 
  484         for (
size_t i = 0; i < (size_t)num_instances; ++i)
 
  488                 auto const src = features_in.template bind<0>(i);
 
  489                 auto out = features_out.template bind<0>(current);
 
  491                 labels_out(current) = labels_in(i);
 
  497     std::vector<bool> is_in_bag_; 
 
  518 template <
typename VISITOR, 
typename NEXT = RFStopVisiting, 
bool CPY = false>
 
  523     typedef VISITOR Visitor;
 
  526     typename std::conditional<CPY, Visitor, Visitor &>::type visitor_;
 
  543         visitor_(other.visitor_),
 
  549         visitor_(other.visitor_),
 
  553     void visit_before_training()
 
  555         if (visitor_.is_active())
 
  556             visitor_.visit_before_training();
 
  557         next_.visit_before_training();
 
  560     template <
typename VISITORS, 
typename RF, 
typename FEATURES, 
typename LABELS>
 
  561     void visit_after_training(VISITORS & v, RF & rf, 
const FEATURES & features, 
const LABELS & labels)
 
  563         typedef typename VISITORS::value_type VisitorNodeType;
 
  564         typedef typename VisitorNodeType::Visitor VisitorType;
 
  565         typedef typename VisitorNodeType::Next NextType;
 
  571         if (visitor_.is_active())
 
  573             std::vector<VisitorType*> visitors;
 
  575                 visitors.push_back(&x.visitor_);
 
  576             visitor_.visit_after_training(visitors, rf, features, labels);
 
  580         std::vector<NextType> nexts;
 
  582             nexts.push_back(x.next_);
 
  585         next_.visit_after_training(nexts, rf, features, labels);
 
  588     template <
typename TREE, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  589     void visit_before_tree(TREE & tree, FEATURES & features, LABELS & labels, WEIGHTS & weights)
 
  591         if (visitor_.is_active())
 
  592             visitor_.visit_before_tree(tree, features, labels, weights);
 
  593         next_.visit_before_tree(tree, features, labels, weights);
 
  596     template <
typename RF, 
typename FEATURES, 
typename LABELS, 
typename WEIGHTS>
 
  597     void visit_after_tree(RF & rf,
 
  602         if (visitor_.is_active())
 
  603             visitor_.visit_after_tree(rf, features, labels, weights);
 
  604         next_.visit_after_tree(rf, features, labels, weights);
 
  607     template <
typename TREE,
 
  613     void visit_after_split(TREE & tree,
 
  622         if (visitor_.is_active())
 
  623             visitor_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
 
  624         next_.visit_after_split(tree, features, labels, weights, scorer, begin, split, end);
 
  634 template <
typename VISITOR>
 
  654 detail::RFVisitorNode<A>
 
  655 create_visitor(A & a)
 
  657     typedef detail::RFVisitorNode<A> _0_t;
 
  662 template<
typename A, 
typename B>
 
  663 detail::RFVisitorNode<A, detail::RFVisitorNode<B> >
 
  664 create_visitor(A & a, B & b)
 
  666     typedef detail::RFVisitorNode<B> _1_t;
 
  668     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  673 template<
typename A, 
typename B, 
typename C>
 
  674 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C> > >
 
  677     typedef detail::RFVisitorNode<C> _2_t;
 
  679     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  681     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  686 template<
typename A, 
typename B, 
typename C, 
typename D>
 
  687 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  688     detail::RFVisitorNode<D> > > >
 
  691     typedef detail::RFVisitorNode<D> _3_t;
 
  693     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  695     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  697     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  702 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E>
 
  703 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  704     detail::RFVisitorNode<D, detail::RFVisitorNode<E> > > > >
 
  707     typedef detail::RFVisitorNode<E> _4_t;
 
  709     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  711     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  713     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  715     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  720 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E,
 
  722 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  723     detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F> > > > > >
 
  726     typedef detail::RFVisitorNode<F> _5_t;
 
  728     typedef detail::RFVisitorNode<E, _5_t> _4_t;
 
  730     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  732     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  734     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  736     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  741 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E,
 
  742          typename F, 
typename G>
 
  743 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  744     detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
 
  745     detail::RFVisitorNode<G> > > > > > >
 
  748     typedef detail::RFVisitorNode<G> _6_t;
 
  750     typedef detail::RFVisitorNode<F, _6_t> _5_t;
 
  752     typedef detail::RFVisitorNode<E, _5_t> _4_t;
 
  754     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  756     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  758     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  760     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  765 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E,
 
  766          typename F, 
typename G, 
typename H>
 
  767 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  768     detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
 
  769     detail::RFVisitorNode<G, detail::RFVisitorNode<H> > > > > > > >
 
  770 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h)
 
  772     typedef detail::RFVisitorNode<H> _7_t;
 
  774     typedef detail::RFVisitorNode<G, _7_t> _6_t;
 
  776     typedef detail::RFVisitorNode<F, _6_t> _5_t;
 
  778     typedef detail::RFVisitorNode<E, _5_t> _4_t;
 
  780     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  782     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  784     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  786     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  791 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E,
 
  792          typename F, 
typename G, 
typename H, 
typename I>
 
  793 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  794     detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
 
  795     detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I> > > > > > > > >
 
  796 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i)
 
  798     typedef detail::RFVisitorNode<I> _8_t;
 
  800     typedef detail::RFVisitorNode<H, _8_t> _7_t;
 
  802     typedef detail::RFVisitorNode<G, _7_t> _6_t;
 
  804     typedef detail::RFVisitorNode<F, _6_t> _5_t;
 
  806     typedef detail::RFVisitorNode<E, _5_t> _4_t;
 
  808     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  810     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  812     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  814     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
  819 template<
typename A, 
typename B, 
typename C, 
typename D, 
typename E,
 
  820          typename F, 
typename G, 
typename H, 
typename I, 
typename J>
 
  821 detail::RFVisitorNode<A, detail::RFVisitorNode<B, detail::RFVisitorNode<C, 
 
  822     detail::RFVisitorNode<D, detail::RFVisitorNode<E, detail::RFVisitorNode<F,
 
  823     detail::RFVisitorNode<G, detail::RFVisitorNode<H, detail::RFVisitorNode<I,
 
  824     detail::RFVisitorNode<J> > > > > > > > > >
 
  825 create_visitor(A & a, B & b, C & c, D & d, E & e, F & f, G & g, H & h, I & i,
 
  828     typedef detail::RFVisitorNode<J> _9_t;
 
  830     typedef detail::RFVisitorNode<I, _9_t> _8_t;
 
  832     typedef detail::RFVisitorNode<H, _8_t> _7_t;
 
  834     typedef detail::RFVisitorNode<G, _7_t> _6_t;
 
  836     typedef detail::RFVisitorNode<F, _6_t> _5_t;
 
  838     typedef detail::RFVisitorNode<E, _5_t> _4_t;
 
  840     typedef detail::RFVisitorNode<D, _4_t> _3_t;
 
  842     typedef detail::RFVisitorNode<C, _3_t> _2_t;
 
  844     typedef detail::RFVisitorNode<B, _2_t> _1_t;
 
  846     typedef detail::RFVisitorNode<A, _1_t> _0_t;
 
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &)
Do something before a tree has been learned. 
Definition: random_forest_visitors.hxx:99
void visit_before_tree(TREE &tree, FEATURES &features, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:270
void visit_after_split(TREE &, FEATURES &, LABELS &, WEIGHTS &, SCORER &, ITER, ITER, ITER)
Do something after the split was made. 
Definition: random_forest_visitors.hxx:121
const difference_type & shape() const 
Definition: multi_array.hxx:1648
void visit_after_tree(RF &, FEATURES &, LABELS &, WEIGHTS &)
Do something after a tree has been learned. 
Definition: random_forest_visitors.hxx:106
void deactivate()
Deactivate the visitor. 
Definition: random_forest_visitors.hxx:150
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
Base class from which all random forest visitors derive. 
Definition: random_forest_visitors.hxx:68
size_t repetition_count_
Definition: random_forest_visitors.hxx:457
The default visitor node (= "do nothing"). 
Definition: random_forest_visitors.hxx:509
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &labels)
Definition: random_forest_visitors.hxx:208
Compute the variable importance. 
Definition: random_forest_visitors.hxx:257
Compute the out of bag error. 
Definition: random_forest_visitors.hxx:172
double oob_err_
Definition: random_forest_visitors.hxx:246
void visit_after_tree(RF &rf, const FEATURES &features, const LABELS &labels, WEIGHTS &)
Definition: random_forest_visitors.hxx:327
Definition: random_forest_visitors.hxx:635
void activate()
Activate the visitor. 
Definition: random_forest_visitors.hxx:142
Class for fixed size vectors.This class contains an array of size SIZE of the specified VALUETYPE...
Definition: accessor.hxx:940
void visit_after_split(TREE &tree, FEATURES &, LABELS &labels, WEIGHTS &weights, SCORER &scorer, ITER begin, ITER, ITER end)
Definition: random_forest_visitors.hxx:307
void visit_before_tree(TREE &, FEATURES &, LABELS &, WEIGHTS &weights)
Definition: random_forest_visitors.hxx:180
Container elements of the statically linked visitor list. Use the create_visitor() functions to creat...
Definition: random_forest_visitors.hxx:519
MultiArray< 2, double > variable_importance_
Definition: random_forest_visitors.hxx:452
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude) 
Definition: fftw3.hxx:1002
bool is_active() const 
Return whether the visitor is active or not. 
Definition: random_forest_visitors.hxx:134
void visit_before_training()
Do something before training starts. 
Definition: random_forest_visitors.hxx:80
MultiArrayView subarray(difference_type p, difference_type q) const 
Definition: multi_array.hxx:1528
detail::VisitorNode< A > create_visitor(A &a)
Definition: rf_visitors.hxx:344
void visit_after_training(VISITORS &, RF &, const FEATURES &, const LABELS &)
Do something after all trees have been learned. 
Definition: random_forest_visitors.hxx:90
void visit_after_training(VISITORS &visitors, RF &rf, const FEATURES &features, const LABELS &)
Definition: random_forest_visitors.hxx:403