35 #ifndef VIGRA_RF3_RANDOM_FOREST_HXX 
   36 #define VIGRA_RF3_RANDOM_FOREST_HXX 
   38 #include <type_traits> 
   41 #include "../multi_shape.hxx" 
   42 #include "../binary_forest.hxx" 
   43 #include "../threadpool.hxx" 
   44 #include "random_forest_common.hxx" 
   64 template <
typename FEATURES,
 
   66           typename SPLITTESTS = LessEqualSplitTest<typename FEATURES::value_type>,
 
   67           typename ACCTYPE = ArgMaxVectorAcc<double>>
 
   72     typedef FEATURES Features;
 
   73     typedef typename Features::value_type FeatureType;
 
   74     typedef LABELS Labels;
 
   75     typedef typename Labels::value_type LabelType;
 
   76     typedef SPLITTESTS SplitTests;
 
   78     typedef typename ACC::input_type AccInputType;
 
   82     static ContainerTag 
const container_tag = VectorTag;
 
  102         typename NodeMap<SplitTests>::type 
const & split_tests,
 
  103         typename NodeMap<AccInputType>::type 
const & node_responses,
 
  104         ProblemSpec<LabelType> 
const & problem_spec
 
  115         FEATURES 
const & features,
 
  118         const std::vector<size_t> & tree_indices = std::vector<size_t>()
 
  123     template <
typename PROBS>
 
  125         FEATURES 
const & features,
 
  128         const std::vector<size_t> & tree_indices = std::vector<size_t>()
 
  133     template <
typename IDS>
 
  135         FEATURES 
const & features,
 
  138         const std::vector<size_t> tree_indices = std::vector<size_t>()
 
  183     template <
typename IDS, 
typename INDICES>
 
  184     double leaf_ids_impl(
 
  185         FEATURES 
const & features,
 
  189         INDICES 
const & tree_indices
 
  192     template<
typename PROBS>
 
  193     void predict_probabilities_impl(
 
  194         FEATURES 
const & features,
 
  197         const std::vector<size_t> & tree_indices) 
const;
 
  201 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  210 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  213     typename NodeMap<SplitTests>::type 
const & split_tests,
 
  214     typename NodeMap<AccInputType>::type 
const & node_responses,
 
  218     split_tests_(split_tests),
 
  219     node_responses_(node_responses),
 
  220     problem_spec_(problem_spec)
 
  223 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  228                        "RandomForest::merge(): You cannot merge with different problem specs.");
 
  232     size_t const offset = num_nodes();
 
  233     graph_.merge(other.
graph_);
 
  236         split_tests_.insert(Node(p.first.id()+offset), p.second);
 
  240         node_responses_.insert(Node(p.first.id()+offset), p.second);
 
  246 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  248     FEATURES 
const & features,
 
  251     const std::vector<size_t> & tree_indices
 
  253     vigra_precondition(features.shape()[0] == labels.shape()[0],
 
  254                        "RandomForest::predict(): Shape mismatch between features and labels.");
 
  255     vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
 
  256                        "RandomForest::predict(): Number of features in prediction differs from training.");
 
  259     predict_probabilities(features, probs, n_threads, tree_indices);
 
  260     for (
size_t i = 0; i < (size_t)features.shape()[0]; ++i)
 
  262         auto const sub_probs = probs.template bind<0>(i);
 
  263         auto it = std::max_element(sub_probs.begin(), sub_probs.end());
 
  264         size_t const label = std::distance(sub_probs.begin(), it);
 
  265         labels(i) = problem_spec_.distinct_classes_[label];
 
  272 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  273 template <
typename PROBS>
 
  275     FEATURES 
