37 #ifndef VIGRA_RF_COMMON_HXX 
   38 #define VIGRA_RF_COMMON_HXX 
   44 struct ClassificationTag
 
   69         friend RF_DEFAULT& ::vigra::rf_default();
 
   99 template<
class T, 
class C>
 
  104     static T & choose(T & t, C &)
 
  111 class Value_Chooser<detail::RF_DEFAULT, C>
 
  116     static C & choose(detail::RF_DEFAULT &, C & c)
 
  133     static detail::RF_DEFAULT result;
 
  176     double  training_set_proportion_;
 
  177     int     training_set_size_;
 
  178     int (*training_set_func_)(int);
 
  180         training_set_calc_switch_;
 
  182     bool    sample_with_replacement_;
 
  184             stratification_method_;
 
  195     int (*mtry_func_)(int) ;
 
  197     bool predict_weighted_;
 
  199     int min_split_node_size_;
 
  200     bool prepare_online_learning_;
 
  204     typedef std::map<std::string, double_array> map_type;
 
  206     int serialized_size()
 const 
  215         #define COMPARE(field) result = result && (this->field == rhs.field); 
  216         COMPARE(training_set_proportion_);
 
  217         COMPARE(training_set_size_);
 
  218         COMPARE(training_set_calc_switch_);
 
  219         COMPARE(sample_with_replacement_);
 
  220         COMPARE(stratification_method_);
 
  221         COMPARE(mtry_switch_);
 
  223         COMPARE(tree_count_);
 
  224         COMPARE(min_split_node_size_);
 
  225         COMPARE(predict_weighted_);
 
  232         return !(*
this == rhs_);
 
  235     void unserialize(Iter 
const & begin, Iter 
const & end)
 
  238         vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
 
  239                            "RandomForestOptions::unserialize():" 
  240                            "wrong number of parameters");
 
  241         #define PULL(item_, type_) item_ = type_(*iter); ++iter; 
  242         PULL(training_set_proportion_, 
double);
 
  243         PULL(training_set_size_, 
int);
 
  246         PULL(sample_with_replacement_, 0 != );
 
  251         PULL(tree_count_, 
int);
 
  252         PULL(min_split_node_size_, 
int);
 
  253         PULL(predict_weighted_, 0 !=);
 
  257     void serialize(Iter 
const &  begin, Iter 
const & end)
 const 
  260         vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
 
  261                            "RandomForestOptions::serialize():" 
  262                            "wrong number of parameters");
 
  263         #define PUSH(item_) *iter = double(item_); ++iter; 
  264         PUSH(training_set_proportion_);
 
  265         PUSH(training_set_size_);
 
  266         if(training_set_func_ != 0)
 
  274         PUSH(training_set_calc_switch_);
 
  275         PUSH(sample_with_replacement_);
 
  276         PUSH(stratification_method_);
 
  288         PUSH(min_split_node_size_);
 
  289         PUSH(predict_weighted_);
 
  293     void make_from_map(map_type & in) 
 
  295         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
  296         #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 
  297         PULL(training_set_proportion_,
double);
 
  298         PULL(training_set_size_, 
int);
 
  300         PULL(tree_count_, 
int);
 
  301         PULL(min_split_node_size_, 
int);
 
  302         PULLBOOL(sample_with_replacement_, 
bool);
 
  303         PULLBOOL(prepare_online_learning_, 
bool);
 
  304         PULLBOOL(predict_weighted_, 
bool);
 
  317     void make_map(map_type & in)
 const 
  319         #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_)); 
  320         #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0)); 
  321         PUSH(training_set_proportion_,
double);
 
  322         PUSH(training_set_size_, 
int);
 
  324         PUSH(tree_count_, 
int);
 
  325         PUSH(min_split_node_size_, 
int);
 
  326         PUSH(sample_with_replacement_, 
bool);
 
  327         PUSH(prepare_online_learning_, 
bool);
 
  328         PUSH(predict_weighted_, 
bool);
 
  334         PUSHFUNC(mtry_func_, 
int);
 
  335         PUSHFUNC(training_set_func_,
int);
 
  348         training_set_proportion_(1.0),
 
  349         training_set_size_(0),
 
  350         training_set_func_(0),
 
  351         training_set_calc_switch_(RF_PROPORTIONAL),
 
  352         sample_with_replacement_(true),
 
  353         stratification_method_(RF_NONE),
 
  354         mtry_switch_(RF_SQRT),
 
  357         predict_weighted_(false),
 
  359         min_split_node_size_(1),
 
  360         prepare_online_learning_(false)
 
  376         vigra_precondition(in == RF_EQUAL ||
 
  377                            in == RF_PROPORTIONAL ||
 
  380                            "RandomForestOptions::use_stratification()" 
  381                            "input must be RF_EQUAL, RF_PROPORTIONAL," 
  382                            "RF_EXTERNAL or RF_NONE");
 
  383         stratification_method_ = in;
 
  389         prepare_online_learning_=in;
 
  399         sample_with_replacement_ = in;
 
  413         training_set_proportion_ = in;
 
  414         training_set_calc_switch_ = RF_PROPORTIONAL;
 
  425         training_set_size_ = in;
 
  426         training_set_calc_switch_ = RF_CONST;
 
  438         training_set_func_ = in;
 
  439         training_set_calc_switch_ = RF_FUNCTION;
 
  447         predict_weighted_ = 
