35 #ifndef VIGRA_RF3_COMMON_HXX 
   36 #define VIGRA_RF3_COMMON_HXX 
   39 #include <type_traits> 
   43 #include "../multi_array.hxx" 
   44 #include "../mathutil.hxx" 
   57 struct LessEqualSplitTest
 
   60     LessEqualSplitTest(
size_t dim = 0, T 
const & val = 0)
 
   66     template<
typename FEATURES>
 
   67     size_t operator()(FEATURES 
const & features)
 const 
   69         return features(dim_) <= val_ ? 0 : 1;
 
   81     typedef size_t input_type;
 
   83     template <
typename ITER, 
typename OUTITER>
 
   84     void operator()(ITER begin, ITER end, OUTITER out)
 
   86         std::fill(buffer_.begin(), buffer_.end(), 0);
 
   89         for (ITER it = begin; it != end; ++it)
 
   92             if (v >= buffer_.size())
 
   94                 buffer_.resize(v+1, 0);
 
   98             max_v = std::max(max_v, v);
 
  100         for (
size_t i = 0; i <= max_v; ++i)
 
  102             *out = buffer_[i] / 
static_cast<double>(n);
 
  107     std::vector<size_t> buffer_;
 
  112 template <
typename VALUETYPE>
 
  113 struct ArgMaxVectorAcc
 
  116     typedef VALUETYPE value_type;
 
  117     typedef std::vector<value_type> input_type;
 
  118     template <
typename ITER, 
typename OUTITER>
 
  119     void operator()(ITER begin, ITER end, OUTITER out)
 
  121         std::fill(buffer_.begin(), buffer_.end(), 0);
 
  123         for (ITER it = begin; it != end; ++it)
 
  125             input_type 
const & vec = *it;
 
  126             if (vec.size() >= buffer_.size())
 
  128                 buffer_.resize(vec.size(), 0);
 
  130             value_type 
const n = std::accumulate(vec.begin(), vec.end(), 
static_cast<value_type
>(0));
 
  131             for (
size_t i = 0; i < vec.size(); ++i)
 
  133                 buffer_[i] += vec[i] / 
static_cast<double>(n);
 
  135             max_v = std::max(vec.size()-1, max_v);
 
  137         for (
size_t i = 0; i <= max_v; ++i)
 
  144         std::vector<double> buffer_;
 
  216     template <
typename FUNCTOR>
 
  221         typedef FUNCTOR Functor;
 
  228             best_score_(std::numeric_limits<double>::max()),
 
  230             n_total_(std::accumulate(priors.begin(), priors.end(), 0.0))
 
  233         template <
typename FEATURES, 
typename LABELS, 
typename WEIGHTS, 
typename ITER>
 
  235             FEATURES 
const & features,
 
  236             LABELS 
const & labels,
 
  237             WEIGHTS 
const & weights,
 
  247             std::vector<double> counts(priors_.size(), 0.0);
 
  251             for (; next != end; ++begin, ++next)
 
  254                 size_t const left_index = *begin;
 
  255                 size_t const right_index = *next;
 
  256                 size_t const label = 
static_cast<size_t>(labels(left_index));
 
  257                 counts[label] += weights[left_index];
 
  258                 n_left += weights[left_index];
 
  261                 auto const left = features(left_index, dim);
 
  262                 auto const right = features(right_index, dim);
 
  268                 double const s = score(priors_, counts, n_total_, n_left);
 
  269                 bool const better_score = s < best_score_;
 
  273                     best_split_ = 0.5*(left+right);
 
  286         std::vector<double> 
const priors_; 
 
  287         double const n_total_; 
 
  299     double operator()(std::vector<double> 
const & priors,
 
  300                       std::vector<double> 
const & counts, 
double n_total, 
double n_left)
 const 
  302         double const n_right = n_total - n_left;
 
  303         double gini_left = 1.0;
 
  304         double gini_right = 1.0;
 
  305         for (
size_t i = 0; i < counts.size(); ++i)
 
  307             double const p_left = counts[i] / n_left;
 
  308             double const p_right = (priors[i] - counts[i]) / n_right;
 
  309             gini_left -= (p_left*p_left);
 
  310             gini_right -= (p_right*p_right);
 
  312         return n_left*gini_left + n_right*gini_right;
 
  316     template <
typename LABELS, 
typename WEIGHTS, 
typename ITER>
 
  317     static double region_score(LABELS 
const & labels, WEIGHTS 
const & weights, ITER begin, ITER end)
 
  320         std::vector<double> counts;
 
  322         for (
auto it = begin; it != end; ++it)
 
  325             auto const lbl = labels[d];
 
  326             if (counts.size() <= lbl)
 
  328                 counts.resize(lbl+1, 0.0);
 
  330             counts[lbl] += weights[d];
 
  336         for (
auto x : counts)
 
  351     double operator()(std::vector<double> 
const & priors, std::vector<double> 
const & counts, 
double n_total, 
double n_left)
 const 
  353         double const n_right = n_total - n_left;
 
  355         for (
size_t i = 0; i < counts.size(); ++i)
 
  357             double c = counts[i];
 
  368     template <
