44 #include "multi_array.hxx" 
   45 #include "sampling.hxx" 
   46 #include "threading.hxx" 
   47 #include "threadpool.hxx" 
   48 #include "random_forest_3/random_forest.hxx" 
   49 #include "random_forest_3/random_forest_common.hxx" 
   50 #include "random_forest_3/random_forest_visitors.hxx" 
   68 template <
typename FEATURES, 
typename LABELS>
 
   71     typedef RandomForest<FEATURES,
 
   73                          LessEqualSplitTest<typename FEATURES::value_type>,
 
   74                          ArgMaxVectorAcc<double> > type;
 
   85 template <
typename ACC>
 
   88     template <
typename A, 
typename B>
 
   89     void operator()(A & a, B 
const & b)
 const 
   98 struct RFMapUpdater<ArgMaxAcc>
 
  100     template <
typename A, 
typename B>
 
  101     void operator()(A & a, B 
const & b)
 const 
  103         auto it = std::max_element(b.begin(), b.end());
 
  104         a = std::distance(b.begin(), it);
 
  111 template <
typename FEATURES, 
typename LABELS, 
typename SAMPLER, 
typename SCORER>
 
  113         FEATURES 
const & features,
 
  114         LABELS 
const & labels,
 
  115         std::vector<double> 
const & instance_weights,
 
  116         std::vector<size_t> 
const & instances,
 
  117         SAMPLER 
const & dim_sampler,
 
  120     typedef typename FEATURES::value_type FeatureType;
 
  122     auto feats = std::vector<FeatureType>(instances.size()); 
 
  123     auto sorted_indices = std::vector<size_t>(feats.size()); 
 
  124     auto tosort_instances = std::vector<size_t>(feats.size()); 
 
  126     for (
int i = 0; i < dim_sampler.sampleSize(); ++i)
 
  128         size_t const d = dim_sampler[i];
 
  131         for (
size_t kk = 0; kk < instances.size(); ++kk)
 
  132             feats[kk] = features(instances[kk], d);
 
  135         indexSort(feats.begin(), feats.end(), sorted_indices.begin());
 
  136         std::copy(instances.begin(), instances.end(), tosort_instances.begin());
 
  137         applyPermutation(sorted_indices.begin(), sorted_indices.end(), instances.begin(), tosort_instances.begin());
 
  140         score(features, labels, instance_weights, tosort_instances.begin(), tosort_instances.end(), d);
 
  149 template <
