12 #ifndef VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H 
   13 #define VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H 
   15 #include "../sampling.hxx" 
   16 #include "rf_split.hxx" 
   17 #include "rf_nodeproxy.hxx" 
   18 #include "../regression.hxx" 
   20 #define outm(v) std::cout << (#v) << ": " << (v) << std::endl; 
   21 #define outm2(v) std::cout << (#v) << ": " << (v) << ", "; 
   70 template<
class ColumnDecisionFunctor, 
class Tag = ClassificationTag>
 
   71 class RidgeSplit: 
public SplitBase<Tag>
 
   76     typedef SplitBase<Tag> SB;
 
   78     ArrayVector<Int32>          splitColumns;
 
   79     ColumnDecisionFunctor       bgfunc;
 
   82     ArrayVector<double>         min_gini_;
 
   83     ArrayVector<std::ptrdiff_t> min_indices_;
 
   84     ArrayVector<double>         min_thresholds_;
 
   89     bool            m_bDoScalingInTraining;
 
   90     bool            m_bDoBestLambdaBasedOnGini;
 
   93     :m_bDoScalingInTraining(true),
 
   94     m_bDoBestLambdaBasedOnGini(true)
 
   98     double minGini()
 const 
  100         return min_gini_[bestSplitIndex];
 
  103     int bestSplitColumn()
 const 
  105         return splitColumns[bestSplitIndex];
 
  108     bool& doScalingInTraining()
 
  109     { 
return m_bDoScalingInTraining; }
 
  111     bool& doBestLambdaBasedOnGini()
 
  112     { 
return m_bDoBestLambdaBasedOnGini; }
 
  115             void set_external_parameters(ProblemSpec<T> 
const & in)
 
  118         bgfunc.set_external_parameters(in);
 
  119         int featureCount_ = in.column_count_;
 
  120         splitColumns.resize(featureCount_);
 
  121         for(
int k=0; k<featureCount_; ++k)
 
  123         min_gini_.resize(featureCount_);
 
  124         min_indices_.resize(featureCount_);
 
  125         min_thresholds_.resize(featureCount_);
 
  129     template<
class T, 
class C, 
class T2, 
class C2, 
class Region, 
class Random>
 
  131                       MultiArrayView<2, T2, C2>  multiClassLabels,
 
  133                       ArrayVector<Region>& childRegions,
 
  138     typedef typename MultiArrayView <2, T, C>::difference_type fShape;
 
  144         if(std::accumulate(region.classCounts().begin(),
 
  145                            region.classCounts().end(), 0) != region.size())
 
  147             RandomForestClassCounter<   MultiArrayView<2,T2, C2>, 
 
  148                                         ArrayVector<double> >
 
  149                 counter(multiClassLabels, region.classCounts());
 
  150             std::for_each(  region.begin(), region.end(), counter);
 
  151             region.classCountsIsValid = 
