37 #ifndef VIGRA_RF3_IMPEX_HDF5_HXX 
   38 #define VIGRA_RF3_IMPEX_HDF5_HXX 
   46 #include "random_forest_3/random_forest.hxx" 
   47 #include "random_forest_3/random_forest_common.hxx" 
   48 #include "random_forest_3/random_forest_visitors.hxx" 
   49 #include "hdf5impex.hxx" 
   57 static const char *
const rf_hdf5_ext_param     = 
"_ext_param";
 
   58 static const char *
const rf_hdf5_options       = 
"_options";
 
   59 static const char *
const rf_hdf5_topology      = 
"topology";
 
   60 static const char *
const rf_hdf5_parameters    = 
"parameters";
 
   61 static const char *
const rf_hdf5_tree          = 
"Tree_";
 
   62 static const char *
const rf_hdf5_version_group = 
".";
 
   63 static const char *
const rf_hdf5_version_tag   = 
"vigra_random_forest_version";
 
   64 static const double      rf_hdf5_version       =  0.1;
 
   70     rf_AllColumns          = 0x00000000,
 
   71     rf_ToBePrunedTag       = 0x80000000,
 
   72     rf_LeafNodeTag         = 0x40000000,
 
   74     rf_i_ThresholdNode     = 0,
 
   75     rf_i_HyperplaneNode    = 1,
 
   76     rf_i_HypersphereNode   = 2,
 
   77     rf_e_ConstProbNode     = 0 | rf_LeafNodeTag,
 
   78     rf_e_LogRegProbNode    = 1 | rf_LeafNodeTag
 
   81 static const unsigned int rf_tag_mask = 0xf0000000;
 
   82 static const unsigned int rf_type_mask = 0x00000003;
 
   83 static const unsigned int rf_zero_mask = 0xffffffff & ~rf_tag_mask & ~rf_type_mask;
 
   87     inline std::string get_cwd(HDF5File & h5context)
 
   89         return h5context.get_absolute_path(h5context.pwd());
 
   93 template <