true;
 
  462         vigra_precondition(in == RF_LOG ||
 
  465                            "RandomForestOptions()::features_per_node():" 
  466                            "input must be of type RF_LOG or RF_SQRT");
 
  480         mtry_switch_ = RF_CONST;
 
  492         mtry_switch_ = RF_FUNCTION;
 
  516         min_split_node_size_ = in;
 
  524 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
 
  537 template<
class LabelType = 
double>
 
  550     typedef std::map<std::string, double_array> map_type;
 
  559     Problem_t               problem_type_;    
 
  568     void to_classlabel(
int index, T & out)
 const 
  570         out = T(classes[index]);
 
  573     int to_classIndex(T index)
 const 
  575         return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
 
  578     #define EQUALS(field) field(rhs.field) 
  581         EQUALS(column_count_),
 
  582         EQUALS(class_count_),
 
  584         EQUALS(actual_mtry_),
 
  585         EQUALS(actual_msample_),
 
  586         EQUALS(problem_type_),
 
  588         EQUALS(class_weights_),
 
  589         EQUALS(is_weighted_),
 
  591         EQUALS(response_size_)
 
  593         std::back_insert_iterator<ArrayVector<Label_t> >
 
  595         std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
 
  598     #define EQUALS(field) field(rhs.field) 
  602         EQUALS(column_count_),
 
  603         EQUALS(class_count_),
 
  605         EQUALS(actual_mtry_),
 
  606         EQUALS(actual_msample_),
 
  607         EQUALS(problem_type_),
 
  609         EQUALS(class_weights_),
 
  610         EQUALS(is_weighted_),
 
  612         EQUALS(response_size_)
 
  614         std::back_insert_iterator<ArrayVector<Label_t> >
 
  616         std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
 
  620     #define EQUALS(field) (this->field = rhs.field); 
  623         EQUALS(column_count_);
 
  624         EQUALS(class_count_);
 
  626         EQUALS(actual_mtry_);
 
  627         EQUALS(actual_msample_);
 
  628         EQUALS(problem_type_);
 
  630         EQUALS(is_weighted_);
 
  632         EQUALS(response_size_)
 
  633         class_weights_.clear();
 
  634         std::back_insert_iterator<ArrayVector<
double> >
 
  635                         iter2(class_weights_);
 
  636         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
 
  638         std::back_insert_iterator<ArrayVector<
Label_t> >
 
  640         std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
 
  647         EQUALS(column_count_);
 
  648         EQUALS(class_count_);
 
  650         EQUALS(actual_mtry_);
 
  651         EQUALS(actual_msample_);
 
  652         EQUALS(problem_type_);
 
  654         EQUALS(is_weighted_);
 
  656         EQUALS(response_size_)
 
  657         class_weights_.clear();
 
  658         std::back_insert_iterator<ArrayVector<
double> >
 
  659                         iter2(class_weights_);
 
  660         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
 
  662         std::back_insert_iterator<ArrayVector<