true;
 
  158         if(region_gini_ == 0 || region.size() < SB::ext_param_.actual_mtry_ || region.oob_size() < 2)
 
  162     for(
int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
 
  163         std::swap(splitColumns[ii], 
 
  164             splitColumns[ii+ randint(features.shape(1) - ii)]);
 
  167     MultiArray<2, T2> labels(lShape(multiClassLabels.shape(0),1));
 
  170       for(
int n=0; n<static_cast<int>(region.classCounts().size()); n++)
 
  171         nNumClasses+=((region.classCounts()[n]>0) ? 1:0);
 
  177         int nMaxClassCounts=0;
 
  178         for(
int n=0; n<static_cast<int>(region.classCounts().size()); n++)
 
  182           if(region.classCounts()[n]>nMaxClassCounts)
 
  184         nMaxClassCounts=region.classCounts()[n];
 
  190         for(
int n=0; n<multiClassLabels.shape(0); n++)
 
  191           labels(n,0)=((multiClassLabels(n,0)==nMaxClass) ? 1:0);
 
  194         labels=multiClassLabels;
 
  228     MultiArrayView<2, T, C> cVector;
 
  229     MultiArray<2, T> xtrain(fShape(region.size(),SB::ext_param_.actual_mtry_));
 
  231     MultiArray<2, double> regrLabels(dShape(region.size(),1));
 
  234     MultiArray<2, double> meanMatrix(dShape(SB::ext_param_.actual_mtry_,1));
 
  235     MultiArray<2, double> stdMatrix(dShape(SB::ext_param_.actual_mtry_,1));
 
  236     for(
int m=0; m<SB::ext_param_.actual_mtry_; m++)
 
  241         double dCurrFeatureColumnMean=0.0;
 
  242         double dCurrFeatureColumnStd=1.0; 
 
  245         for(
int n=0; n<region.size(); n++)
 
  246           dCurrFeatureColumnMean+=cVector[region[n]];
 
  247         dCurrFeatureColumnMean/=region.size();
 
  249         if(m_bDoScalingInTraining)
 
  251           for(
int n=0; n<region.size(); n++)
 
  253               dCurrFeatureColumnStd+=
 
  254             (cVector[region[n]]-dCurrFeatureColumnMean)*(cVector[region[n]]-dCurrFeatureColumnMean);
 
  257           dCurrFeatureColumnStd=
sqrt(dCurrFeatureColumnStd/(region.size()-1));
 
  260         stdMatrix(m,0)=dCurrFeatureColumnStd;
 
  262         meanMatrix(m,0)=dCurrFeatureColumnMean;
 
  266         for(
int n=0; n<region.size(); n++)
 
  267             xtrain(n,m)=(cVector[region[n]]-dCurrFeatureColumnMean)/dCurrFeatureColumnStd;
 
  272     for(
int n=0; n<region.size(); n++)
 
  277         regrLabels(n,0)=((labels[region[n]]==0) ? -1:1);
 
  280     MultiArray<2, double> dLambdas(dShape(11,1));
 
  282     for(
int nLambda=-5; nLambda<=5; nLambda++)
 
  283         dLambdas[nCounter++]=pow(10.0,nLambda);
 
  285     MultiArray<2, double> regrCoef(dShape(SB::ext_param_.actual_mtry_,11));
 
  288     double dMaxRidgeSum=NumericTraits<double>::min();
 
  289     double dCurrRidgeSum;
 
  290     int nMaxRidgeSumAtLambdaInd=0;
 
  292     for(
int nLambdaInd=0; nLambdaInd<11; nLambdaInd++)
 
  300         MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
 
  302         for(
int n=0; n<region.oob_size(); n++)
 
  304           dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
 
  305           for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
 
  307             dDistanceFromHyperplane(region.oob_begin()[n],0)+=
 
  308               features(region.oob_begin()[n],splitColumns[m])*regrCoef(m,nLambdaInd);
 
  312         double dCurrIntercept=0.0;
 
  313         if(m_bDoBestLambdaBasedOnGini)
 
  316           bgfunc(dDistanceFromHyperplane,
 
  318               region.oob_begin(), region.oob_end(), 
 
  319               region.classCounts());
 
  320           dCurrIntercept=bgfunc.min_threshold_;
 
  324           for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
 
  325             dCurrIntercept+=meanMatrix(m,0)*regrCoef(m,nLambdaInd);
 
  328         for(
int n=0; n<region.oob_size(); n++)
 
  331             int nClassPrediction=((dDistanceFromHyperplane(region.oob_begin()[n],0) >=dCurrIntercept) ? 1:0);
 
  332             dCurrRidgeSum+=((nClassPrediction == labels(region.oob_begin()[n],0)) ? 1:0);
 
  334         if(dCurrRidgeSum>dMaxRidgeSum)
 
  336             dMaxRidgeSum=dCurrRidgeSum;
 
  337             nMaxRidgeSumAtLambdaInd=nLambdaInd;
 
  343         Node<i_HyperplaneNode>   node(SB::ext_param_.actual_mtry_, SB::t_data, SB::p_data);
 
  347         MultiArray<2, double> dCoeffVector(dShape(SB::ext_param_.actual_mtry_,1));
 
  348         for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
 
  349           dCoeffVector(n,0)=regrCoef(n,nMaxRidgeSumAtLambdaInd)*stdMatrix(n,0);
 
  352         double dVnorm=
columnVector(regrCoef,nMaxRidgeSumAtLambdaInd).norm();
 
  354         for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
 
  355             node.weights()[n]=dCoeffVector(n,0)/dVnorm;
 
  359         node.column_data()[0]=SB::ext_param_.actual_mtry_;
 
  360         for(
int n=0; n<SB::ext_param_.actual_mtry_; n++)
 
  361             node.column_data()[n+1]=splitColumns[n];
 
  367         MultiArray<2, double> dDistanceFromHyperplane(dShape(features.shape(0),1));
 
  369         for(
int n=0; n<region.size(); n++)
 
  371             dDistanceFromHyperplane(region[n],0)=0.0;
 
  372             for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
 
  374               dDistanceFromHyperplane(region[n],0)+=
 
  375                features(region[n],m)*node.weights()[m];
 
  378         for(
int n=0; n<region.oob_size(); n++)
 
  380             dDistanceFromHyperplane(region.oob_begin()[n],0)=0.0;
 
  381             for (
int m=0; m<SB::ext_param_.actual_mtry_; m++)
 
  383               dDistanceFromHyperplane(region.oob_begin()[n],0)+=
 
  384             features(region.oob_begin()[n],m)*node.weights()[m];
 
  389         bgfunc(dDistanceFromHyperplane,
 
  391             region.begin(), region.end(), 
 
  392             region.classCounts());
 
  399     node.intercept()    = bgfunc.min_threshold_;
 
  402     childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
 
  403     childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
 
  404     childRegions[0].classCountsIsValid = 
true;
 
  405     childRegions[1].classCountsIsValid = 
true;
 
  408     childRegions[0].setRange(   region.begin()  , region.begin() + bgfunc.min_index_   );
 
  409     childRegions[0].rule = region.rule;
 
  410     childRegions[0].rule.push_back(std::make_pair(1, 1.0));
 
  411     childRegions[1].setRange(   region.begin() + bgfunc.min_index_       , region.end()    );
 
  412     childRegions[1].rule = region.rule;
 
  413     childRegions[1].rule.push_back(std::make_pair(1, 1.0));
 
  418       std::sort(region.oob_begin(), region.oob_end(), 
 
  419             SortSamplesByDimensions< MultiArray<2, double> > (dDistanceFromHyperplane, 0));
 
  423       for(nOOBindx=0; nOOBindx<region.oob_size(); nOOBindx++)
 
  425         if(dDistanceFromHyperplane(region.oob_begin()[nOOBindx],0)>=node.intercept())
 
  429       childRegions[0].set_oob_range(   region.oob_begin()  , region.oob_begin() + nOOBindx   );
 
  430       childRegions[1].set_oob_range(   region.oob_begin() + nOOBindx , region.oob_end() );
 
  436     return i_HyperplaneNode;
 
  446 #endif // VIGRA_RANDOM_FOREST_RIDGE_SPLIT_H 
static double impurity(Array const &hist, double total)
Definition: rf_split.hxx:443
RidgeSplit< BestGiniOfColumn< GiniCriterion > > GiniRidgeSplit
Definition: rf_ridge_split.hxx:442
MultiArrayView< 2, T, C > columnVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:727
MultiArrayShape< actual_dimension >::type difference_type
Definition: multi_array.hxx:739
void set_external_parameters(ProblemSpec< T > const &in)
Definition: rf_split.hxx:112
int findBestSplit(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region, ArrayVector< Region >, Random)
Definition: rf_split.hxx:150
bool ridgeRegressionSeries(MultiArrayView< 2, T, C1 > const &A, MultiArrayView< 2, T, C2 > const &b, MultiArrayView< 2, T, C3 > &x, Array const &lambda)
Definition: regression.hxx:304
bool closeAtTolerance(T1 l, T2 r, typename PromoteTraits< T1, T2 >::Promote epsilon)
Tolerance based floating-point equality. 
Definition: mathutil.hxx:1638
int makeTerminalNode(MultiArrayView< 2, T, C >, MultiArrayView< 2, T2, C2 >, Region ®ion, Random)
Definition: rf_split.hxx:168
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root. 
Definition: fixedpoint.hxx:616