typename LABELS, 
typename WEIGHTS, 
typename ITER>
 
  369     double region_score(LABELS 
const & , WEIGHTS 
const & , ITER , ITER )
 const 
  371         vigra_fail(
"EntropyScore::region_score(): Not implemented yet.");
 
  385     double operator()(std::vector<double> 
const & priors, std::vector<double> 
const & counts, 
double , 
double ) 
const  
  387         double const eps = 1e-10;
 
  389         std::vector<double> norm_counts(counts.size(), 0.0);
 
  390         for (
size_t i = 0; i < counts.size(); ++i)
 
  394                 norm_counts[i] = counts[i] / priors[i];
 
  403         double const mean = std::accumulate(norm_counts.begin(), norm_counts.end(), 0.0) / nnz;
 
  407         for (
size_t i = 0; i < norm_counts.size(); ++i)
 
  411                 double const v = (mean-norm_counts[i]);
 
  418     template <
typename LABELS, 
typename WEIGHTS, 
typename ITER>
 
  419     double region_score(LABELS 
const & , WEIGHTS 
const & , ITER , ITER )
 const 
  421         vigra_fail(
"KolmogorovSmirnovScore::region_score(): Region score not available for the Kolmogorov-Smirnov split.");
 
  427 template <
typename ARR>
 
  428 struct RFNodeDescription
 
  431     RFNodeDescription(
size_t depth, ARR 
const & priors)
 
  443 template <
typename LABELS, 
typename ITER>
 
  444 bool is_pure(LABELS 
const & , RFNodeDescription<ITER> 
const & desc)
 
  447     for (
auto n : desc.priors_)
 
  466     template <
typename LABELS, 
typename ITER>
 
  467     bool operator()(LABELS 
const & labels, RFNodeDescription<ITER> 
const & desc)
 const 
  469         return is_pure(labels, desc);
 
  482         max_depth_(max_depth)
 
  485     template <
typename LABELS, 
typename ITER>
 
  486     bool operator()(LABELS 
const & labels, RFNodeDescription<ITER> 
const & desc)
 const 
  488         if (desc.depth_ >= max_depth_)
 
  491             return is_pure(labels, desc);
 
  508     template <
typename LABELS, 
typename ARR>
 
  509     bool operator()(LABELS 
const & labels, RFNodeDescription<ARR> 
const & desc)
 const 
  511         typedef typename ARR::value_type value_type;
 
  512         if (std::accumulate(desc.priors_.begin(), desc.priors_.end(), 
static_cast<value_type
>(0)) <= min_n_)
 
  515             return is_pure(labels, desc);
 
  530         logtau_(std::
log(tau))
 
  532         vigra_precondition(tau > 0 && tau < 1, 
"NodeComplexityStop(): Tau must be in the open interval (0, 1).");
 
  535     template <
typename LABELS, 
typename ARR>
 
  536     bool operator()(LABELS 
const & , RFNodeDescription<ARR> 
const & desc) 
 
  538         typedef typename ARR::value_type value_type;
 
  541         size_t const total = std::accumulate(desc.priors_.begin(), desc.priors_.end(), 
static_cast<value_type
>(0));
 
  546         for (
auto v : desc.priors_)
 
  551                 lg += 
loggamma(static_cast<double>(v+1));
 
  554         lg += 
loggamma(static_cast<double>(nnz+1));
 
  555         lg -= 
loggamma(static_cast<double>(total+1));
 
  559         return lg >= logtau_;
 
  565 enum RandomForestOptionTags
 
  589         features_per_node_(0),
 
  590         features_per_node_switch_(RF_SQRT),
 
  591         bootstrap_sampling_(
true),
 
  595         node_complexity_tau_(-1),
 
  596         min_num_instances_(1),
 
  597         use_stratification_(
false),
 
  609         tree_count_ = p_tree_count;
 
  622         features_per_node_switch_ = RF_CONST;
 
  623         features_per_node_ = p_features_per_node;
 
  639         vigra_precondition(p_features_per_node_switch == RF_SQRT ||
 
  640                            p_features_per_node_switch == RF_LOG ||
 
  641                            p_features_per_node_switch == RF_ALL,
 
  642                            "RandomForestOptions::features_per_node(): Input must be RF_SQRT, RF_LOG or RF_ALL.");
 
  643         features_per_node_switch_ = p_features_per_node_switch;
 
  654         bootstrap_sampling_ = b;
 
  666         bootstrap_sampling_ = 
false;
 
  682         vigra_precondition(p_split == RF_GINI ||
 
  683                            p_split == RF_ENTROPY ||
 
  685                            "RandomForestOptions::split(): Input must be RF_GINI, RF_ENTROPY or RF_KSD.");
 
  708         node_complexity_tau_ = tau;
 
  719         min_num_instances_ = n;
 
  733         use_stratification_ = b;
 
  774         if (features_per_node_switch_ == RF_SQRT)
 
  776         else if (features_per_node_switch_ == RF_LOG)
 
  778         else if (features_per_node_switch_ == RF_CONST)
 
  779             return features_per_node_;
 
  780         else if (features_per_node_switch_ == RF_ALL)
 
  782         vigra_fail(
"RandomForestOptions::get_features_per_node(): Unknown switch.");
 
  787     int features_per_node_;
 
  788     RandomForestOptionTags features_per_node_switch_;
 
  789     bool bootstrap_sampling_;
 
  790     size_t resample_count_;
 
  791     RandomForestOptionTags split_;
 
  793     double node_complexity_tau_;
 
  794     size_t min_num_instances_;
 
  795     bool use_stratification_;
 
  797     std::vector<double> class_weights_;
 
  803 template <
