36 #ifndef VIGRA_RANDOM_FOREST_NP_HXX 
   37 #define VIGRA_RANDOM_FOREST_NP_HXX 
   42 #include "vigra/mathutil.hxx" 
   43 #include "vigra/array_vector.hxx" 
   44 #include "vigra/sized_int.hxx" 
   45 #include "vigra/matrix.hxx" 
   46 #include "vigra/random.hxx" 
   47 #include "vigra/functorexpression.hxx" 
   58     AllColumns          = 0x00000000,
 
   59     ToBePrunedTag       = 0x80000000,
 
   60     LeafNodeTag         = 0x40000000,
 
   64     i_HypersphereNode   = 2,
 
   65     e_ConstProbNode     = 0 | LeafNodeTag,
 
   66     e_LogRegProbNode    = 1 | LeafNodeTag
 
   93     typedef T_Container_type::iterator          Topology_type;
 
   94     typedef P_Container_type::iterator          Parameter_type;
 
   97     mutable Topology_type                       topology_;
 
  100     mutable Parameter_type                      parameters_;
 
  101     int                                         parameter_size_ ;
 
  141     INT 
const &          
typeID()
 const 
  161         return topology_ + 4 ;
 
  177             return featureCount_;
 
  197     Topology_type   topology_end()
 const 
  201     int          topology_size()
 const 
  203         return topology_size_;
 
  211     Parameter_type  parameters_end()
 const 
  216     int          parameters_size()
 const 
  218         return parameter_size_;
 
  243         vigra_precondition(topology_size_==o.topology_size_,
"Cannot copy nodes of different sizes");
 
  244         vigra_precondition(featureCount_==o.featureCount_,
"Cannot copy nodes with different feature count");
 
  245         vigra_precondition(classCount_==o.classCount_,
"Cannot copy nodes with different class counts");
 
  246         vigra_precondition(parameters_size() ==o.parameters_size(),
"Cannot copy nodes with different parameter sizes");
 
  258                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
 
  260                     parameters_  (const_cast<Parameter_type>(parameter.begin() + 
parameter_addr())),
 
  262                     featureCount_(topology[0]),
 
  263                     classCount_(topology[1]),
 
  278                     topology_   (const_cast<Topology_type>(topology.begin()+ n)),
 
  279                     topology_size_(tLen),
 
  280                     parameters_  (const_cast<Parameter_type>(parameter.begin() + 
parameter_addr())),
 
  281                     parameter_size_(pLen),
 
  282                     featureCount_(topology[0]),
 
  283                     classCount_(topology[1]),
 
  296                     topology_   (node.topology_),
 
  297                     topology_size_(tLen),
 
  298                     parameters_  (node.parameters_),
 
  299                     parameter_size_(pLen),
 
  300                     featureCount_(node.featureCount_),
 
  301                     classCount_(node.classCount_),
 
  321                     topology_size_(tLen),
 
  322                     parameter_size_(pLen),
 
  323                     featureCount_(topology[0]),
 
  324                     classCount_(topology[1]),
 
  330         size_t n = topology.
size();
 
  331         for(
int ii = 0; ii < tLen; ++ii)
 
  332             topology.push_back(0);
 
  335         topology_           =   topology.
begin()+ n;
 
  341         for(
int ii = 0; ii < pLen; ++ii)
 
  342             parameter.push_back(0);
 
  360                     topology_size_(toCopy.topology_size()),
 
  361                     parameter_size_(toCopy.parameters_size()),
 
  362                     featureCount_(topology[0]),
 
  363                     classCount_(topology[1]),
 
  369         size_t n            = topology.
size();
 
  370         for(
int ii = 0; ii < toCopy.topology_size(); ++ii)
 
  373         topology_           =   topology.
