36 #ifndef VIGRA_RANDOM_FOREST_DT_HXX 
   37 #define VIGRA_RANDOM_FOREST_DT_HXX 
   42 #include "vigra/multi_array.hxx" 
   43 #include "vigra/mathutil.hxx" 
   44 #include "vigra/metaprogramming.hxx" 
   45 #include "vigra/array_vector.hxx" 
   46 #include "vigra/sized_int.hxx" 
   47 #include "vigra/matrix.hxx" 
   48 #include "vigra/random.hxx" 
   49 #include "vigra/functorexpression.hxx" 
   52 #include "rf_common.hxx" 
   53 #include "rf_visitors.hxx" 
   54 #include "rf_nodeproxy.hxx" 
   86     typedef Int32 TreeInt;
 
   88     ArrayVector<TreeInt>  topology_;
 
   89     ArrayVector<double>   parameters_;
 
   91     ProblemSpec<> ext_param_;
 
   92     unsigned int classCount_;
 
   98     DecisionTree(ProblemSpec<T> ext_param)
 
  100         ext_param_(ext_param),
 
  101         classCount_(ext_param.class_count_)
 
  106     void reset(
unsigned int classCount = 0)
 
  109             classCount_ = classCount;
 
  122     template <  
class U, 
class C,
 
  129     void learn(     MultiArrayView<2, U, C> 
const      & features,
 
  130                     MultiArrayView<2, U2, C2> 
const    & labels,
 
  131                     StackEntry_t 
const &                 stack_entry,
 
  136     template <  
class U, 
class C,
 
  143     void continueLearn(   MultiArrayView<2, U, C> 
const       & features,
 
  144                           MultiArrayView<2, U2, C2> 
const     & labels,
 
  145                           StackEntry_t 
const &                  stack_entry,
 
  151                           int                                   garbaged_child=-1);
 
  154     inline bool isLeafNode(TreeInt in)
 const 
  156         return (in & LeafNodeTag) == LeafNodeTag;
 
  164     template<
class U, 
class C, 
class Visitor_t>
 
  165     TreeInt getToLeaf(MultiArrayView<2, U, C> 
const & features, 
 
  166                       Visitor_t  & visitor)
 const 
  169         while(!isLeafNode(topology_[index]))
 
  171             visitor.visit_internal_node(*
this, index, topology_[index],features);
 
  172             switch(topology_[index])
 
  174                 case i_ThresholdNode:
 
  176                     Node<i_ThresholdNode> 
 
  177                                 node(topology_, parameters_, index);
 
  178                     index = node.next(features);
 
  181                 case i_HyperplaneNode:
 
  183                     Node<i_HyperplaneNode> 
 
  184                                 node(topology_, parameters_, index);
 
  185                     index = node.next(features);
 
  188                 case i_HypersphereNode:
 
  190                     Node<i_HypersphereNode> 
 
  191                                 node(topology_, parameters_, index);
 
  192                     index = node.next(features);
 
  200                                 node(topology_, parameters, index);
 
  201                     index = node.next(features);
 
  205                     vigra_fail(
"DecisionTree::getToLeaf():" 
  206                                "encountered unknown internal Node Type");
 
  209         visitor.visit_external_node(*
this, index, topology_[index],features);
 
  217     template<
class Visitor_t>
 
  218     void traverse_mem_order(Visitor_t visitor)
 const 
  221         while(index < topology_.size())
 
  223             if(isLeafNode(topology_[index]))
 
  226                     .visit_external_node(*
this, index, topology_[index]);
 
  231                     ._internal_node(*
this, index, topology_[index]);
 
  236     template<
class Visitor_t>
 
  237     void traverse_post_order(Visitor_t visitor,  TreeInt  = 2)
 const 
  239         typedef TinyVector<double, 2> Entry; 
 
  240         std::vector<Entry > stack;
 
  241         std::vector<double> result_stack;
 
  242         stack.push_back(Entry(2, 0));
 
  244         while(!stack.empty())
 
  246             addr = stack.back()[0];
 
  247             NodeBase node(topology_, parameters_, stack.back()[0]);
 
  248             if(stack.back()[1] == 1)
 
  251                 double leftRes = result_stack.back();
 
  252                 double rightRes = result_stack.back();
 
  253                 result_stack.pop_back();
 
  254                 result_stack.pop_back();
 
  255                 result_stack.push_back(rightRes+ leftRes);
 
  256                 visitor.visit_internal_node(*
this, 
 
  263                 if(isLeafNode(node.typeID()))
 
  265                     visitor.visit_external_node(*
this, 
 
  270                     result_stack.push_back(node.weights());
 
  275                     stack.push_back(Entry(node.child(0), 0));
 
  276                     stack.push_back(Entry(node.child(1), 0));
 
  284     template<
class U, 
class C>
 
  285     TreeInt getToLeaf(MultiArrayView<2, U, C> 
const & features)
 const 
  288         return getToLeaf(features, stop);
 
  292     template <
class U, 
class C>
 
  293     ArrayVector<double>::iterator
 
  294     predict(MultiArrayView<2, U, C> 
const & features)
 const 
  296         TreeInt nodeindex = getToLeaf(features);
 
  297         switch(topology_[nodeindex])
 
  299             case e_ConstProbNode:
 
  300                 return Node<e_ConstProbNode>(topology_, 
 
  302                                              nodeindex).prob_begin();
 
  306             case e_LogRegProbNode:
 
  307                 return Node<e_LogRegProbNode>(topology_, 
 
  309                                               nodeindex).prob_begin();
 
  312                 vigra_fail(
"DecisionTree::predict() :" 
  313                            " encountered unknown external Node Type");
 
  315         return ArrayVector<double>::iterator();
 
  320     template <
class U, 
class C>
 
  321     Int32 predictLabel(MultiArrayView<2, U, C> 
const & features)
 const 
  323         ArrayVector<double>::const_iterator weights = predict(features);
 
  324         return argMax(weights, weights+classCount_) - weights;
 
  330 template <  
class U, 
class C,
 
  337 void DecisionTree::learn(   MultiArrayView<2, U, C> 
const       & features,
 
  338                             MultiArrayView<2, U2, C2> 
const     & labels,
 
  339                             StackEntry_t 
const &                  stack_entry,
 
  346     topology_.reserve(256);
 
  347     parameters_.reserve(256);
 
  348     topology_.push_back(features.shape(1));
 
  349     topology_.push_back(classCount_);
 
  350     continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
 
  353 template <  
class U, 
class C,
 
  360 void DecisionTree::continueLearn(   MultiArrayView<2, U, C> 
const       & features,
 
  361                             MultiArrayView<2, U2, C2> 
const     & labels,
 
  362                             StackEntry_t 
const &                  stack_entry,
 
  370     std::vector<StackEntry_t> stack;
 
  372     ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
 
  373     stack.push_back(stack_entry);
 
  374     size_t last_node_pos = 0;
 
  375     StackEntry_t top=stack.back();
 
  377     while(!stack.empty())
 
  385         child_stack_entry[0].reset();
 
  386         child_stack_entry[1].reset();
 
  396             NodeID = split.makeTerminalNode(features, 
 
  403             NodeID = split.findBestSplit(features, 
 
  413         visitor.visit_after_split(*
this, split, top, 
 
  414                                   child_stack_entry[0], 
 
  415                                   child_stack_entry[1],
 
  423         last_node_pos = topology_.size();
 
  424         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
 
  428                      top.leftParent).child(0) = last_node_pos;
 
  430         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
 
  434                      top.rightParent).child(1) = last_node_pos;
 
  441         if(!isLeafNode(NodeID))
 
  443             child_stack_entry[0].leftParent = topology_.size();
 
  444             child_stack_entry[1].rightParent = topology_.size();    
 
  445             child_stack_entry[0].rightParent = -1;
 
  446             child_stack_entry[1].leftParent = -1;
 
  447             stack.push_back(child_stack_entry[0]);
 
  448             stack.push_back(child_stack_entry[1]);
 
  453         NodeBase node(split.createNode(), topology_, parameters_ );
 
  454         ignore_argument(node);
 
  456     if(garbaged_child!=-1)
 
  458         Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
 
  460         int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
 
  461         topology_.resize(last_node_pos);
 
  462         parameters_.resize(parameters_.size() - last_parameter_size);
 
  464         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
 
  467                      top.leftParent).child(0) = garbaged_child;
 
  468         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
 
  471                      top.rightParent).child(1) = garbaged_child;
 
  479 #endif //VIGRA_RANDOM_FOREST_DT_HXX 
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int 
Definition: sized_int.hxx:175
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence. 
Definition: algorithm.hxx:96
detail::SelectIntegerType< 32, detail::UnsignedIntTypes >::type UInt32
32-bit unsigned int 
Definition: sized_int.hxx:183
Definition: rf_visitors.hxx:234