typename LabelType>
 
  818     ProblemSpec & num_features(
size_t n)
 
  824     ProblemSpec & num_instances(
size_t n)
 
  830     ProblemSpec & num_classes(
size_t n)
 
  836     ProblemSpec & distinct_classes(std::vector<LabelType> v)
 
  838         distinct_classes_ = v;
 
  839         num_classes_ = v.size();
 
  843     ProblemSpec & actual_mtry(
size_t m)
 
  849     ProblemSpec & actual_msample(
size_t m)
 
  855     bool operator==(ProblemSpec 
const & other)
 const 
  857         #define COMPARE(field) if (field != other.field) return false; 
  858         COMPARE(num_features_);
 
  859         COMPARE(num_instances_);
 
  860         COMPARE(num_classes_);
 
  861         COMPARE(distinct_classes_);
 
  862         COMPARE(actual_mtry_);
 
  863         COMPARE(actual_msample_);
 
  868     size_t num_features_;
 
  869     size_t num_instances_;
 
  871     std::vector<LabelType> distinct_classes_;
 
  873     size_t actual_msample_;
 
RandomForestOptions & min_num_instances(size_t n)
Do not split a node if it contains less than min_num_instances data points. 
Definition: random_forest_common.hxx:717
RandomForestOptions & split(RandomForestOptionTags p_split)
The split criterion. 
Definition: random_forest_common.hxx:680
Random forest 'maximum depth' stop criterion. 
Definition: random_forest_common.hxx:476
size_t get_features_per_node(size_t total) const 
Get the actual number of features per node. 
Definition: random_forest_common.hxx:772
DepthStop(size_t max_depth)
Constructor: terminate tree construction at max_depth. 
Definition: random_forest_common.hxx:480
RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
The number of features that are considered when computing the split. 
Definition: random_forest_common.hxx:637
RandomForestOptions & max_depth(size_t d)
Do not split a node if its depth is greater or equal to max_depth. 
Definition: random_forest_common.hxx:695
RandomForestOptions & bootstrap_sampling(bool b)
Use bootstrap sampling. 
Definition: random_forest_common.hxx:652
problem specification class for the random forest. 
Definition: rf_common.hxx:538
Definition: random_forest_common.hxx:217
RandomForestOptions & class_weights(std::vector< double > const &v)
Each datapoint is weighted by its class weight. By default, each class has weight 1...
Definition: random_forest_common.hxx:759
RandomForestOptions & use_stratification(bool b)
Use stratification when creating the bootstrap samples. 
Definition: random_forest_common.hxx:731
RandomForestOptions & resample_count(size_t n)
If resample_count is greater than zero, the split in each node is computed using only resample_count ...
Definition: random_forest_common.hxx:663
Functor that computes the entropy score. 
Definition: random_forest_common.hxx:348
RandomForestOptions & n_threads(int n)
The number of threads that are used in training. 
Definition: random_forest_common.hxx:744
bool operator==(FFTWComplex< R > const &a, const FFTWComplex< R > &b)
equal 
Definition: fftw3.hxx:825
Random forest 'node purity' stop criterion. 
Definition: random_forest_common.hxx:463
NodeComplexityStop(double tau=0.001)
Constructor: stop when fewer than 1/tau label arrangements are possible. 
Definition: random_forest_common.hxx:528
RandomForestOptions & node_complexity_tau(double tau)
Value of the node complexity termination criterion. 
Definition: random_forest_common.hxx:706
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
double loggamma(double x)
The natural logarithm of the gamma function. 
Definition: mathutil.hxx:1603
Functor that computes the gini score. 
Definition: random_forest_common.hxx:296
Random forest 'node complexity' stop criterion. 
Definition: random_forest_common.hxx:524
Options class for vigra::rf3::RandomForest version 3. 
Definition: random_forest_common.hxx:582
RandomForestOptions & features_per_node(int p_features_per_node)
The number of features that are considered when computing the split. 
Definition: random_forest_common.hxx:620
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up. 
Definition: fixedpoint.hxx:675
RandomForestOptions & tree_count(int p_tree_count)
The number of trees. 
Definition: random_forest_common.hxx:607
NumInstancesStop(size_t min_n)
Constructor: terminate tree construction when node contains less than min_n instances. 
Definition: random_forest_common.hxx:503
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root. 
Definition: fixedpoint.hxx:616
Functor that computes the Kolmogorov-Smirnov score. 
Definition: random_forest_common.hxx:382
Random forest 'number of datapoints' stop criterion. 
Definition: random_forest_common.hxx:499