const & features,
 
  278     const std::vector<size_t> & tree_indices
 
  280     vigra_precondition(features.shape()[0] == probs.shape()[0],
 
  281                        "RandomForest::predict_probabilities(): Shape mismatch between features and probabilities.");
 
  282     vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
 
  283                        "RandomForest::predict_probabilities(): Number of features in prediction differs from training.");
 
  284     vigra_precondition((
size_t)probs.shape()[1] == problem_spec_.num_classes_,
 
  285                        "RandomForest::predict_probabilities(): Number of labels in probabilities differs from training.");
 
  289     std::vector<size_t> tree_indices_cpy(tree_indices);
 
  290     if (tree_indices_cpy.size() == 0)
 
  292         tree_indices_cpy.resize(graph_.numRoots());
 
  293         std::iota(tree_indices_cpy.begin(), tree_indices_cpy.end(), 0);
 
  297         std::sort(tree_indices_cpy.begin(), tree_indices_cpy.end());
 
  298         tree_indices_cpy.erase(std::unique(tree_indices_cpy.begin(), tree_indices_cpy.end()), tree_indices_cpy.end());
 
  299         for (
auto i : tree_indices_cpy)
 
  300             vigra_precondition(i < graph_.numRoots(), 
"RandomForest::leaf_ids(): Tree index out of range.");
 
  303     size_t const num_instances = features.shape()[0];
 
  306         n_threads = std::thread::hardware_concurrency();
 
  313         [&features,&probs,&tree_indices_cpy,
this](
size_t, 
size_t i) {
 
  314             this->predict_probabilities_impl(features, probs, i, tree_indices_cpy);
 
  319 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  320 template <
typename PROBS>
 
  322     FEATURES 
const & features,
 
  325     const std::vector<size_t> & tree_indices
 
  330     std::vector<AccInputType> tree_results;
 
  331     tree_results.reserve(tree_indices.size());
 
  332     auto const sub_features = features.template bind<0>(i);
 
  335     for (
auto k : tree_indices)
 
  337         Node node = graph_.getRoot(k);
 
  338         while (graph_.outDegree(node) > 0)
 
  340             size_t const child_index = split_tests_.at(node)(sub_features);
 
  341             node = graph_.getChild(node, child_index);
 
  343         tree_results.emplace_back(node_responses_.at(node));
 
  347     auto sub_probs = probs.template bind<0>(i);
 
  348     acc(tree_results.begin(), tree_results.end(), sub_probs.begin());    
 
  351 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  352 template <
typename IDS>
 
  354     FEATURES 
const & features,
 
  357     std::vector<size_t> tree_indices
 
  359     vigra_precondition(features.shape()[0] == ids.shape()[0],
 
  360                        "RandomForest::leaf_ids(): Shape mismatch between features and probabilities.");
 
  361     vigra_precondition((
size_t)features.shape()[1] == problem_spec_.num_features_,
 
  362                        "RandomForest::leaf_ids(): Number of features in prediction differs from training.");
 
  363     vigra_precondition(ids.shape()[1] == graph_.numRoots(),
 
  364                        "RandomForest::leaf_ids(): Leaf array has wrong shape.");
 
  367     std::sort(tree_indices.begin(), tree_indices.end());
 
  368     tree_indices.erase(std::unique(tree_indices.begin(), tree_indices.end()), tree_indices.end());
 
  369     for (
auto i : tree_indices)
 
  370         vigra_precondition(i < graph_.numRoots(), 
"RandomForest::leaf_ids(): Tree index out of range.");
 
  373     if (tree_indices.size() == 0)
 
  375         tree_indices.resize(graph_.numRoots());
 
  376         std::iota(tree_indices.begin(), tree_indices.end(), 0);
 
  379     size_t const num_instances = features.shape()[0];
 
  381         n_threads = std::thread::hardware_concurrency();
 
  384     std::vector<double> split_comparisons(n_threads, 0.0);
 
  385     std::vector<size_t> indices(num_instances);
 
  386     std::iota(indices.begin(), indices.end(), 0);
 
  387     std::fill(ids.begin(), ids.end(), -1);
 
  392         [
this, &features, &ids, &split_comparisons, &tree_indices](
size_t thread_id, 
size_t i) {
 
  393             split_comparisons[thread_id] += this->leaf_ids_impl(features, ids, i, i+1, tree_indices);
 
  397     double const sum_split_comparisons = std::accumulate(split_comparisons.begin(), split_comparisons.end(), 0.0);
 
  398     return sum_split_comparisons / features.shape()[0];
 
  401 template <
typename FEATURES, 
typename LABELS, 
typename SPLITTESTS, 
typename ACC>
 
  402 template <
typename IDS, 
typename INDICES>
 
  404     FEATURES 
const & features,
 
  408     INDICES 
const & tree_indices
 
  410     vigra_precondition(features.shape()[0] == ids.shape()[0],
 
  411                        "RandomForest::leaf_ids_impl(): Shape mismatch between features and labels.");
 
  412     vigra_precondition(features.shape()[1] == problem_spec_.num_features_,
 
  413                        "RandomForest::leaf_ids_impl(): Number of Features in prediction differs from training.");
 
  414     vigra_precondition(from >= 0 && from <= to && to <= (
size_t)features.shape()[0],
 
  415                        "RandomForest::leaf_ids_impl(): Indices out of range.");
 
  416     vigra_precondition(ids.shape()[1] == graph_.numRoots(),
 
  417                        "RandomForest::leaf_ids_impl(): Leaf array has wrong shape.");
 
  419     double split_comparisons = 0.0;
 
  420     for (
size_t i = from; i < to; ++i)
 
  422         auto const sub_features = features.template bind<0>(i);
 
  423         for (
auto k : tree_indices)
 
  425             Node node = graph_.getRoot(k);
 
  426             while (graph_.outDegree(node) > 0)
 
  428                 size_t const child_index = split_tests_.at(node)(sub_features);
 
  429                 node = graph_.getChild(node, child_index);
 
  430                 split_comparisons += 1.0;
 
  432             ids(i, k) = node.id();
 
  435     return split_comparisons;
 
void predict_probabilities(FEATURES const &features, PROBS &probs, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const 
Predict the probabilities of the given data and return the average number of split comparisons...
Definition: random_forest.hxx:274
void predict(FEATURES const &features, LABELS &labels, int n_threads=-1, const std::vector< size_t > &tree_indices=std::vector< size_t >()) const 
Predict the given data and return the average number of split comparisons. 
Definition: random_forest.hxx:247
The PropertyMap is used to store Node or Arc information of graphs. 
Definition: graphs.hxx:410
void merge(RandomForest const &other)
Grow this forest by incorporating the other. 
Definition: random_forest.hxx:224
detail::NodeDescriptor< index_type > Node
Node descriptor type of the present graph. 
Definition: binary_forest.hxx:70
size_t numNodes() const 
Return the number of nodes (equivalent to maxNodeId()+1). 
Definition: binary_forest.hxx:289
NodeMap< SplitTests >::type split_tests_
Contains a test for each internal node, that is used to determine whether given data goes to the left...
Definition: random_forest.hxx:169
problem specification class for the random forest. 
Definition: rf_common.hxx:538
size_t num_features() const 
Return the number of classes. 
Definition: random_forest.hxx:160
Graph graph_
The graph structure. 
Definition: random_forest.hxx:166
Random forest version 3. 
Definition: random_forest.hxx:68
RandomForestOptions options_
The options that were used for training. 
Definition: random_forest.hxx:178
double leaf_ids(FEATURES const &features, IDS &ids, int n_threads=-1, const std::vector< size_t > tree_indices=std::vector< size_t >()) const 
For each data point in features, compute the corresponding leaf ids and return the average number of ...
Definition: random_forest.hxx:353
Random forest version 2 (see also vigra::rf3::RandomForest for version 3) 
Definition: random_forest.hxx:147
ProblemSpec< LabelType > problem_spec_
The specifications. 
Definition: random_forest.hxx:175
void parallel_foreach(...)
Apply a functor to all items in a range in parallel. 
BinaryForest stores a collection of rooted binary trees. 
Definition: binary_forest.hxx:64
NodeMap< AccInputType >::type node_responses_
Contains the responses of each node (for example the most frequent label). 
Definition: random_forest.hxx:172
size_t numRoots() const 
Return the number of trees in the forest. 
Definition: binary_forest.hxx:332
size_t num_nodes() const 
Return the number of nodes. 
Definition: random_forest.hxx:142
Options class for vigra::rf3::RandomForest version 3. 
Definition: random_forest_common.hxx:582
size_t num_classes() const 
Return the number of classes. 
Definition: random_forest.hxx:154
size_t num_trees() const 
Return the number of trees. 
Definition: random_forest.hxx:148