typename RF, 
typename SCORER, 
typename VISITOR, 
typename STOP, 
typename RANDENGINE>
 
  150 void random_forest_single_tree(
 
  151         typename RF::Features 
const & features,
 
  152         MultiArray<1, size_t>  
const & labels,
 
  153         RandomForestOptions 
const & options,
 
  157         RANDENGINE 
const & randengine
 
  159     typedef typename RF::Features Features;
 
  160     typedef typename Features::value_type FeatureType;
 
  161     typedef LessEqualSplitTest<FeatureType> SplitTests;
 
  162     typedef typename RF::Node Node;
 
  163     typedef typename RF::ACC ACC;
 
  164     typedef typename ACC::input_type ACCInputType;
 
  166     static_assert(std::is_same<SplitTests, typename RF::SplitTests>::value,
 
  167                   "random_forest_single_tree(): Wrong Random Forest class.");
 
  170     int const num_instances = features.shape()[0];
 
  171     size_t const num_features = features.shape()[1];
 
  172     auto const & spec = tree.problem_spec_;
 
  174     vigra_precondition(num_instances == labels.size(),
 
  175                        "random_forest_single_tree(): Shape mismatch between features and labels.");
 
  176     vigra_precondition(num_features == spec.num_features_,
 
  177                        "random_forest_single_tree(): Wrong number of features.");
 
  180     std::vector<size_t> instance_indices(num_instances);
 
  181     std::iota(instance_indices.begin(), instance_indices.end(), 0);
 
  182     typedef std::vector<size_t>::iterator InstanceIter;
 
  185     std::vector<double> instance_weights(num_instances, 1.0);
 
  186     if (options.bootstrap_sampling_)
 
  188         std::fill(instance_weights.begin(), instance_weights.end(), 0.0);
 
  189         Sampler<MersenneTwister> sampler(num_instances,
 
  190                                          SamplerOptions().withReplacement().stratified(options.use_stratification_),
 
  193         for (
int i = 0; i < sampler.sampleSize(); ++i)
 
  195             int const index = sampler[i];
 
  196             ++instance_weights[index];
 
  201     if (options.class_weights_.size() > 0)
 
  203         for (
size_t i = 0; i < instance_weights.size(); ++i)
 
  204             instance_weights[i] *= options.class_weights_.at(labels(i));
 
  208     auto const mtry = spec.actual_mtry_;
 
  209     Sampler<MersenneTwister> dim_sampler(num_features, SamplerOptions().withoutReplacement().sampleSize(mtry), &randengine);
 
  212     std::stack<Node> node_stack;
 
  213     typedef std::pair<InstanceIter, InstanceIter> IterPair;
 
  214     PropertyMap<Node, IterPair> instance_range;  
 
  215     PropertyMap<Node, std::vector<double> > node_distributions;  
 
  216     PropertyMap<Node, size_t> node_depths;  
 
  218         auto const rootnode = tree.graph_.addNode();
 
  219         node_stack.push(rootnode);
 
  221         instance_range.insert(rootnode, IterPair(instance_indices.begin(), instance_indices.end()));
 
  223         std::vector<double> priors(spec.num_classes_, 0.0);
 
  224         for (
auto i : instance_indices)
 
  225             priors[labels(i)] += instance_weights[i];
 
  226         node_distributions.insert(rootnode, priors);
 
  228         node_depths.insert(rootnode, 0);
 
  232     visitor.visit_before_tree(tree, features, labels, instance_weights);
 
  235     detail::RFMapUpdater<ACC> node_map_updater;
 
  236     while (!node_stack.empty())
 
  239         auto const node = node_stack.top();
 
  241         auto const begin = instance_range.at(node).first;
 
  242         auto const end = instance_range.at(node).second;
 
  243         auto const & priors = node_distributions.at(node);
 
  244         auto const depth = node_depths.at(node);
 
  247         std::vector<size_t> used_instances;
 
  248         for (
auto it = begin; it != end; ++it)
 
  249             if (instance_weights[*it] > 1e-10)
 
  250                 used_instances.push_back(*it);
 
  253         dim_sampler.sample();
 
  254         SCORER score(priors);
 
  255         if (options.resample_count_ == 0 || used_instances.size() <= options.resample_count_)
 
  270             Sampler<MersenneTwister> resampler(used_instances.begin(), used_instances.end(), SamplerOptions().withoutReplacement().sampleSize(options.resample_count_), &randengine);
 
  272             auto indices = std::vector<size_t>(options.resample_count_);
 
  273             for (
size_t i = 0; i < options.resample_count_; ++i)
 
  274                 indices[i] = used_instances[resampler[i]];
 
  288         if (!score.split_found_)
 
  290             tree.node_responses_.insert(node, ACCInputType());
 
  291             node_map_updater(tree.node_responses_.at(node), node_distributions.at(node));
 
  296         auto const n_left = tree.graph_.addNode();
 
  297         auto const n_right = tree.graph_.addNode();
 
  298         tree.graph_.addArc(node, n_left);
 
  299         tree.graph_.addArc(node, n_right);
 
  300         auto const best_split = score.best_split_;
 
  301         auto const best_dim = score.best_dim_;
 
  302         auto const split_iter = std::partition(begin, end,
 
  305                 return features(i, best_dim) <= best_split;
 
  310         visitor.visit_after_split(tree, features, labels, instance_weights, score, begin, split_iter, end);
 
  312         instance_range.insert(n_left, IterPair(begin, split_iter));
 
  313         instance_range.insert(n_right, IterPair(split_iter, end));
 
  314         tree.split_tests_.insert(node, SplitTests(best_dim, best_split));
 
  315         node_depths.insert(n_left, depth+1);
 
  316         node_depths.insert(n_right, depth+1);
 
  319         auto priors_left = std::vector<double>(spec.num_classes_, 0.0);
 
  320         for (
auto it = begin; it != split_iter; ++it)
 
  321             priors_left[labels(*it)] += instance_weights[*it];
 
  322         node_distributions.insert(n_left, priors_left);
 
  325         if (stop(labels, RFNodeDescription<decltype(priors_left)>(depth+1, priors_left)))
 
  327             tree.node_responses_.insert(n_left, ACCInputType());
 
  328             node_map_updater(tree.node_responses_.at(n_left), node_distributions.at(n_left));
 
  332             node_stack.push(n_left);
 
  336         auto priors_right = std::vector<double>(spec.num_classes_, 0.0);
 
  337         for (
auto it = split_iter; it != end; ++it)
 
  338             priors_right[labels(*it)] += instance_weights[*it];
 
  339         node_distributions.insert(n_right, priors_right);
 
  342         if (stop(labels, RFNodeDescription<decltype(priors_right)>(depth+1, priors_right)))
 
  344             tree.node_responses_.insert(n_right, ACCInputType());
 
  345             node_map_updater(tree.node_responses_.at(n_right), node_distributions.at(n_right));
 
  349             node_stack.push(n_right);
 
  354     visitor.visit_after_tree(tree, features, labels, instance_weights);
 
  360 template <
typename FEATURES,
 
  366 RandomForest<FEATURES, LABELS>
 
  368         FEATURES 
const & features,
 
  369         LABELS 
const & labels,
 
  370         RandomForestOptions 
const & options,
 
  373         RANDENGINE & randengine
 
  376     typedef LABELS Labels;
 
  378     typedef typename Labels::value_type LabelType;
 
  379     typedef RandomForest<FEATURES, LABELS> RF;
 
  381     ProblemSpec<LabelType> pspec;
 
  382     pspec.num_instances(features.shape()[0])
 
  383          .num_features(features.shape()[1])
 
  384          .actual_mtry(options.get_features_per_node(features.shape()[1]))
 
  385          .actual_msample(labels.size());
 
  388     size_t const tree_count = options.tree_count_;
 
  389     vigra_precondition(tree_count > 0, 
"random_forest_impl(): tree_count must not be zero.");
 
  390     std::vector<RF> trees(tree_count);
 
  393     std::set<LabelType> 
const dlabels(labels.begin(), labels.end());
 
  394     std::vector<LabelType> 
const distinct_labels(dlabels.begin(), dlabels.end());
 
  395     pspec.distinct_classes(distinct_labels);
 
  396     std::map<LabelType, size_t> label_map;
 
  397     for (
size_t i = 0; i < distinct_labels.size(); ++i)
 
  399         label_map[distinct_labels[i]] = i;
 
  402     MultiArray<1, LabelType> transformed_labels(Shape1(labels.size()));
 
  403     for (
size_t i = 0; i < (size_t)labels.size(); ++i)
 
  405         transformed_labels(i) = label_map[labels(i)];
 
  409     vigra_precondition(options.class_weights_.size() == 0 || options.class_weights_.size() == distinct_labels.size(),
 
  410                        "random_forest_impl(): The number of class weights must be 0 or equal to the number of classes.");
 
  413     for (
auto & t : trees)
 
  414         t.problem_spec_ = pspec;
 
  417     size_t n_threads = 1;
 
  418     if (options.n_threads_ >= 1)
 
  419         n_threads = options.n_threads_;
 
  420     else if (options.n_threads_ == -1)
 
  421         n_threads = std::thread::hardware_concurrency();
 
  424     UniformIntRandomFunctor<RANDENGINE> rand_functor(randengine);
 
  425     std::set<UInt32> seeds;
 
  426     while (seeds.size() < n_threads)
 
  428         seeds.insert(rand_functor());
 
  430     vigra_assert(seeds.size() == n_threads, 
"random_forest_impl(): Could not create random seeds.");
 
  433     std::vector<RANDENGINE> rand_engines;
 
  434     for (
auto seed : seeds)
 
  436         rand_engines.push_back(RANDENGINE(seed));
 
  440     visitor.visit_before_training();
 
  444     typedef typename VisitorCopy<VISITOR>::type VisitorCopyType;
 
  445     std::vector<VisitorCopyType> tree_visitors;
 
  446     for (
size_t i = 0; i < tree_count; ++i)
 
  448         tree_visitors.emplace_back(visitor);
 
  452     ThreadPool pool((
size_t)n_threads);
 
  453     std::vector<threading::future<void> > futures;
 
  454     for (
size_t i = 0; i < tree_count; ++i)
 
  456         futures.emplace_back(
 
  457             pool.enqueue([&features, &transformed_labels, &options, &tree_visitors, &stop, &trees, i, &rand_engines](
size_t thread_id)
 
  459                     random_forest_single_tree<RF, SCORER, VisitorCopyType, STOP>(features, transformed_labels, options, tree_visitors[i], stop, trees[i], rand_engines[thread_id]);
 
  464     for (
auto & fut : futures)
 
  469     rf.options_ = options;
 
  470     for (
size_t i = 1; i < trees.size(); ++i)
 
  476     visitor.visit_after_training(tree_visitors, rf, features, labels);
 
  484 template <
typename FEATURES, 
typename LABELS, 
typename VISITOR, 
typename SCORER, 
typename RANDENGINE>
 
  486 RandomForest<FEATURES, LABELS>
 
  488         FEATURES 
const & features,
 
  489         LABELS 
const & labels,
 
  490         RandomForestOptions 
const & options,
 
  492         RANDENGINE & randengine
 
  494     if (options.max_depth_ > 0)
 
  495         return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, DepthStop, RANDENGINE>(features, labels, options, visitor, DepthStop(options.max_depth_), randengine);
 
  496     else if (options.min_num_instances_ > 1)
 
  497         return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NumInstancesStop, RANDENGINE>(features, labels, options, visitor, NumInstancesStop(options.min_num_instances_), randengine);
 
  498     else if (options.node_complexity_tau_ > 0)
 
  499         return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, NodeComplexityStop, RANDENGINE>(features, labels, options, visitor, NodeComplexityStop(options.node_complexity_tau_), randengine);
 
  501         return random_forest_impl<FEATURES, LABELS, VISITOR, SCORER, PurityStop, RANDENGINE>(features, labels, options, visitor, PurityStop(), randengine);
 
  579 template <
typename FEATURES, 
typename LABELS, 
typename VISITOR, 
typename RANDENGINE>
 
  581 RandomForest<FEATURES, LABELS>
 
  583         FEATURES 
const & features,
 
  584         LABELS 
const & labels,
 
  585         RandomForestOptions 
const & options,
 
  587         RANDENGINE & randengine
 
  589     typedef detail::GeneralScorer<GiniScore> GiniScorer;
 
  590     typedef detail::GeneralScorer<EntropyScore> EntropyScorer;
 
  591     typedef detail::GeneralScorer<KolmogorovSmirnovScore> KSDScorer;
 
  592     if (options.split_ == RF_GINI)
 
  593         return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, GiniScorer, RANDENGINE>(features, labels, options, visitor, randengine);
 
  594     else if (options.split_ == RF_ENTROPY)
 
  595         return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, EntropyScorer, RANDENGINE>(features, labels, options, visitor, randengine);
 
  596     else if (options.split_ == RF_KSD)
 
  597         return detail::random_forest_impl0<FEATURES, LABELS, VISITOR, KSDScorer, RANDENGINE>(features, labels, options, visitor, randengine);
 
  599         throw std::runtime_error(
"random_forest(): Unknown split criterion.");
 
  602 template <
typename FEATURES, 
typename LABELS, 
typename VISITOR>
 
  604 RandomForest<FEATURES, LABELS>
 
  606         FEATURES 
const & features,
 
  607         LABELS 
const & labels,
 
  608         RandomForestOptions 
const & options,
 
  612     return random_forest(features, labels, options, visitor, randengine);
 
  615 template <
typename FEATURES, 
typename LABELS>
 
  617 RandomForest<FEATURES, LABELS>
 
  619         FEATURES 
const & features,
 
  620         LABELS 
const & labels,
 
  621         RandomForestOptions 
const & options
 
  627 template <
typename FEATURES, 
typename LABELS>
 
  629 RandomForest<FEATURES, LABELS>
 
  631         FEATURES 
const & features,
 
  632         LABELS 
const & labels
 
  634     return random_forest(features, labels, RandomForestOptions());
 
void applyPermutation(IndexIterator index_first, IndexIterator index_last, InIterator in, OutIterator out)
Sort an array according to the given index permutation. 
Definition: algorithm.hxx:456
void indexSort(Iterator first, Iterator last, IndexIterator index_first, Compare c)
Return the index permutation that would sort the input array. 
Definition: algorithm.hxx:414
static RandomNumberGenerator & global()
Definition: random.hxx:566
doxygen_overloaded_function(template<...> void separableConvolveBlockwise) template< unsigned int N
Separated convolution on ChunkedArrays. 
void random_forest(...)
Train a vigra::rf3::RandomForest classifier.