36 #ifndef VIGRA_RF_PREPROCESSING_HXX 
   37 #define VIGRA_RF_PREPROCESSING_HXX 
   40 #include <vigra/mathutil.hxx> 
   41 #include "rf_common.hxx" 
   62 template<
class Tag, 
class LabelType, 
class T1, 
class C1, 
class T2, 
class C2>
 
   77         switch(options.mtry_switch_)
 
   80                 ext_param.actual_mtry_ =
 
   82                             std::sqrt(
double(ext_param.column_count_))
 
   87                 ext_param.actual_mtry_ =
 
   88                     int(1+(
std::log(
double(ext_param.column_count_))
 
   92                 ext_param.actual_mtry_ =
 
   93                     options.mtry_func_(ext_param.column_count_);
 
   96                 ext_param.actual_mtry_ = ext_param.column_count_;
 
   99                 ext_param.actual_mtry_ =
 
  103         switch(options.training_set_calc_switch_)
 
  106                 ext_param.actual_msample_ =
 
  107                     options.training_set_size_;
 
  109             case RF_PROPORTIONAL:
 
  110                 ext_param.actual_msample_ =
 
  111                     static_cast<int>(
std::ceil(options.training_set_proportion_ *
 
  112                                                ext_param.row_count_));
 
  115                 ext_param.actual_msample_ =
 
  116                     options.training_set_func_(ext_param.row_count_);
 
  119                 vigra_precondition(1!= 1, 
"unexpected error");
 
  127     template<
unsigned int N, 
class T, 
class C>
 
  128     bool contains_nan(MultiArrayView<N, T, C> 
const & in)
 
  130         typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
 
  131         Iter i = in.begin(), end = in.end();
 
  133             if(isnan(NumericTraits<T>::toRealPromote(*i)))
 
  140     template<
unsigned int N, 
class T, 
class C>
 
  141     bool contains_inf(MultiArrayView<N, T, C> 
const & in)
 
  143          if(!std::numeric_limits<T>::has_infinity)
 
  145         typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
 
  146         Iter i = in.begin(), end = in.end();
 
  148             if(
abs(*i) == std::numeric_limits<T>::infinity())
 
  161 template<
class LabelType, 
class T1, 
class C1, 
class T2, 
class C2>
 
  162 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
 
  165     typedef Int32 LabelInt;
 
  181         vigra_precondition(!detail::contains_nan(features), 
"RandomForest(): Feature matrix " 
  183         vigra_precondition(!detail::contains_nan(response), 
"RandomForest(): Response " 
  185         vigra_precondition(!detail::contains_inf(features), 
"RandomForest(): Feature matrix " 
  187         vigra_precondition(!detail::contains_inf(response), 
"RandomForest(): Response " 
  190         ext_param.column_count_  = features.
shape(1);
 
  191         ext_param.row_count_     = features.
shape(0);
 
  192         ext_param.problem_type_  = CLASSIFICATION;
 
  193         ext_param.used_          = 
true;
 
  194         intLabels_.reshape(response.
shape());
 
  197         if(ext_param.class_count_ == 0)
 
  201             std::set<T2>                    labelToInt;
 
  203                 labelToInt.insert(response(k,0));
 
  204             std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
 
  205             ext_param.
classes_(tmp_.begin(), tmp_.end());
 
  209             if(std::find(ext_param.classes.
begin(), ext_param.classes.
end(), response(k,0)) == ext_param.classes.
end())
 
  211                 throw std::runtime_error(
"RandomForest(): invalid label in training data.");
 
  214                 intLabels_(k, 0) = std::find(ext_param.classes.
begin(), ext_param.classes.
end(), response(k,0))
 
  215                                     - ext_param.classes.
begin();
 
  218         if(ext_param.class_weights_.
size() == 0)
 
  221                 tmp(static_cast<std::size_t>(ext_param.class_count_),
 
  222                     NumericTraits<T2>::one());
 
  227         detail::fill_external_parameters(options, ext_param);
 
  230         strata_ = intLabels_;
 
  268 template<
class LabelType, 
class T1, 
class C1, 
class T2, 
class C2>
 
  269 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
 
  292         ext_param_(ext_param)
 
  295         ext_param.column_count_  = features.
shape(1);
 
  296         ext_param.row_count_     = features.
shape(0);
 
  297         ext_param.problem_type_  = REGRESSION;
 
  298         ext_param.used_          = 
true;
 
  299         detail::fill_external_parameters(options, ext_param);
 
  300         vigra_precondition(!detail::contains_nan(features), 
"Processor(): Feature Matrix " 
  302         vigra_precondition(!detail::contains_nan(response), 
"Processor(): Response " 
  304         vigra_precondition(!detail::contains_inf(features), 
"Processor(): Feature Matrix " 
  306         vigra_precondition(!detail::contains_inf(response), 
"Processor(): Response " 
  309         ext_param.response_size_ = response.
shape(1);
 
  310         ext_param.class_count_ = response_.shape(1);
 
  311         std::vector<T2> tmp_(ext_param.class_count_, 0);
 
  312             ext_param.
classes_(tmp_.begin(), tmp_.end());
 
  337 #endif //VIGRA_RF_PREPROCESSING_HXX 
ArrayVectorView< double > strata_prob()
Definition: rf_preprocessing.hxx:257
MultiArrayView< 2, LabelInt > response()
Definition: rf_preprocessing.hxx:243
Definition: rf_preprocessing.hxx:63
const difference_type & shape() const 
Definition: multi_array.hxx:1648
Definition: array_vector.hxx:76
const_iterator begin() const 
Definition: array_vector.hxx:223
problem specification class for the random forest. 
Definition: rf_common.hxx:538
MultiArrayView< 2, T1, C1 > & features()
Definition: rf_preprocessing.hxx:317
Main MultiArray class containing the memory management. 
Definition: multi_array.hxx:2474
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
MultiArrayView< 2, T1, C1 > const & features()
Definition: rf_preprocessing.hxx:236
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels - 
Definition: rf_common.hxx:828
Definition: array_vector.hxx:58
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
MultiArrayView< 2, T2, C2 > & response()
Definition: rf_preprocessing.hxx:324
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights - 
Definition: rf_common.hxx:844
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
MultiArray< 2, int > & strata()
Definition: rf_preprocessing.hxx:331
ArrayVectorView< LabelInt > strata()
Definition: rf_preprocessing.hxx:250
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude) 
Definition: fftw3.hxx:1002
Options object for the random forest. 
Definition: rf_common.hxx:170
const_iterator end() const 
Definition: array_vector.hxx:237
size_type size() const 
Definition: array_vector.hxx:358
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up. 
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down. 
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root. 
Definition: fixedpoint.hxx:616