typename FEATURES, 
typename LABELS>
 
   94 typename DefaultRF<FEATURES, LABELS>::type
 
   95 random_forest_import_HDF5(HDF5File & h5ctx, std::string 
const & pathname = 
"")
 
   97     typedef typename DefaultRF<FEATURES, LABELS>::type RF;
 
   98     typedef typename RF::Graph Graph;
 
   99     typedef typename RF::Node Node;
 
  100     typedef typename RF::SplitTests SplitTest;
 
  101     typedef typename LABELS::value_type LabelType;
 
  102     typedef typename RF::AccInputType AccInputType;
 
  103     typedef typename AccInputType::value_type AccValueType;
 
  107     if (pathname.size()) {
 
  108         cwd = detail::get_cwd(h5ctx);
 
  112     if (h5ctx.existsAttribute(rf_hdf5_version_group, rf_hdf5_version_tag)) {
 
  114         h5ctx.readAttribute(rf_hdf5_version_group, rf_hdf5_version_tag, version);
 
  115         vigra_precondition(version <= rf_hdf5_version, 
"random_forest_import_HDF5(): unexpected file format version.");
 
  120     size_t num_instances;
 
  125     MultiArray<1, LabelType> distinct_labels_marray;
 
  126     MultiArray<1, double> class_weights_marray;
 
  128     h5ctx.cd(rf_hdf5_ext_param);
 
  129     h5ctx.read(
"actual_msample_", msample);
 
  130     h5ctx.read(
"actual_mtry_", actual_mtry);
 
  131     h5ctx.read(
"class_count_", num_classes);
 
  132     h5ctx.readAndResize(
"class_weights_", class_weights_marray);
 
  133     h5ctx.read(
"column_count_", num_features);
 
  134     h5ctx.read(
"is_weighted_", is_weighted_int);
 
  135     h5ctx.readAndResize(
"labels", distinct_labels_marray);
 
  136     h5ctx.read(
"row_count_", num_instances);
 
  139     bool is_weighted = is_weighted_int == 1 ? 
true : 
false;
 
  142     size_t min_num_instances;
 
  145     int bootstrap_sampling_int;
 
  147     h5ctx.cd(rf_hdf5_options);
 
  148     h5ctx.read(
"min_split_node_size_", min_num_instances);
 
  149     h5ctx.read(
"mtry_", mtry);
 
  150     h5ctx.read(
"mtry_switch_", mtry_switch_int);
 
  151     h5ctx.read(
"sample_with_replacement_", bootstrap_sampling_int);
 
  152     h5ctx.read(
"tree_count_", tree_count);
 
  155     RandomForestOptionTags mtry_switch = (RandomForestOptionTags)mtry_switch_int;
 
  156     bool bootstrap_sampling = bootstrap_sampling_int == 1 ? 
true : 
false;
 
  158     std::vector<LabelType> 
const distinct_labels(distinct_labels_marray.begin(), distinct_labels_marray.end());
 
  159     std::vector<double> 
const class_weights(class_weights_marray.begin(), class_weights_marray.end());
 
  161     auto const pspec = ProblemSpec<LabelType>()
 
  162                                .num_features(num_features)
 
  163                                .num_instances(num_instances)
 
  164                                .num_classes(num_classes)
 
  165                                .distinct_classes(distinct_labels)
 
  166                                .actual_mtry(actual_mtry)
 
  167                                .actual_msample(msample);
 
  169     auto options = RandomForestOptions()
 
  170                             .min_num_instances(min_num_instances)
 
  171                             .bootstrap_sampling(bootstrap_sampling)
 
  172                             .tree_count(tree_count);
 
  173     options.features_per_node_switch_ = mtry_switch;
 
  174     options.features_per_node_ = mtry;
 
  176         options.class_weights(class_weights);
 
  179     typename RF::template NodeMap<SplitTest>::type split_tests;
 
  180     typename RF::template NodeMap<AccInputType>::type leaf_responses;
 
  182     auto const groups = h5ctx.ls();
 
  183     for (
auto const & groupname : groups) {
 
  184         if (groupname.substr(0, std::char_traits<char>::length(rf_hdf5_tree)).compare(rf_hdf5_tree) != 0) {
 
  188         MultiArray<1, unsigned int> topology;
 
  189         MultiArray<1, double> parameters;
 
  191         h5ctx.readAndResize(rf_hdf5_topology, topology);
 
  192         h5ctx.readAndResize(rf_hdf5_parameters, parameters);
 
  195         vigra_precondition(topology[0] == num_features, 
"random_forest_import_HDF5(): number of features mismatch.");
 
  196         vigra_precondition(topology[1] == num_classes, 
"random_forest_import_HDF5(): number of classes mismatch.");
 
  198         Node 
const n = gr.addNode();
 
  200         std::queue<std::pair<unsigned int, Node> > q;
 
  203             auto const el = q.front();
 
  205             unsigned int const index = el.first;
 
  206             Node 
const parent = el.second;
 
  208             vigra_precondition((topology[index] & rf_zero_mask) == 0, 
"random_forest_import_HDF5(): unexpected node type: type & zero_mask > 0");
 
  210             if (topology[index] & rf_LeafNodeTag) {
 
  211                 unsigned int const probs_start = topology[index+1] + 1;
 
  213                 vigra_precondition((topology[index] & rf_tag_mask) == rf_LeafNodeTag, 
"random_forest_import_HDF5(): unexpected node type: additional tags in leaf node");
 
  215                 std::vector<AccValueType> node_response;
 
  217                 for (
unsigned int i = 0; i < num_classes; ++i) {
 
  218                     node_response.push_back(parameters[probs_start + i]);
 
  221                 leaf_responses.insert(parent, node_response);
 
  224                 vigra_precondition(topology[index] == rf_i_ThresholdNode, 
"random_forest_import_HDF5(): unexpected node type.");
 
  226                 Node 
const left = gr.addNode();
 
  227                 Node 
const right = gr.addNode();
 
  229                 gr.addArc(parent, left);
 
  230                 gr.addArc(parent, right);
 
  232                 split_tests.insert(parent, SplitTest(topology[index+4], parameters[topology[index+1]+1]));
 
  234                 q.push(std::make_pair(topology[index+2], left));
 
  235                 q.push(std::make_pair(topology[index+3], right));
 
  246     RF rf(gr, split_tests, leaf_responses, pspec);
 
  247     rf.options_ = options;
 
  253     class PaddedNumberString
 
  257         PaddedNumberString(
int n)
 
  260             width_ = ss_.str().size();
 
  263         std::string operator()(
int k)
 const 
  266             ss_ << std::setw(width_) << std::setfill(
'0') << k;
 
  272         mutable std::ostringstream ss_;
 
  277 template <
typename RF>
 
  278 void random_forest_export_HDF5(
 
  280         HDF5File & h5context,
 
  281         std::string 
const & pathname = 
"" 
  283     typedef typename RF::LabelType LabelType;
 
  284     typedef typename RF::Node Node;
 
  287     if (pathname.size()) {
 
  288         cwd = detail::get_cwd(h5context);
 
  289         h5context.cd_mk(pathname);
 
  293     h5context.writeAttribute(rf_hdf5_version_group, rf_hdf5_version_tag,
 
  297     auto const & p = rf.problem_spec_;
 
  298     auto const & opts = rf.options_;
 
  299     MultiArray<1, LabelType> distinct_classes(Shape1(p.distinct_classes_.size()), p.distinct_classes_.data());
 
  300     MultiArray<1, double> class_weights(Shape1(p.num_classes_), 1.0);
 
  302     if (opts.class_weights_.size() > 0)
 
  305         for (
size_t i = 0; i < opts.class_weights_.size(); ++i)
 
  306             class_weights(i) = opts.class_weights_[i];
 
  310     h5context.cd_mk(rf_hdf5_ext_param);
 
  311     h5context.write(
"column_count_", p.num_features_);
 
  312     h5context.write(
"row_count_", p.num_instances_);
 
  313     h5context.write(
"class_count_", p.num_classes_);
 
  314     h5context.write(
"actual_mtry_", p.actual_mtry_);
 
  315     h5context.write(
"actual_msample_", p.actual_msample_);
 
  316     h5context.write(
"labels", distinct_classes);
 
  317     h5context.write(
"is_weighted_", is_weighted);
 
  318     h5context.write(
"class_weights_", class_weights);
 
  319     h5context.write(
"precision_", 0.0);
 
  320     h5context.write(
"problem_type_", 1.0);
 
  321     h5context.write(
"response_size_", 1.0);
 
  322     h5context.write(
"used_", 1.0);
 
  326     h5context.cd_mk(rf_hdf5_options);
 
  327     h5context.write(
"min_split_node_size_", opts.min_num_instances_);
 
  328     h5context.write(
"mtry_", opts.features_per_node_);
 
  329     h5context.write(
"mtry_func_", 0.0);
 
  330     h5context.write(
"mtry_switch_", opts.features_per_node_switch_);
 
  331     h5context.write(
"predict_weighted_", 0.0);
 
  332     h5context.write(
"prepare_online_learning_", 0.0);
 
  333     h5context.write(
"sample_with_replacement_", opts.bootstrap_sampling_ ? 1 : 0);
 
  334     h5context.write(
"stratification_method_", 3.0);
 
  335     h5context.write(
"training_set_calc_switch_", 1.0);
 
  336     h5context.write(
"training_set_func_", 0.0);
 
  337     h5context.write(
"training_set_proportion_", 1.0);
 
  338     h5context.write(
"training_set_size_", 0.0);
 
  339     h5context.write(
"tree_count_", opts.tree_count_);
 
  343     detail::PaddedNumberString tree_number(rf.num_trees());
 
  344     for (
size_t i = 0; i < rf.num_trees(); ++i)
 
  347         std::vector<UInt32> topology;
 
  348         std::vector<double> parameters;
 
  349         topology.push_back(p.num_features_);
 
  350         topology.push_back(p.num_classes_);
 
  352         auto const & probs = rf.node_responses_;
 
  353         auto const & splits = rf.split_tests_;
 
  354         auto const & gr = rf.graph_;
 
  355         auto const root = gr.getRoot(i);
 
  361         std::stack<std::pair<Node, std::ptrdiff_t> > stack;
 
  362         stack.emplace(root, -1);
 
  363         while (!stack.empty())
 
  365             auto const n = stack.top().first; 
 
  366             auto const i = stack.top().second; 
 
  371                 topology[i] = topology.size();
 
  373             if (gr.numChildren(n) == 0)
 
  378                 topology.push_back(rf_LeafNodeTag);
 
  379                 topology.push_back(parameters.size());
 
  380                 auto const & prob = probs.at(n);
 
  381                 auto const weight = std::accumulate(prob.begin(), prob.end(), 0.0);
 
  382                 parameters.push_back(weight);
 
  383                 parameters.insert(parameters.end(), prob.begin(), prob.end());
 
  390                 topology.push_back(rf_i_ThresholdNode);
 
  391                 topology.push_back(parameters.size());
 
  392                 topology.push_back(-1); 
 
  393                 topology.push_back(-1); 
 
  394                 topology.push_back(splits.at(n).dim_);
 
  395                 parameters.push_back(1.0); 
 
  396                 parameters.push_back(splits.at(n).val_);
 
  399                 stack.emplace(gr.getChild(n, 0), topology.size()-3);
 
  400                 stack.emplace(gr.getChild(n, 1), topology.size()-2);
 
  405         MultiArray<1, UInt32> topo(Shape1(topology.size()), topology.data());
 
  406         MultiArray<1, double> para(Shape1(parameters.size()), parameters.data());
 
  408         auto const name = rf_hdf5_tree + tree_number(i);
 
  409         h5context.cd_mk(name);
 
  410         h5context.write(rf_hdf5_topology, topo);
 
  411         h5context.write(rf_hdf5_parameters, para);
 
  424 #endif // VIGRA_NEW_RANDOM_FOREST_IMPEX_HDF5_HXX