Point Cloud Library (PCL) 1.12.0
Loading...
Searching...
No Matches
fern_trainer.hpp
1/*
2 * Software License Agreement (BSD License)
3 *
4 * Point Cloud Library (PCL) - www.pointclouds.org
5 * Copyright (c) 2010-2011, Willow Garage, Inc.
6 *
7 * All rights reserved.
8 *
9 * Redistribution and use in source and binary forms, with or without
10 * modification, are permitted provided that the following conditions
11 * are met:
12 *
13 * * Redistributions of source code must retain the above copyright
14 * notice, this list of conditions and the following disclaimer.
15 * * Redistributions in binary form must reproduce the above
16 * copyright notice, this list of conditions and the following
17 * disclaimer in the documentation and/or other materials provided
18 * with the distribution.
19 * * Neither the name of Willow Garage, Inc. nor the names of its
20 * contributors may be used to endorse or promote products derived
21 * from this software without specific prior written permission.
22 *
23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34 * POSSIBILITY OF SUCH DAMAGE.
35 *
36 */
37
38#pragma once
39
40namespace pcl {
41
42template <class FeatureType,
43 class DataSet,
44 class LabelType,
45 class ExampleIndex,
46 class NodeType>
48: fern_depth_(10)
49, num_of_features_(1000)
50, num_of_thresholds_(10)
51, feature_handler_(nullptr)
52, stats_estimator_(nullptr)
53, data_set_()
54, label_data_()
55, examples_()
56{}
57
58template <class FeatureType,
59 class DataSet,
60 class LabelType,
61 class ExampleIndex,
62 class NodeType>
65
66template <class FeatureType,
67 class DataSet,
68 class LabelType,
69 class ExampleIndex,
70 class NodeType>
71void
74{
75 const std::size_t num_of_branches = stats_estimator_->getNumOfBranches();
76 const std::size_t num_of_examples = examples_.size();
77
78 // create random features
79 std::vector<FeatureType> features;
80 feature_handler_->createRandomFeatures(num_of_features_, features);
81
82 // setup fern
83 fern.initialize(fern_depth_);
84
85 // evaluate all features
86 std::vector<std::vector<float>> feature_results(num_of_features_);
87 std::vector<std::vector<unsigned char>> flags(num_of_features_);
88
89 for (std::size_t feature_index = 0; feature_index < num_of_features_;
90 ++feature_index) {
92 flags[feature_index].reserve(num_of_examples);
93
94 feature_handler_->evaluateFeature(features[feature_index],
95 data_set_,
96 examples_,
98 flags[feature_index]);
99 }
100
101 // iteratively select features and thresholds
102 std::vector<std::vector<std::vector<float>>> branch_feature_results(
103 num_of_features_); // [feature_index][branch_index][result_index]
104 std::vector<std::vector<std::vector<unsigned char>>> branch_flags(
105 num_of_features_); // [feature_index][branch_index][flag_index]
106 std::vector<std::vector<std::vector<ExampleIndex>>> branch_examples(
107 num_of_features_); // [feature_index][branch_index][result_index]
108 std::vector<std::vector<std::vector<LabelType>>> branch_label_data(
109 num_of_features_); // [feature_index][branch_index][flag_index]
110
111 // - initialize branch feature results and flags
112 for (std::size_t feature_index = 0; feature_index < num_of_features_;
113 ++feature_index) {
115 branch_flags[feature_index].resize(1);
118
121 branch_examples[feature_index][0] = examples_;
122 branch_label_data[feature_index][0] = label_data_;
123 }
124
125 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
126 // get thresholds
127 std::vector<std::vector<float>> thresholds(num_of_features_);
128
129 for (std::size_t feature_index = 0; feature_index < num_of_features_;
130 ++feature_index) {
131 thresholds.reserve(num_of_thresholds_);
132 createThresholdsUniform(num_of_thresholds_,
135 }
136
137 // compute information gain
138 int best_feature_index = -1;
139 float best_feature_threshold = 0.0f;
141
142 for (std::size_t feature_index = 0; feature_index < num_of_features_;
143 ++feature_index) {
144 for (std::size_t threshold_index = 0; threshold_index < num_of_thresholds_;
145 ++threshold_index) {
146 float information_gain = 0.0f;
147 for (std::size_t branch_index = 0;
149 ++branch_index) {
150 const float branch_information_gain =
151 stats_estimator_->computeInformationGain(
152 data_set_,
158
162 }
163
166 best_feature_index = static_cast<int>(feature_index);
168 }
169 }
170 }
171
172 // add feature to the feature list of the fern
173 fern.accessFeature(depth_index) = features[best_feature_index];
174 fern.accessThreshold(depth_index) = best_feature_threshold;
175
176 // update branch feature results and flags
177 for (std::size_t feature_index = 0; feature_index < num_of_features_;
178 ++feature_index) {
179 std::vector<std::vector<float>>& cur_branch_feature_results =
181 std::vector<std::vector<unsigned char>>& cur_branch_flags =
183 std::vector<std::vector<ExampleIndex>>& cur_branch_examples =
185 std::vector<std::vector<LabelType>>& cur_branch_label_data =
187
188 const std::size_t total_num_of_new_branches =
190
191 std::vector<std::vector<float>> new_branch_feature_results(
192 total_num_of_new_branches); // [branch_index][example_index]
193 std::vector<std::vector<unsigned char>> new_branch_flags(
194 total_num_of_new_branches); // [branch_index][example_index]
195 std::vector<std::vector<ExampleIndex>> new_branch_examples(
196 total_num_of_new_branches); // [branch_index][example_index]
197 std::vector<std::vector<LabelType>> new_branch_label_data(
198 total_num_of_new_branches); // [branch_index][example_index]
199
200 for (std::size_t branch_index = 0;
202 ++branch_index) {
203 const std::size_t num_of_examples_in_this_branch =
205
206 std::vector<unsigned char> branch_indices;
208
209 stats_estimator_->computeBranchIndices(cur_branch_feature_results[branch_index],
213
214 // split results into different branches
215 const std::size_t base_branch_index = branch_index * num_of_branches;
216 for (std::size_t example_index = 0;
218 ++example_index) {
219 const std::size_t combined_branch_index =
221
230 }
231 }
232
237 }
238 }
239
240 // set node statistics
241 // - re-evaluate selected features
242 std::vector<std::vector<float>> final_feature_results(
243 fern_depth_); // [feature_index][example_index]
244 std::vector<std::vector<unsigned char>> final_flags(
245 fern_depth_); // [feature_index][example_index]
246 std::vector<std::vector<unsigned char>> final_branch_indices(
247 fern_depth_); // [feature_index][example_index]
248 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
252
253 feature_handler_->evaluateFeature(fern.accessFeature(depth_index),
254 data_set_,
255 examples_,
258
259 stats_estimator_->computeBranchIndices(final_feature_results[depth_index],
261 fern.accessThreshold(depth_index),
263 }
264
265 // - distribute examples to nodes
266 std::vector<std::vector<LabelType>> node_labels(
267 0x1 << fern_depth_); // [node_index][example_index]
268 std::vector<std::vector<ExampleIndex>> node_examples(
269 0x1 << fern_depth_); // [node_index][example_index]
270
271 for (std::size_t example_index = 0; example_index < num_of_examples;
272 ++example_index) {
273 std::size_t node_index = 0;
274 for (std::size_t depth_index = 0; depth_index < fern_depth_; ++depth_index) {
277 }
278
279 node_labels[node_index].push_back(label_data_[example_index]);
280 node_examples[node_index].push_back(examples_[example_index]);
281 }
282
283 // - compute and set statistics for every node
284 const std::size_t num_of_nodes = 0x1 << fern_depth_;
285 for (std::size_t node_index = 0; node_index < num_of_nodes; ++node_index) {
286 stats_estimator_->computeAndSetNodeStats(data_set_,
290 }
291}
292
293template <class FeatureType,
294 class DataSet,
295 class LabelType,
296 class ExampleIndex,
297 class NodeType>
298void
301 std::vector<float>& values,
302 std::vector<float>& thresholds)
303{
304 // estimate range of values
305 float min_value = ::std::numeric_limits<float>::max();
306 float max_value = -::std::numeric_limits<float>::max();
307
308 const std::size_t num_of_values = values.size();
310 const float value = values[value_index];
311
312 if (value < min_value)
313 min_value = value;
314 if (value > max_value)
315 max_value = value;
316 }
317
318 const float range = max_value - min_value;
319 const float step = range / (num_of_thresholds + 2);
320
321 // compute thresholds
323
325 ++threshold_index) {
327 }
328}
329
330} // namespace pcl
Iterator class for point clouds with or without given indices.
ConstCloudIterator(const PointCloud< PointT > &cloud)
std::size_t size() const
Size of the range the iterator is going through.
static void createThresholdsUniform(const std::size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
virtual ~FernTrainer()
Destructor.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.