gtsam 4.2.0
gtsam
Loading...
Searching...
No Matches
DecisionTree.h
Go to the documentation of this file.
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
20#pragma once
21
22#include <gtsam/base/Testable.h>
23#include <gtsam/base/types.h>
25
26#include <boost/serialization/nvp.hpp>
27#include <boost/shared_ptr.hpp>
28#include <functional>
29#include <iostream>
30#include <map>
31#include <set>
32#include <sstream>
33#include <string>
34#include <utility>
35#include <vector>
36
37namespace gtsam {
38
60 template<typename L, typename Y>
62 protected:
64 static bool DefaultCompare(const Y& a, const Y& b) {
65 return a == b;
66 }
67
68 public:
69 using LabelFormatter = std::function<std::string(L)>;
70 using ValueFormatter = std::function<std::string(Y)>;
71 using CompareFunc = std::function<bool(const Y&, const Y&)>;
72
74 using Unary = std::function<Y(const Y&)>;
75 using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
76 using Binary = std::function<Y(const Y&, const Y&)>;
77
79 using LabelC = std::pair<L, size_t>;
80
82 struct Leaf;
83 struct Choice;
84
86 struct Node {
87 using Ptr = boost::shared_ptr<const Node>;
88
89#ifdef DT_DEBUG_MEMORY
90 static int nrNodes;
91#endif
92
93 // Constructor
94 Node() {
95#ifdef DT_DEBUG_MEMORY
96 std::cout << ++nrNodes << " constructed " << id() << std::endl;
97 std::cout.flush();
98#endif
99 }
100
101 // Destructor
102 virtual ~Node() {
103#ifdef DT_DEBUG_MEMORY
104 std::cout << --nrNodes << " destructed " << id() << std::endl;
105 std::cout.flush();
106#endif
107 }
108
109 // Unique ID for dot files
110 const void* id() const { return this; }
111
112 // everything else is virtual, no documentation here as internal
113 virtual void print(const std::string& s,
114 const LabelFormatter& labelFormatter,
115 const ValueFormatter& valueFormatter) const = 0;
116 virtual void dot(std::ostream& os, const LabelFormatter& labelFormatter,
117 const ValueFormatter& valueFormatter,
118 bool showZero) const = 0;
119 virtual bool sameLeaf(const Leaf& q) const = 0;
120 virtual bool sameLeaf(const Node& q) const = 0;
121 virtual bool equals(const Node& other, const CompareFunc& compare =
122 &DefaultCompare) const = 0;
123 virtual const Y& operator()(const Assignment<L>& x) const = 0;
124 virtual Ptr apply(const Unary& op) const = 0;
125 virtual Ptr apply(const UnaryAssignment& op,
126 const Assignment<L>& assignment) const = 0;
127 virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
128 virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
129 virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
130 virtual Ptr choose(const L& label, size_t index) const = 0;
131 virtual bool isLeaf() const = 0;
132
133 private:
136 template <class ARCHIVE>
137 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
138 };
141 public:
143 using NodePtr = typename Node::Ptr;
144
147
148 protected:
153 template<typename It, typename ValueIt>
154 NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
155
166 template <typename M, typename X>
168 std::function<L(const M&)> L_of_M,
169 std::function<Y(const X&)> Y_of_X) const;
170
171 public:
174
176 DecisionTree();
177
179 explicit DecisionTree(const Y& y);
180
188 DecisionTree(const L& label, const Y& y1, const Y& y2);
189
191 DecisionTree(const LabelC& label, const Y& y1, const Y& y2);
192
194 DecisionTree(const std::vector<LabelC>& labelCs, const std::vector<Y>& ys);
195
197 DecisionTree(const std::vector<LabelC>& labelCs, const std::string& table);
198
200 template<typename Iterator>
201 DecisionTree(Iterator begin, Iterator end, const L& label);
202
204 DecisionTree(const L& label, const DecisionTree& f0,
205 const DecisionTree& f1);
206
214 template <typename X, typename Func>
215 DecisionTree(const DecisionTree<L, X>& other, Func Y_of_X);
216
227 template <typename M, typename X, typename Func>
228 DecisionTree(const DecisionTree<M, X>& other, const std::map<M, L>& map,
229 Func Y_of_X);
230
234
242 void print(const std::string& s, const LabelFormatter& labelFormatter,
243 const ValueFormatter& valueFormatter) const;
244
245 // Testable
246 bool equals(const DecisionTree& other,
247 const CompareFunc& compare = &DefaultCompare) const;
248
252
254 virtual ~DecisionTree() = default;
255
257 bool empty() const { return !root_; }
258
260 bool operator==(const DecisionTree& q) const;
261
263 const Y& operator()(const Assignment<L>& x) const;
264
279 template <typename Func>
280 void visit(Func f) const;
281
296 template <typename Func>
297 void visitLeaf(Func f) const;
298
313 template <typename Func>
314 void visitWith(Func f) const;
315
317 size_t nrLeaves() const;
318
334 template <typename Func, typename X>
335 X fold(Func f, X x0) const;
336
338 std::set<L> labels() const;
339
341 DecisionTree apply(const Unary& op) const;
342
351 DecisionTree apply(const UnaryAssignment& op) const;
352
354 DecisionTree apply(const DecisionTree& g, const Binary& op) const;
355
358 DecisionTree choose(const L& label, size_t index) const {
359 NodePtr newRoot = root_->choose(label, index);
360 return DecisionTree(newRoot);
361 }
362
364 DecisionTree combine(const L& label, size_t cardinality,
365 const Binary& op) const;
366
368 DecisionTree combine(const LabelC& labelC, const Binary& op) const {
369 return combine(labelC.first, labelC.second, op);
370 }
371
373 void dot(std::ostream& os, const LabelFormatter& labelFormatter,
374 const ValueFormatter& valueFormatter, bool showZero = true) const;
375
377 void dot(const std::string& name, const LabelFormatter& labelFormatter,
378 const ValueFormatter& valueFormatter, bool showZero = true) const;
379
381 std::string dot(const LabelFormatter& labelFormatter,
382 const ValueFormatter& valueFormatter,
383 bool showZero = true) const;
384
387
388 // internal use only
389 explicit DecisionTree(const NodePtr& root);
390
391 // internal use only
392 template<typename Iterator> NodePtr
393 compose(Iterator begin, Iterator end, const L& label) const;
394
396
397 private:
400 template <class ARCHIVE>
401 void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
402 ar& BOOST_SERIALIZATION_NVP(root_);
403 }
404 }; // DecisionTree
405
406 template <class L, class Y>
407 struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
408
412 template<typename L, typename Y>
414 const typename DecisionTree<L, Y>::Unary& op) {
415 return f.apply(op);
416 }
417
419 template<typename L, typename Y>
421 const typename DecisionTree<L, Y>::UnaryAssignment& op) {
422 return f.apply(op);
423 }
424
426 template<typename L, typename Y>
428 const DecisionTree<L, Y>& g,
429 const typename DecisionTree<L, Y>::Binary& op) {
430 return f.apply(g, op);
431 }
432
439 template <typename L, typename T1, typename T2>
440 std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
441 const DecisionTree<L, std::pair<T1, T2> >& input) {
442 return std::make_pair(
443 DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
445 [](std::pair<T1, T2> i) { return i.second; }));
446 }
447
448} // namespace gtsam
Concept check for values that can be used in unit tests.
Typedefs for easier changing of types.
An assignment from labels to a discrete value index (size_t)
Global functions in a separate testing namespace.
Definition chartTesting.h:28
std::pair< DecisionTree< L, T1 >, DecisionTree< L, T2 > > unzip(const DecisionTree< L, std::pair< T1, T2 > > &input)
unzip a DecisionTree with std::pair values.
Definition DecisionTree.h:440
DecisionTree< L, Y > apply(const DecisionTree< L, Y > &f, const typename DecisionTree< L, Y >::Unary &op)
free versions of apply
Definition DecisionTree.h:413
A manifold defines a space in which there is a notion of a linear tangent space that can be centered ...
Definition concepts.h:30
Template to create a binary predicate.
Definition Testable.h:111
A helper that implements the traits interface for GTSAM types.
Definition Testable.h:151
An assignment from labels to value index (size_t).
Definition Assignment.h:37
Definition DecisionTree-inl.h:52
Definition DecisionTree-inl.h:172
a decision tree is a function from assignments to values.
Definition DecisionTree.h:61
DecisionTree apply(const Unary &op) const
apply Unary operation "op" to f
Definition DecisionTree-inl.h:889
DecisionTree choose(const L &label, size_t index) const
create a new function where value(label)==index It's like "restrict" in Darwiche09book pg329,...
Definition DecisionTree.h:358
NodePtr convertFrom(const typename DecisionTree< M, X >::NodePtr &f, std::function< L(const M &)> L_of_M, std::function< Y(const X &)> Y_of_X) const
Convert from a DecisionTree<M, X> to DecisionTree<L, Y>.
Definition DecisionTree-inl.h:671
DecisionTree combine(const LabelC &labelC, const Binary &op) const
combine with LabelC for convenience
Definition DecisionTree.h:368
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Internal recursive function to create from keys, cardinalities, and Y values.
Definition DecisionTree-inl.h:630
virtual ~DecisionTree()=default
Make virtual.
static bool DefaultCompare(const Y &a, const Y &b)
Default method for comparison of two objects of type Y.
Definition DecisionTree.h:64
typename Node::Ptr NodePtr
---------------------— Node base class ------------------------—
Definition DecisionTree.h:143
std::set< L > labels() const
Retrieve all unique labels as a set.
Definition DecisionTree-inl.h:853
bool empty() const
Check if tree is empty.
Definition DecisionTree.h:257
void visit(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:736
void visitLeaf(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:773
std::function< Y(const Y &)> Unary
Handy typedefs for unary and binary function types.
Definition DecisionTree.h:74
X fold(Func f, X x0) const
Fold a binary function over the tree, returning accumulator.
Definition DecisionTree-inl.h:833
NodePtr root_
A DecisionTree just contains the root. TODO(dellaert): make protected.
Definition DecisionTree.h:146
void print(const std::string &s, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter) const
GTSAM-style print.
Definition DecisionTree-inl.h:872
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
combine subtrees on key with binary operation "op"
Definition DecisionTree-inl.h:937
void visitWith(Func f) const
Visit all leaves in depth-first fashion.
Definition DecisionTree-inl.h:816
const Y & operator()(const Assignment< L > &x) const
evaluate
Definition DecisionTree-inl.h:884
void dot(std::ostream &os, const LabelFormatter &labelFormatter, const ValueFormatter &valueFormatter, bool showZero=true) const
output to graphviz format, stream version
Definition DecisionTree-inl.h:949
friend class boost::serialization::access
Serialization function.
Definition DecisionTree.h:399
bool operator==(const DecisionTree &q) const
equality
Definition DecisionTree-inl.h:879
std::pair< L, size_t > LabelC
A label annotated with cardinality.
Definition DecisionTree.h:79
size_t nrLeaves() const
Return the number of leaves in the tree.
Definition DecisionTree-inl.h:823
DecisionTree()
Default constructor (for serialization)
Definition DecisionTree-inl.h:462
---------------------— Node base class ------------------------—
Definition DecisionTree.h:86
friend class boost::serialization::access
Serialization function.
Definition DecisionTree.h:135