Label_t> >
 
  664         std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
 
  670     bool operator==(ProblemSpec<T> 
const & rhs)
 
  673         #define COMPARE(field) result = result && (this->field == rhs.field); 
  674         COMPARE(column_count_);
 
  675         COMPARE(class_count_);
 
  677         COMPARE(actual_mtry_);
 
  678         COMPARE(actual_msample_);
 
  679         COMPARE(problem_type_);
 
  680         COMPARE(is_weighted_);
 
  683         COMPARE(class_weights_);
 
  685         COMPARE(response_size_)
 
  692         return !(*
this == rhs);
 
  696     size_t serialized_size()
 const 
  698         return 10 + class_count_ *int(is_weighted_+1);
 
  703     void unserialize(Iter 
const & begin, Iter 
const & end)
 
  706         vigra_precondition(end - begin >= 10,
 
  707                            "ProblemSpec::unserialize():" 
  708                            "wrong number of parameters");
 
  709         #define PULL(item_, type_) item_ = type_(*iter); ++iter; 
  710         PULL(column_count_,
int);
 
  711         PULL(class_count_, 
int);
 
  713         vigra_precondition(end - begin >= 10 + class_count_,
 
  714                            "ProblemSpec::unserialize(): 1");
 
  715         PULL(row_count_, 
int);
 
  716         PULL(actual_mtry_,
int);
 
  717         PULL(actual_msample_, 
int);
 
  718         PULL(problem_type_, Problem_t);
 
  719         PULL(is_weighted_, 
int);
 
  721         PULL(precision_, 
double);
 
  722         PULL(response_size_, 
int);
 
  725             vigra_precondition(end - begin == 10 + 2*class_count_,
 
  726                                "ProblemSpec::unserialize(): 2");
 
  727             class_weights_.insert(class_weights_.end(),
 
  729                                   iter + class_count_);
 
  730             iter += class_count_;
 
  732         classes.insert(classes.end(), iter, end);
 
  738     void serialize(Iter 
const & begin, Iter 
const & end)
 const 
  741         vigra_precondition(end - begin == serialized_size(),
 
  742                            "RandomForestOptions::serialize():" 
  743                            "wrong number of parameters");
 
  744         #define PUSH(item_) *iter = double(item_); ++iter; 
  749         PUSH(actual_msample_);
 
  754         PUSH(response_size_);
 
  757             std::copy(class_weights_.begin(),
 
  758                       class_weights_.end(),
 
  760             iter += class_count_;
 
  762         std::copy(classes.begin(),
 
  768     void make_from_map(map_type & in) 
 
  770         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
  771         PULL(column_count_,
int);
 
  772         PULL(class_count_, 
int);
 
  773         PULL(row_count_, 
int);
 
  774         PULL(actual_mtry_,
int);
 
  775         PULL(actual_msample_, 
int);
 
  776         PULL(problem_type_, (Problem_t)
int);
 
  777         PULL(is_weighted_, 
int);
 
  779         PULL(precision_, 
double);
 
  780         PULL(response_size_, 
int);
 
  781         class_weights_ = in[
"class_weights_"];
 
  784     void make_map(map_type & in)
 const 
  786         #define PUSH(item_) in[#item_] = double_array(1, double(item_)); 
  791         PUSH(actual_msample_);
 
  796         PUSH(response_size_);
 
  797         in["class_weights_"] = class_weights_;
 
  809         problem_type_(CHECKLATER),
 
  827     template<
class C_Iter>
 
  831         int size = end-begin;
 
  832         for(
int k=0; k<size; ++k, ++begin)
 
  833             classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
 
  843     template<
class W_Iter>
 
  846         class_weights_.clear();
 
  847         class_weights_.insert(class_weights_.end(), begin, end);
 
  858         class_weights_.clear();
 
  863         problem_type_ = CHECKLATER;
 
  864         is_weighted_ = 
false;
 
  888     int min_split_node_size_;
 
  892     :   min_split_node_size_(opt.min_split_node_size_)
 
  896     void set_external_parameters(
ProblemSpec<T>const  &, 
int  = 0, 
bool  = 
false)
 
  899     template<
class Region>
 
  900     bool operator()(Region& region)
 
  902         return region.size() < min_split_node_size_;
 
  905     template<
class WeightIter, 
class T, 
class C>
 
  915 #endif //VIGRA_RF_COMMON_HXX 
RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry 
Definition: rf_common.hxx:460
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag 
Definition: rf_common.hxx:131
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning. 
Definition: rf_common.hxx:411
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry 
Definition: rf_common.hxx:489
const_iterator begin() const 
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree 
Definition: rf_common.hxx:423
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:436
problem specification class for the random forest. 
Definition: rf_common.hxx:538
LabelType Label_t
problem class 
Definition: rf_common.hxx:547
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split. 
Definition: rf_common.hxx:514
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value. 
Definition: rf_common.hxx:477
Standard early stopping criterion. 
Definition: rf_common.hxx:885
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels - 
Definition: rf_common.hxx:828
RandomForestOptions()
create a RandomForestOptions object with default initialisation. 
Definition: rf_common.hxx:346
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights - 
Definition: rf_common.hxx:844
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement? 
Definition: rf_common.hxx:397
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node 
Definition: rf_common.hxx:445
Base class for, and view to, vigra::MultiArray. 
Definition: multi_array.hxx:704
Options object for the random forest. 
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy 
Definition: rf_common.hxx:374
RandomForestOptions & tree_count(unsigned int in)
Definition: rf_common.hxx:500
const_iterator end() const 
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set) 
Definition: rf_common.hxx:803
RF_OptionTag
Definition: rf_common.hxx:140