gtsam 4.2.0
gtsam
Loading...
Searching...
No Matches
BayesTree-inst.h
Go to the documentation of this file.
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
21#pragma once
22
26#include <gtsam/base/timing.h>
27
28#include <boost/optional.hpp>
29#include <fstream>
30
31namespace gtsam {
32
33 /* ************************************************************************* */
34 template<class CLIQUE>
37 for (const sharedClique& root : roots_) getCliqueData(root, &stats);
38 return stats;
39 }
40
41 /* ************************************************************************* */
42 template <class CLIQUE>
44 BayesTreeCliqueData* stats) const {
45 const auto conditional = clique->conditional();
46 stats->conditionalSizes.push_back(conditional->nrFrontals());
47 stats->separatorSizes.push_back(conditional->nrParents());
48 for (sharedClique c : clique->children) {
49 getCliqueData(c, stats);
50 }
51 }
52
53 /* ************************************************************************* */
54 template<class CLIQUE>
56 size_t count = 0;
57 for(const sharedClique& root: roots_)
58 count += root->numCachedSeparatorMarginals();
59 return count;
60 }
61
62 /* ************************************************************************* */
63 template <class CLIQUE>
64 void BayesTree<CLIQUE>::dot(std::ostream& os,
65 const KeyFormatter& keyFormatter) const {
66 if (roots_.empty())
67 throw std::invalid_argument(
68 "the root of Bayes tree has not been initialized!");
69 os << "digraph G{\n";
70 for (const sharedClique& root : roots_) dot(os, root, keyFormatter);
71 os << "}";
72 std::flush(os);
73 }
74
75 /* ************************************************************************* */
76 template <class CLIQUE>
77 std::string BayesTree<CLIQUE>::dot(const KeyFormatter& keyFormatter) const {
78 std::stringstream ss;
79 dot(ss, keyFormatter);
80 return ss.str();
81 }
82
83 /* ************************************************************************* */
84 template <class CLIQUE>
85 void BayesTree<CLIQUE>::saveGraph(const std::string& filename,
86 const KeyFormatter& keyFormatter) const {
87 std::ofstream of(filename.c_str());
88 dot(of, keyFormatter);
89 of.close();
90 }
91
92 /* ************************************************************************* */
93 template <class CLIQUE>
94 void BayesTree<CLIQUE>::dot(std::ostream& s, sharedClique clique,
95 const KeyFormatter& keyFormatter,
96 int parentnum) const {
97 static int num = 0;
98 bool first = true;
99 std::stringstream out;
100 out << num;
101 std::string parent = out.str();
102 parent += "[label=\"";
103
104 for (Key key : clique->conditional_->frontals()) {
105 if (!first) parent += ", ";
106 first = false;
107 parent += keyFormatter(key);
108 }
109
110 if (clique->parent()) {
111 parent += " : ";
112 s << parentnum << "->" << num << "\n";
113 }
114
115 first = true;
116 for (Key parentKey : clique->conditional_->parents()) {
117 if (!first) parent += ", ";
118 first = false;
119 parent += keyFormatter(parentKey);
120 }
121 parent += "\"];\n";
122 s << parent;
123 parentnum = num;
124
125 for (sharedClique c : clique->children) {
126 num++;
127 dot(s, c, keyFormatter, parentnum);
128 }
129 }
130
131 /* ************************************************************************* */
132 template<class CLIQUE>
133 size_t BayesTree<CLIQUE>::size() const {
134 size_t size = 0;
135 for(const sharedClique& clique: roots_)
136 size += clique->treeSize();
137 return size;
138 }
139
140 /* ************************************************************************* */
141 template<class CLIQUE>
142 void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
143 for(Key j: clique->conditional()->frontals())
144 nodes_[j] = clique;
145 if (parent_clique != nullptr) {
146 clique->parent_ = parent_clique;
147 parent_clique->children.push_back(clique);
148 } else {
149 roots_.push_back(clique);
150 }
151 }
152
153 /* ************************************************************************* */
154 namespace {
155 template <class FACTOR, class CLIQUE>
156 struct _pushCliqueFunctor {
157 _pushCliqueFunctor(FactorGraph<FACTOR>* graph_) : graph(graph_) {}
158 FactorGraph<FACTOR>* graph;
159 int operator()(const boost::shared_ptr<CLIQUE>& clique, int dummy) {
160 graph->push_back(clique->conditional_);
161 return 0;
162 }
163 };
164 } // namespace
165
166 /* ************************************************************************* */
167 template <class CLIQUE>
169 FactorGraph<FactorType>* graph) const {
170 // Traverse the BayesTree and add all conditionals to this graph
171 int data = 0; // Unused
172 _pushCliqueFunctor<FactorType, CLIQUE> functor(graph);
173 treeTraversal::DepthFirstForest(*this, data, functor);
174 }
175
176 /* ************************************************************************* */
177 template<class CLIQUE>
179 *this = other;
180 }
181
182 /* ************************************************************************* */
183 namespace {
184 template<typename NODE>
185 boost::shared_ptr<NODE>
186 BayesTreeCloneForestVisitorPre(const boost::shared_ptr<NODE>& node, const boost::shared_ptr<NODE>& parentPointer)
187 {
188 // Clone the current node and add it to its cloned parent
189 boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
190 clone->children.clear();
191 clone->parent_ = parentPointer;
192 parentPointer->children.push_back(clone);
193 return clone;
194 }
195 }
197 /* ************************************************************************* */
198 template<class CLIQUE>
200 this->clear();
201 boost::shared_ptr<Clique> rootContainer = boost::make_shared<Clique>();
202 treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
203 for(const sharedClique& root: rootContainer->children) {
204 root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
205 insertRoot(root);
206 }
207 return *this;
209
210 /* ************************************************************************* */
211 template<class CLIQUE>
212 void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
213 std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
214 treeTraversal::PrintForest(*this, s, keyFormatter);
215 }
216
217 /* ************************************************************************* */
218 // binary predicate to test equality of a pair for use in equals
219 template<class CLIQUE>
220 bool check_sharedCliques(
221 const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
222 const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
223 ) {
224 return v1.first == v2.first &&
225 ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
227
228 /* ************************************************************************* */
229 template<class CLIQUE>
230 bool BayesTree<CLIQUE>::equals(const BayesTree<CLIQUE>& other, double tol) const {
231 return size()==other.size() &&
232 std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
233 }
234
235 /* ************************************************************************* */
236 template<class CLIQUE>
237 template<class CONTAINER>
238 Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
239 typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
240 assert(lowestOrderedParent != parents.end());
241 return *lowestOrderedParent;
242 }
243
244 /* ************************************************************************* */
245 template<class CLIQUE>
247 // Add each frontal variable of this root node
248 for(const Key& j: subtree->conditional()->frontals()) {
249 bool inserted = nodes_.insert(std::make_pair(j, subtree)).second;
250 assert(inserted); (void)inserted;
251 }
252 // Fill index for each child
254 for(const sharedClique& child: subtree->children) {
255 fillNodesIndex(child); }
257
258 /* ************************************************************************* */
259 template<class CLIQUE>
261 roots_.push_back(subtree); // Add to roots
262 fillNodesIndex(subtree); // Populate nodes index
263 }
264
265 /* ************************************************************************* */
266 // First finds clique marginal then marginalizes that
267 /* ************************************************************************* */
268 template<class CLIQUE>
269 typename BayesTree<CLIQUE>::sharedConditional
270 BayesTree<CLIQUE>::marginalFactor(Key j, const Eliminate& function) const
271 {
272 gttic(BayesTree_marginalFactor);
273
274 // get clique containing Key j
275 sharedClique clique = this->clique(j);
276
277 // calculate or retrieve its marginal P(C) = P(F,S)
278 FactorGraphType cliqueMarginal = clique->marginal2(function);
279
280 // Now, marginalize out everything that is not variable j
281 BayesNetType marginalBN =
282 *cliqueMarginal.marginalMultifrontalBayesNet(Ordering{j}, function);
283
284 // The Bayes net should contain only one conditional for variable j, so return it
285 return marginalBN.front();
286 }
287
288 /* ************************************************************************* */
289 // Find two cliques, their joint, then marginalizes
290 /* ************************************************************************* */
291 template<class CLIQUE>
292 typename BayesTree<CLIQUE>::sharedFactorGraph
293 BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
294 {
295 gttic(BayesTree_joint);
296 return boost::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
297 }
298
299 /* ************************************************************************* */
300 template<class CLIQUE>
301 typename BayesTree<CLIQUE>::sharedBayesNet
302 BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
303 {
304 gttic(BayesTree_jointBayesNet);
305 // get clique C1 and C2
306 sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
307
308 gttic(Lowest_common_ancestor);
309 // Find lowest common ancestor clique
310 sharedClique B; {
311 // Build two paths to the root
312 FastList<sharedClique> path1, path2; {
313 sharedClique p = C1;
314 while(p) {
315 path1.push_front(p);
316 p = p->parent();
317 }
318 } {
319 sharedClique p = C2;
320 while(p) {
321 path2.push_front(p);
322 p = p->parent();
323 }
324 }
325 // Find the path intersection
326 typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
327 if(*p1 == *p2)
328 B = *p1;
329 while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
330 B = *p1;
331 ++p1;
332 ++p2;
333 }
334 }
335 gttoc(Lowest_common_ancestor);
336
337 // Build joint on all involved variables
338 FactorGraphType p_BC1C2;
339
340 if(B)
341 {
342 // Compute marginal on lowest common ancestor clique
343 gttic(LCA_marginal);
344 FactorGraphType p_B = B->marginal2(function);
345 gttoc(LCA_marginal);
346
347 // Compute shortcuts of the requested cliques given the lowest common ancestor
348 gttic(Clique_shortcuts);
349 BayesNetType p_C1_Bred = C1->shortcut(B, function);
350 BayesNetType p_C2_Bred = C2->shortcut(B, function);
351 gttoc(Clique_shortcuts);
352
353 // Factor the shortcuts to be conditioned on the full root
354 // Get the set of variables to eliminate, which is C1\B.
355 gttic(Full_root_factoring);
356 boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
357 KeyVector C1_minus_B; {
358 KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
359 for(const Key j: *B->conditional()) {
360 C1_minus_B_set.erase(j); }
361 C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
362 }
363 // Factor into C1\B | B.
364 sharedFactorGraph temp_remaining;
365 boost::tie(p_C1_B, temp_remaining) =
366 FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function);
367 }
368 boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
369 KeyVector C2_minus_B; {
370 KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
371 for(const Key j: *B->conditional()) {
372 C2_minus_B_set.erase(j); }
373 C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
374 }
375 // Factor into C2\B | B.
376 sharedFactorGraph temp_remaining;
377 boost::tie(p_C2_B, temp_remaining) =
378 FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function);
379 }
380 gttoc(Full_root_factoring);
381
382 gttic(Variable_joint);
383 p_BC1C2 += p_B;
384 p_BC1C2 += *p_C1_B;
385 p_BC1C2 += *p_C2_B;
386 if(C1 != B)
387 p_BC1C2 += C1->conditional();
388 if(C2 != B)
389 p_BC1C2 += C2->conditional();
390 gttoc(Variable_joint);
391 }
392 else
393 {
394 // The nodes have no common ancestor, they're in different trees, so they're joint is just the
395 // product of their marginals.
396 gttic(Disjoint_marginals);
397 p_BC1C2 += C1->marginal2(function);
398 p_BC1C2 += C2->marginal2(function);
399 gttoc(Disjoint_marginals);
400 }
401
402 // now, marginalize out everything that is not variable j1 or j2
403 return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, function);
404 }
405
406 /* ************************************************************************* */
407 template<class CLIQUE>
409 // Remove all nodes and clear the root pointer
410 nodes_.clear();
411 roots_.clear();
412 }
413
414 /* ************************************************************************* */
415 template<class CLIQUE>
417 for(const sharedClique& root: roots_) {
418 root->deleteCachedShortcuts();
419 }
420 }
421
422 /* ************************************************************************* */
423 template<class CLIQUE>
425 {
426 if (clique->isRoot()) {
427 typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
428 if(root != roots_.end())
429 roots_.erase(root);
430 } else { // detach clique from parent
431 sharedClique parent = clique->parent_.lock();
432 typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
433 assert(child != parent->children.end());
434 parent->children.erase(child);
435 }
436
437 // orphan my children
438 for(sharedClique child: clique->children)
439 child->parent_ = typename Clique::weak_ptr();
440
441 for(Key j: clique->conditional()->frontals()) {
442 nodes_.unsafe_erase(j);
443 }
444 }
445
446 /* ************************************************************************* */
447 template <class CLIQUE>
448 void BayesTree<CLIQUE>::removePath(sharedClique clique, BayesNetType* bn,
449 Cliques* orphans) {
450 // base case is nullptr, if so we do nothing and return empties above
451 if (clique) {
452 // remove the clique from orphans in case it has been added earlier
453 orphans->remove(clique);
454
455 // remove me
456 this->removeClique(clique);
457
458 // remove path above me
459 this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn,
460 orphans);
461
462 // add children to list of orphans (splice also removed them from
463 // clique->children_)
464 orphans->insert(orphans->begin(), clique->children.begin(),
465 clique->children.end());
466 clique->children.clear();
467
468 bn->push_back(clique->conditional_);
469 }
470 }
471
472 /* *************************************************************************
473 */
474 template <class CLIQUE>
475 void BayesTree<CLIQUE>::removeTop(const KeyVector& keys, BayesNetType* bn,
476 Cliques* orphans) {
477 gttic(removetop);
478 // process each key of the new factor
479 for (const Key& j : keys) {
480 // get the clique
481 // TODO(frank): Nodes will be searched again in removeClique
482 typename Nodes::const_iterator node = nodes_.find(j);
483 if (node != nodes_.end()) {
484 // remove path from clique to root
485 this->removePath(node->second, bn, orphans);
486 }
487 }
488
489 // Delete cachedShortcuts for each orphan subtree
490 // TODO(frank): Consider Improving
491 for (sharedClique& orphan : *orphans) orphan->deleteCachedShortcuts();
492 }
493
494 /* ************************************************************************* */
495 template<class CLIQUE>
497 const sharedClique& subtree)
498 {
499 // Result clique list
500 Cliques cliques;
501 cliques.push_back(subtree);
502
503 // Remove the first clique from its parents
504 if(!subtree->isRoot())
505 subtree->parent()->children.erase(std::find(
506 subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
507 else
508 roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
509
510 // Add all subtree cliques and erase the children and parent of each
511 for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
512 {
513 // Add children
514 for(const sharedClique& child: (*clique)->children) {
515 cliques.push_back(child); }
516
517 // Delete cached shortcuts
518 (*clique)->deleteCachedShortcutsNonRecursive();
519
520 // Remove this node from the nodes index
521 for(Key j: (*clique)->conditional()->frontals()) {
522 nodes_.unsafe_erase(j); }
523
524 // Erase the parent and children pointers
525 (*clique)->parent_.reset();
526 (*clique)->children.clear();
527 }
528
529 return cliques;
530 }
531
532}
Timing utilities.
Bayes Tree is a tree of cliques of a Bayes Chain.
Variable ordering for the elimination algorithm.
Global functions in a separate testing namespace.
Definition chartTesting.h:28
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition Key.h:86
double dot(const V1 &a, const V2 &b)
Dot product.
Definition Vector.h:195
std::uint64_t Key
Integer nonlinear key type.
Definition types.h:100
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition Key.h:35
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
Traverse a forest depth-first with pre-order and post-order visits.
Definition treeTraversal-inst.h:77
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
Print a tree, prefixing each line with str, and formatting keys using keyFormatter.
Definition treeTraversal-inst.h:219
FastList is a thin wrapper around std::list that uses the boost fast_pool_allocator instead of the de...
Definition FastList.h:40
A factor graph is a bipartite graph with factor nodes connected to variable nodes.
Definition FactorGraph.h:97
store all the sizes
Definition BayesTree.h:48
Bayes tree.
Definition BayesTree.h:67
Nodes nodes_
Map from indices to Clique.
Definition BayesTree.h:100
void removeClique(sharedClique clique)
remove a clique: warning, can result in a forest
Definition BayesTree-inst.h:424
sharedFactorGraph joint(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables Limitation: can only calculate joint if cliques are disjoint or one of ...
Definition BayesTree-inst.h:293
void fillNodesIndex(const sharedClique &subtree)
Fill the nodes index for a subtree.
Definition BayesTree-inst.h:246
void dot(std::ostream &os, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Output to graphviz format, stream version.
Definition BayesTree-inst.h:64
void addFactorsToGraph(FactorGraph< FactorType > *graph) const
Add all cliques in this BayesTree to the specified factor graph.
Definition BayesTree-inst.h:168
bool equals(const This &other, double tol=1e-9) const
check equality
Definition BayesTree-inst.h:230
This & operator=(const This &other)
Assignment operator.
Definition BayesTree-inst.h:199
boost::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition BayesTree.h:74
BayesTree()
Create an empty Bayes Tree.
Definition BayesTree.h:109
void clear()
Remove all nodes.
Definition BayesTree-inst.h:408
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
add a clique (top down)
Definition BayesTree-inst.h:142
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables as a BayesNet Limitation: can only calculate joint if cliques are disjo...
Definition BayesTree-inst.h:302
Key findParentClique(const CONTAINER &parents) const
Find parent clique of a conditional.
Definition BayesTree-inst.h:238
size_t size() const
number of cliques
Definition BayesTree-inst.h:133
void deleteCachedShortcuts()
Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data.
Definition BayesTree-inst.h:416
void removePath(sharedClique clique, BayesNetType *bn, Cliques *orphans)
Remove path from clique to root and return that path as factors plus a list of orphaned subtree roots...
Definition BayesTree-inst.h:448
sharedConditional marginalFactor(Key j, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
Return marginal on any variable.
Definition BayesTree-inst.h:270
size_t numCachedSeparatorMarginals() const
Collect number of cliques with cached separator marginals.
Definition BayesTree-inst.h:55
BayesTreeCliqueData getCliqueData() const
Gather data on all cliques.
Definition BayesTree-inst.h:35
Cliques removeSubtree(const sharedClique &subtree)
Remove the requested subtree.
Definition BayesTree-inst.h:496
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
print
Definition BayesTree-inst.h:212
void insertRoot(const sharedClique &subtree)
Insert a new subtree with known parent clique.
Definition BayesTree-inst.h:260
void saveGraph(const std::string &filename, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
output to file with graphviz format.
Definition BayesTree-inst.h:85
void removeTop(const KeyVector &keys, BayesNetType *bn, Cliques *orphans)
Given a list of indices, turn "contaminated" part of the tree back into a factor graph.
Definition BayesTree-inst.h:475
Definition Ordering.h:34