begin()+ n;
 
  375         for(
int ii = 0; ii < toCopy.parameters_size(); ++ii)
 
  383 template<NodeTags NodeType>
 
  387 class Node<i_ThresholdNode>
 
  397     Node(   BT::T_Container_type &   topology,
 
  398             BT::P_Container_type &   param)
 
  399                 :   BT(5,2,topology, param)
 
  401         BT::typeID() = i_ThresholdNode;
 
  404     Node(   BT::T_Container_type 
const     &   topology,
 
  405             BT::P_Container_type 
const     &   param,
 
  407                 :   BT(5,2,topology, param, n)
 
  416         return BT::parameters_begin()[1];
 
  419     double const & threshold()
 const 
  421         return BT::parameters_begin()[1];
 
  426         return BT::column_data()[0];
 
  428     BT::INT 
const & column()
 const 
  430         return BT::column_data()[0];
 
  433     template<
class U, 
class C>
 
  434     BT::INT  next(MultiArrayView<2,U,C> 
const & feature)
 const 
  436         return (feature(0, column()) < threshold())? child(0):child(1);
 
  442 class Node<i_HyperplaneNode>
 
  452                     BT::T_Container_type    &   topology,
 
  453                     BT::P_Container_type    &   split_param)
 
  454                 :   BT(nCol + 5,nCol + 2,topology, split_param)
 
  456         BT::typeID() = i_HyperplaneNode;
 
  459     Node(           BT::T_Container_type  
const  &   topology,
 
  460                     BT::P_Container_type  
const  &   split_param,
 
  462                 :   NodeBase(5 , 2,topology, split_param, n)
 
  465         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
 
  467                                     :   BT::column_data()[0];
 
  468         BT::parameter_size_ += BT::columns_size();
 
  475         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
 
  477                                     :   BT::column_data()[0];
 
  478         BT::parameter_size_ += BT::columns_size();
 
  482     double const & intercept()
 const 
  484         return BT::parameters_begin()[1];
 
  488         return BT::parameters_begin()[1];
 
  491     BT::Parameter_type weights()
 const 
  493         return BT::parameters_begin()+2;
 
  496     BT::Parameter_type weights()
 
  498         return BT::parameters_begin()+2;
 
  502     template<
class U, 
class C>
 
  503     BT::INT next(MultiArrayView<2,U,C> 
const & feature)
 const 
  505         double result = -1 * intercept();
 
  506         if(*(BT::column_data()) == AllColumns)
 
  508             for(
int ii = 0; ii < BT::columns_size(); ++ii)
 
  510                 result +=feature[ii] * weights()[ii];
 
  515             for(
int ii = 0; ii < BT::columns_size(); ++ii)
 
  517                 result +=feature[BT::columns_begin()[ii]] * weights()[ii];
 
  520         return result < 0 ? BT::child(0)
 
  528 class Node<i_HypersphereNode>
 
  538                     BT::T_Container_type    &   topology,
 
  539                     BT::P_Container_type    &   param)
 
  540                 :   NodeBase(nCol + 5,nCol + 1,topology, param)
 
  542         BT::typeID() = i_HypersphereNode;
 
  545     Node(           BT::T_Container_type  
const  &   topology,
 
  546                     BT::P_Container_type  
const  &  param,
 
  548                 :   NodeBase(5, 1,topology, param, n)
 
  550         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
 
  552                                     :   BT::column_data()[0];
 
  553         BT::parameter_size_ += BT::columns_size();
 
  559         BT::topology_size_ += BT::column_data()[0]== AllColumns ?
 
  561                                     :   BT::column_data()[0];
 
  562         BT::parameter_size_ += BT::columns_size();
 
  566     double const & squaredRadius()
 const 
  568         return BT::parameters_begin()[1];
 
  571     double& squaredRadius()
 
  573         return BT::parameters_begin()[1];
 
  576     BT::Parameter_type center()
 const 
  578         return BT::parameters_begin()+2;
 
  581     BT::Parameter_type center()
 
  583         return BT::parameters_begin()+2;
 
  586     template<
class U, 
class C>
 
  587     BT::INT next(MultiArrayView<2,U,C> 
const & feature)
 const 
  589         double result = -1 * squaredRadius();
 
  590         if(*(BT::column_data()) == AllColumns)
 
  592             for(
int ii = 0; ii < BT::columns_size(); ++ii)
 
  594                 result += (feature[ii] - center()[ii])*
 
  595                           (feature[ii] - center()[ii]);
 
  600             for(
int ii = 0; ii < BT::columns_size(); ++ii)
 
  602                 result += (feature[BT::columns_begin()[ii]] - center()[ii])*
 
  603                           (feature[BT::columns_begin()[ii]] - center()[ii]);
 
  606         return result < 0 ? BT::child(0)
 
  626 class Node<e_ConstProbNode>
 
  636                 BT(2,topology[1]+1, topology, param)
 
  639         BT::typeID() = e_ConstProbNode;
 
  646                 :   
BT(2, topology[1]+1,topology, param, n)
 
  651         :   
BT(2, node_.classCount_ +1, node_) 
 
  653     BT::Parameter_type  prob_begin()
 const 
  655         return BT::parameters_begin()+1;
 
  657     BT::Parameter_type  prob_end()
 const 
  659         return prob_begin() + prob_size();
 
  661     int prob_size()
 const 
  663         return BT::classCount_;
 
  668 class Node<e_LogRegProbNode>;
 
  672 #endif //RF_nodeproxy 
Topology_type column_data() const 
Definition: rf_nodeproxy.hxx:159
const_iterator begin() const 
Definition: array_vector.hxx:223
NodeBase()
Definition: rf_nodeproxy.hxx:237
NodeBase(T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:254
Topology_type columns_begin() const 
Definition: rf_nodeproxy.hxx:167
INT & child(Int32 l)
Definition: rf_nodeproxy.hxx:224
int columns_size() const 
Definition: rf_nodeproxy.hxx:174
NodeBase(int tLen, int pLen, NodeBase &node)
Definition: rf_nodeproxy.hxx:292
Topology_type columns_end() const 
Definition: rf_nodeproxy.hxx:184
NodeBase(int tLen, int pLen, T_Container_type const &topology, P_Container_type const ¶meter, INT n)
Definition: rf_nodeproxy.hxx:272
Definition: rf_nodeproxy.hxx:87
bool data() const 
Definition: rf_nodeproxy.hxx:128
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
Parameter_type parameters_begin() const 
Definition: rf_nodeproxy.hxx:207
INT & typeID()
Definition: rf_nodeproxy.hxx:136
NodeBase(int tLen, int pLen, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:316
INT & parameter_addr()
Definition: rf_nodeproxy.hxx:148
size_type size() const 
Definition: array_vector.hxx:358
INT const & child(Int32 l) const 
Definition: rf_nodeproxy.hxx:231
Topology_type topology_begin() const 
Definition: rf_nodeproxy.hxx:193
double & weights()
Definition: rf_nodeproxy.hxx:115
NodeBase(NodeBase const &toCopy, T_Container_type &topology, P_Container_type ¶meter)
Definition: rf_nodeproxy.hxx:356