Halide 17.0.1
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
150 const uint16_t flags = ty.lanes & MatcherState::special_values_mask;
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(e, lanes);
187 }
188 return e;
189}
190
191bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192
193// A fast version of expression equality that assumes a well-typed non-null expression tree.
195bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196 // Early out
197 return (&a == &b) ||
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
200 equal_helper(a, b));
201}
202
203// A pattern that matches a specific expression
205 struct pattern_tag {};
206
207 constexpr static uint32_t binds = 0;
208
209 // What is the weakest and strongest IR node this could possibly be
212 constexpr static bool canonical = true;
213
215
216 template<uint32_t bound>
217 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218 return equal(expr, e);
219 }
220
223 return Expr(&expr);
224 }
225
226 constexpr static bool foldable = false;
227};
228
229inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230 s << Expr(&e.expr);
231 return s;
232}
233
234template<int i>
236 struct pattern_tag {};
237
238 constexpr static uint32_t binds = 1 << i;
239
242 constexpr static bool canonical = true;
243
244 template<uint32_t bound>
245 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247 const BaseExprNode *op = &e;
248 if (op->node_type == IRNodeType::Broadcast) {
249 op = ((const Broadcast *)op)->value.get();
250 }
251 if (op->node_type != IRNodeType::IntImm) {
252 return false;
253 }
254 int64_t value = ((const IntImm *)op)->value;
255 if (bound & binds) {
257 halide_type_t type;
258 state.get_bound_const(i, val, type);
259 return (halide_type_t)e.type == type && value == val.u.i64;
260 }
261 state.set_bound_const(i, value, e.type);
262 return true;
263 }
264
265 template<uint32_t bound>
266 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268 if (bound & binds) {
270 halide_type_t type;
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.u.i64;
273 }
274 state.set_bound_const(i, value, i64_type);
275 return true;
276 }
277
281 halide_type_t type;
282 state.get_bound_const(i, val, type);
283 return make_const_expr(val, type);
284 }
285
286 constexpr static bool foldable = true;
287
290 state.get_bound_const(i, val, ty);
291 }
292};
293
294template<int i>
295std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296 s << "ci" << i;
297 return s;
298}
299
300template<int i>
302 struct pattern_tag {};
303
304 constexpr static uint32_t binds = 1 << i;
305
308 constexpr static bool canonical = true;
309
310 template<uint32_t bound>
311 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313 const BaseExprNode *op = &e;
314 if (op->node_type == IRNodeType::Broadcast) {
315 op = ((const Broadcast *)op)->value.get();
316 }
317 if (op->node_type != IRNodeType::UIntImm) {
318 return false;
319 }
320 uint64_t value = ((const UIntImm *)op)->value;
321 if (bound & binds) {
323 halide_type_t type;
324 state.get_bound_const(i, val, type);
325 return (halide_type_t)e.type == type && value == val.u.u64;
326 }
327 state.set_bound_const(i, value, e.type);
328 return true;
329 }
330
334 halide_type_t type;
335 state.get_bound_const(i, val, type);
336 return make_const_expr(val, type);
337 }
338
339 constexpr static bool foldable = true;
340
343 state.get_bound_const(i, val, ty);
344 }
345};
346
347template<int i>
348std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349 s << "cu" << i;
350 return s;
351}
352
353template<int i>
355 struct pattern_tag {};
356
357 constexpr static uint32_t binds = 1 << i;
358
361 constexpr static bool canonical = true;
362
363 template<uint32_t bound>
364 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366 const BaseExprNode *op = &e;
367 if (op->node_type == IRNodeType::Broadcast) {
368 op = ((const Broadcast *)op)->value.get();
369 }
370 if (op->node_type != IRNodeType::FloatImm) {
371 return false;
372 }
373 double value = ((const FloatImm *)op)->value;
374 if (bound & binds) {
376 halide_type_t type;
377 state.get_bound_const(i, val, type);
378 return (halide_type_t)e.type == type && value == val.u.f64;
379 }
380 state.set_bound_const(i, value, e.type);
381 return true;
382 }
383
387 halide_type_t type;
388 state.get_bound_const(i, val, type);
389 return make_const_expr(val, type);
390 }
391
392 constexpr static bool foldable = true;
393
396 state.get_bound_const(i, val, ty);
397 }
398};
399
400template<int i>
401std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402 s << "cf" << i;
403 return s;
404}
405
406// Matches and binds to any constant Expr. Does not support constant-folding.
407template<int i>
408struct WildConst {
409 struct pattern_tag {};
410
411 constexpr static uint32_t binds = 1 << i;
412
415 constexpr static bool canonical = true;
416
417 template<uint32_t bound>
418 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420 const BaseExprNode *op = &e;
421 if (op->node_type == IRNodeType::Broadcast) {
422 op = ((const Broadcast *)op)->value.get();
423 }
424 switch (op->node_type) {
426 return WildConstInt<i>().template match<bound>(e, state);
428 return WildConstUInt<i>().template match<bound>(e, state);
430 return WildConstFloat<i>().template match<bound>(e, state);
431 default:
432 return false;
433 }
434 }
435
436 template<uint32_t bound>
437 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439 return WildConstInt<i>().template match<bound>(e, state);
440 }
441
445 halide_type_t type;
446 state.get_bound_const(i, val, type);
447 return make_const_expr(val, type);
448 }
449
450 constexpr static bool foldable = true;
451
454 state.get_bound_const(i, val, ty);
455 }
456};
457
458template<int i>
459std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460 s << "c" << i;
461 return s;
462}
463
464// Matches and binds to any Expr
465template<int i>
466struct Wild {
467 struct pattern_tag {};
468
469 constexpr static uint32_t binds = 1 << (i + 16);
470
473 constexpr static bool canonical = true;
474
475 template<uint32_t bound>
476 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477 if (bound & binds) {
478 return equal(*state.get_binding(i), e);
479 }
480 state.set_binding(i, e);
481 return true;
482 }
483
486 return state.get_binding(i);
487 }
488
489 constexpr static bool foldable = false;
490};
491
492template<int i>
493std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
494 s << "_" << i;
495 return s;
496}
497
498// Matches a specific constant or broadcast of that constant. The
499// constant must be representable as an int64_t.
501 struct pattern_tag {};
503
504 constexpr static uint32_t binds = 0;
505
508 constexpr static bool canonical = true;
509
512 : v(v) {
513 }
514
515 template<uint32_t bound>
516 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
517 const BaseExprNode *op = &e;
518 if (e.node_type == IRNodeType::Broadcast) {
519 op = ((const Broadcast *)op)->value.get();
520 }
521 switch (op->node_type) {
523 return ((const IntImm *)op)->value == (int64_t)v;
525 return ((const UIntImm *)op)->value == (uint64_t)v;
527 return ((const FloatImm *)op)->value == (double)v;
528 default:
529 return false;
530 }
531 }
532
533 template<uint32_t bound>
534 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
535 return v == val;
536 }
537
538 template<uint32_t bound>
539 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
540 return v == b.v;
541 }
542
545 return make_const(type_hint, v);
546 }
547
548 constexpr static bool foldable = true;
549
552 // Assume type is already correct
553 switch (ty.code) {
554 case halide_type_int:
555 val.u.i64 = v;
556 break;
557 case halide_type_uint:
558 val.u.u64 = (uint64_t)v;
559 break;
562 val.u.f64 = (double)v;
563 break;
564 default:
565 // Unreachable
566 ;
567 }
568 }
569};
570
574
575// Convert a provided pattern, expr, or constant int into the internal
576// representation we use in the matcher trees.
577template<typename T,
578 typename = typename std::decay<T>::type::pattern_tag>
580 return t;
581}
584 return IntLiteral{x};
585}
586
587template<typename T>
589 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
590 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
591}
592
594 return {*e.get()};
595}
596
597// Helpers to deref SpecificExprs to const BaseExprNode & rather than
598// passing them by value anywhere (incurring lots of refcounting)
599template<typename T,
600 // T must be a pattern node
601 typename = typename std::decay<T>::type::pattern_tag,
602 // But T may not be SpecificExpr
603 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
605 return t;
606}
607
610 return e.expr;
611}
612
613inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
614 s << op.v;
615 return s;
616}
617
618template<typename Op>
620
621template<typename Op>
623
624template<typename Op>
625double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
626
627constexpr bool commutative(IRNodeType t) {
628 return (t == IRNodeType::Add ||
629 t == IRNodeType::Mul ||
630 t == IRNodeType::And ||
631 t == IRNodeType::Or ||
632 t == IRNodeType::Min ||
633 t == IRNodeType::Max ||
634 t == IRNodeType::EQ ||
635 t == IRNodeType::NE);
636}
637
638// Matches one of the binary operators
639template<typename Op, typename A, typename B>
640struct BinOp {
641 struct pattern_tag {};
642 A a;
643 B b;
644
646
647 constexpr static IRNodeType min_node_type = Op::_node_type;
648 constexpr static IRNodeType max_node_type = Op::_node_type;
649
650 // For commutative bin ops, we expect the weaker IR node type on
651 // the right. That is, for the rule to be canonical it must be
652 // possible that A is at least as strong as B.
653 constexpr static bool canonical =
654 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
655
656 template<uint32_t bound>
657 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
658 if (e.node_type != Op::_node_type) {
659 return false;
660 }
661 const Op &op = (const Op &)e;
662 return (a.template match<bound>(*op.a.get(), state) &&
663 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
664 }
665
666 template<uint32_t bound, typename Op2, typename A2, typename B2>
667 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
668 return (std::is_same<Op, Op2>::value &&
669 a.template match<bound>(unwrap(op.a), state) &&
670 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
671 }
672
673 constexpr static bool foldable = A::foldable && B::foldable;
674
678 if (std::is_same<A, IntLiteral>::value) {
679 b.make_folded_const(val_b, ty, state);
680 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
681 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
682 // Short circuit
683 val = val_b;
684 return;
685 }
686 const uint16_t l = ty.lanes;
687 a.make_folded_const(val_a, ty, state);
688 ty.lanes |= l; // Make sure the overflow bits are sticky
689 } else {
690 a.make_folded_const(val_a, ty, state);
691 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
692 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
693 // Short circuit
694 val = val_a;
695 return;
696 }
697 const uint16_t l = ty.lanes;
698 b.make_folded_const(val_b, ty, state);
699 ty.lanes |= l;
700 }
701 switch (ty.code) {
702 case halide_type_int:
703 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
704 break;
705 case halide_type_uint:
706 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
707 break;
710 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
711 break;
712 default:
713 // unreachable
714 ;
715 }
716 }
717
719 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
720 Expr ea, eb;
721 if (std::is_same<A, IntLiteral>::value) {
722 eb = b.make(state, type_hint);
723 ea = a.make(state, eb.type());
724 } else {
725 ea = a.make(state, type_hint);
726 eb = b.make(state, ea.type());
727 }
728 // We sometimes mix vectors and scalars in the rewrite rules,
729 // so insert a broadcast if necessary.
730 if (ea.type().is_vector() && !eb.type().is_vector()) {
731 eb = Broadcast::make(eb, ea.type().lanes());
732 }
733 if (eb.type().is_vector() && !ea.type().is_vector()) {
734 ea = Broadcast::make(ea, eb.type().lanes());
735 }
736 return Op::make(std::move(ea), std::move(eb));
737 }
738};
739
740template<typename Op>
742
743template<typename Op>
745
746template<typename Op>
747uint64_t constant_fold_cmp_op(double, double) noexcept;
748
749// Matches one of the comparison operators
750template<typename Op, typename A, typename B>
751struct CmpOp {
752 struct pattern_tag {};
753 A a;
754 B b;
755
757
758 constexpr static IRNodeType min_node_type = Op::_node_type;
759 constexpr static IRNodeType max_node_type = Op::_node_type;
760 constexpr static bool canonical = (A::canonical &&
761 B::canonical &&
762 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
763 (Op::_node_type != IRNodeType::GE) &&
764 (Op::_node_type != IRNodeType::GT));
765
766 template<uint32_t bound>
767 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
768 if (e.node_type != Op::_node_type) {
769 return false;
770 }
771 const Op &op = (const Op &)e;
772 return (a.template match<bound>(*op.a.get(), state) &&
773 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
774 }
775
776 template<uint32_t bound, typename Op2, typename A2, typename B2>
777 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
778 return (std::is_same<Op, Op2>::value &&
779 a.template match<bound>(unwrap(op.a), state) &&
780 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
781 }
782
783 constexpr static bool foldable = A::foldable && B::foldable;
784
788 // If one side is an untyped const, evaluate the other side first to get a type hint.
789 if (std::is_same<A, IntLiteral>::value) {
790 b.make_folded_const(val_b, ty, state);
791 const uint16_t l = ty.lanes;
792 a.make_folded_const(val_a, ty, state);
793 ty.lanes |= l;
794 } else {
795 a.make_folded_const(val_a, ty, state);
796 const uint16_t l = ty.lanes;
797 b.make_folded_const(val_b, ty, state);
798 ty.lanes |= l;
799 }
800 switch (ty.code) {
801 case halide_type_int:
802 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
803 break;
804 case halide_type_uint:
805 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
806 break;
809 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
810 break;
811 default:
812 // unreachable
813 ;
814 }
815 ty.code = halide_type_uint;
816 ty.bits = 1;
817 }
818
821 // If one side is an untyped const, evaluate the other side first to get a type hint.
822 Expr ea, eb;
823 if (std::is_same<A, IntLiteral>::value) {
824 eb = b.make(state, {});
825 ea = a.make(state, eb.type());
826 } else {
827 ea = a.make(state, {});
828 eb = b.make(state, ea.type());
829 }
830 // We sometimes mix vectors and scalars in the rewrite rules,
831 // so insert a broadcast if necessary.
832 if (ea.type().is_vector() && !eb.type().is_vector()) {
833 eb = Broadcast::make(eb, ea.type().lanes());
834 }
835 if (eb.type().is_vector() && !ea.type().is_vector()) {
836 ea = Broadcast::make(ea, eb.type().lanes());
837 }
838 return Op::make(std::move(ea), std::move(eb));
839 }
840};
841
842template<typename A, typename B>
843std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844 s << "(" << op.a << " + " << op.b << ")";
845 return s;
846}
847
848template<typename A, typename B>
849std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850 s << "(" << op.a << " - " << op.b << ")";
851 return s;
852}
853
854template<typename A, typename B>
855std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856 s << "(" << op.a << " * " << op.b << ")";
857 return s;
858}
859
860template<typename A, typename B>
861std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862 s << "(" << op.a << " / " << op.b << ")";
863 return s;
864}
865
866template<typename A, typename B>
867std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868 s << "(" << op.a << " && " << op.b << ")";
869 return s;
870}
871
872template<typename A, typename B>
873std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874 s << "(" << op.a << " || " << op.b << ")";
875 return s;
876}
877
878template<typename A, typename B>
879std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880 s << "min(" << op.a << ", " << op.b << ")";
881 return s;
882}
883
884template<typename A, typename B>
885std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886 s << "max(" << op.a << ", " << op.b << ")";
887 return s;
888}
889
890template<typename A, typename B>
891std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892 s << "(" << op.a << " <= " << op.b << ")";
893 return s;
894}
895
896template<typename A, typename B>
897std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898 s << "(" << op.a << " < " << op.b << ")";
899 return s;
900}
901
902template<typename A, typename B>
903std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904 s << "(" << op.a << " >= " << op.b << ")";
905 return s;
906}
907
908template<typename A, typename B>
909std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910 s << "(" << op.a << " > " << op.b << ")";
911 return s;
912}
913
914template<typename A, typename B>
915std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916 s << "(" << op.a << " == " << op.b << ")";
917 return s;
918}
919
920template<typename A, typename B>
921std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922 s << "(" << op.a << " != " << op.b << ")";
923 return s;
924}
925
926template<typename A, typename B>
927std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928 s << "(" << op.a << " % " << op.b << ")";
929 return s;
930}
931
932template<typename A, typename B>
933HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
936 return {pattern_arg(a), pattern_arg(b)};
937}
938
939template<typename A, typename B>
940HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
943 return IRMatcher::operator+(a, b);
944}
945
946template<>
948 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
949 int dead_bits = 64 - t.bits;
950 // Drop the high bits then sign-extend them back
951 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
952}
953
954template<>
956 uint64_t ones = (uint64_t)(-1);
957 return (a + b) & (ones >> (64 - t.bits));
958}
959
960template<>
961HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
962 return a + b;
963}
964
965template<typename A, typename B>
966HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
969 return {pattern_arg(a), pattern_arg(b)};
970}
971
972template<typename A, typename B>
973HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
976 return IRMatcher::operator-(a, b);
977}
978
979template<>
981 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
982 // Drop the high bits then sign-extend them back
983 int dead_bits = 64 - t.bits;
984 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
985}
986
987template<>
989 uint64_t ones = (uint64_t)(-1);
990 return (a - b) & (ones >> (64 - t.bits));
991}
992
993template<>
994HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
995 return a - b;
996}
997
998template<typename A, typename B>
999HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1002 return {pattern_arg(a), pattern_arg(b)};
1003}
1004
1005template<typename A, typename B>
1006HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1009 return IRMatcher::operator*(a, b);
1010}
1011
1012template<>
1014 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1015 int dead_bits = 64 - t.bits;
1016 // Drop the high bits then sign-extend them back
1017 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1018}
1019
1020template<>
1022 uint64_t ones = (uint64_t)(-1);
1023 return (a * b) & (ones >> (64 - t.bits));
1024}
1025
1026template<>
1027HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1028 return a * b;
1029}
1030
1031template<typename A, typename B>
1032HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1035 return {pattern_arg(a), pattern_arg(b)};
1036}
1037
1038template<typename A, typename B>
1039HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1040 return IRMatcher::operator/(a, b);
1041}
1042
1043template<>
1047
1048template<>
1052
1053template<>
1054HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1055 return div_imp(a, b);
1056}
1057
1058template<typename A, typename B>
1059HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1062 return {pattern_arg(a), pattern_arg(b)};
1063}
1064
1065template<typename A, typename B>
1066HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1069 return IRMatcher::operator%(a, b);
1070}
1071
1072template<>
1076
1077template<>
1081
1082template<>
1083HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1084 return mod_imp(a, b);
1085}
1086
1087template<typename A, typename B>
1088HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1091 return {pattern_arg(a), pattern_arg(b)};
1092}
1093
1094template<>
1096 return std::min(a, b);
1097}
1098
1099template<>
1101 return std::min(a, b);
1102}
1103
1104template<>
1105HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1106 return std::min(a, b);
1107}
1108
1109template<typename A, typename B>
1110HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1113 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1114}
1115
1116template<>
1118 return std::max(a, b);
1119}
1120
1121template<>
1123 return std::max(a, b);
1124}
1125
1126template<>
1127HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1128 return std::max(a, b);
1129}
1130
1131template<typename A, typename B>
1132HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1133 return {pattern_arg(a), pattern_arg(b)};
1134}
1135
1136template<typename A, typename B>
1137HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1138 return IRMatcher::operator<(a, b);
1139}
1140
1141template<>
1145
1146template<>
1150
1151template<>
1153 return a < b;
1154}
1155
1156template<typename A, typename B>
1157HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1158 return {pattern_arg(a), pattern_arg(b)};
1159}
1160
1161template<typename A, typename B>
1162HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1163 return IRMatcher::operator>(a, b);
1164}
1165
1166template<>
1170
1171template<>
1175
1176template<>
1178 return a > b;
1179}
1180
1181template<typename A, typename B>
1182HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1183 return {pattern_arg(a), pattern_arg(b)};
1184}
1185
1186template<typename A, typename B>
1187HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1188 return IRMatcher::operator<=(a, b);
1189}
1190
1191template<>
1193 return a <= b;
1194}
1195
1196template<>
1200
1201template<>
1203 return a <= b;
1204}
1205
1206template<typename A, typename B>
1207HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1208 return {pattern_arg(a), pattern_arg(b)};
1209}
1210
1211template<typename A, typename B>
1212HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1213 return IRMatcher::operator>=(a, b);
1214}
1215
1216template<>
1218 return a >= b;
1219}
1220
1221template<>
1225
1226template<>
1228 return a >= b;
1229}
1230
1231template<typename A, typename B>
1232HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1233 return {pattern_arg(a), pattern_arg(b)};
1234}
1235
1236template<typename A, typename B>
1237HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1238 return IRMatcher::operator==(a, b);
1239}
1240
1241template<>
1243 return a == b;
1244}
1245
1246template<>
1250
1251template<>
1253 return a == b;
1254}
1255
1256template<typename A, typename B>
1257HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1258 return {pattern_arg(a), pattern_arg(b)};
1259}
1260
1261template<typename A, typename B>
1262HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1263 return IRMatcher::operator!=(a, b);
1264}
1265
1266template<>
1268 return a != b;
1269}
1270
1271template<>
1275
1276template<>
1278 return a != b;
1279}
1280
1281template<typename A, typename B>
1282HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1283 return {pattern_arg(a), pattern_arg(b)};
1284}
1285
1286template<typename A, typename B>
1287HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1288 return IRMatcher::operator||(a, b);
1289}
1290
1291template<>
1293 return (a | b) & 1;
1294}
1295
1296template<>
1298 return (a | b) & 1;
1299}
1300
1301template<>
1302HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1303 // Unreachable, as it would be a type mismatch.
1304 return 0;
1305}
1306
1307template<typename A, typename B>
1308HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1309 return {pattern_arg(a), pattern_arg(b)};
1310}
1311
1312template<typename A, typename B>
1313HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1314 return IRMatcher::operator&&(a, b);
1315}
1316
1317template<>
1319 return a & b & 1;
1320}
1321
1322template<>
1326
1327template<>
1328HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1329 // Unreachable
1330 return 0;
1331}
1332
1333constexpr inline uint32_t bitwise_or_reduce() {
1334 return 0;
1335}
1336
1337template<typename... Args>
1338constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1339 return first | bitwise_or_reduce(rest...);
1340}
1341
1342constexpr inline bool and_reduce() {
1343 return true;
1344}
1345
1346template<typename... Args>
1347constexpr bool and_reduce(bool first, Args... rest) {
1348 return first && and_reduce(rest...);
1349}
1350
1351// TODO: this can be replaced with std::min() once we require C++14 or later
1352constexpr int const_min(int a, int b) {
1353 return a < b ? a : b;
1354}
1355
1356template<typename... Args>
1357struct Intrin {
1358 struct pattern_tag {};
1360 std::tuple<Args...> args;
1361 // The type of the output of the intrinsic node.
1362 // Only necessary in cases where it can't be inferred
1363 // from the input types (e.g. saturating_cast).
1365
1367
1370 constexpr static bool canonical = and_reduce((Args::canonical)...);
1371
1372 template<int i,
1373 uint32_t bound,
1374 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1375 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1376 using T = decltype(std::get<i>(args));
1377 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1379 }
1380
1381 template<int i, uint32_t binds>
1382 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1383 return true;
1384 }
1385
1386 template<uint32_t bound>
1387 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1388 if (e.node_type != IRNodeType::Call) {
1389 return false;
1390 }
1391 const Call &c = (const Call &)e;
1392 return (c.is_intrinsic(intrin) &&
1393 ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1394 match_args<0, bound>(0, c, state));
1395 }
1396
1397 template<int i,
1398 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1399 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1401 if (i + 1 < sizeof...(Args)) {
1402 s << ", ";
1403 }
1404 print_args<i + 1>(0, s);
1405 }
1406
1407 template<int i>
1408 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1409 }
1410
1412 void print_args(std::ostream &s) const {
1413 print_args<0>(0, s);
1414 }
1415
1418 Expr arg0 = std::get<0>(args).make(state, type_hint);
1419 if (intrin == Call::likely) {
1420 return likely(arg0);
1421 } else if (intrin == Call::likely_if_innermost) {
1422 return likely_if_innermost(arg0);
1423 } else if (intrin == Call::abs) {
1424 return abs(arg0);
1425 } else if (intrin == Call::saturating_cast) {
1427 }
1428
1429 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1430 if (intrin == Call::absd) {
1431 return absd(arg0, arg1);
1432 } else if (intrin == Call::widen_right_add) {
1433 return widen_right_add(arg0, arg1);
1434 } else if (intrin == Call::widen_right_mul) {
1435 return widen_right_mul(arg0, arg1);
1436 } else if (intrin == Call::widen_right_sub) {
1437 return widen_right_sub(arg0, arg1);
1438 } else if (intrin == Call::widening_add) {
1439 return widening_add(arg0, arg1);
1440 } else if (intrin == Call::widening_sub) {
1441 return widening_sub(arg0, arg1);
1442 } else if (intrin == Call::widening_mul) {
1443 return widening_mul(arg0, arg1);
1444 } else if (intrin == Call::saturating_add) {
1445 return saturating_add(arg0, arg1);
1446 } else if (intrin == Call::saturating_sub) {
1447 return saturating_sub(arg0, arg1);
1448 } else if (intrin == Call::halving_add) {
1449 return halving_add(arg0, arg1);
1450 } else if (intrin == Call::halving_sub) {
1451 return halving_sub(arg0, arg1);
1452 } else if (intrin == Call::rounding_halving_add) {
1454 } else if (intrin == Call::shift_left) {
1455 return arg0 << arg1;
1456 } else if (intrin == Call::shift_right) {
1457 return arg0 >> arg1;
1458 } else if (intrin == Call::rounding_shift_left) {
1459 return rounding_shift_left(arg0, arg1);
1460 } else if (intrin == Call::rounding_shift_right) {
1462 }
1463
1464 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1466 return mul_shift_right(arg0, arg1, arg2);
1467 } else if (intrin == Call::rounding_mul_shift_right) {
1469 }
1470
1471 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1472 return Expr();
1473 }
1474
1475 constexpr static bool foldable = true;
1476
1479 // Assuming the args have the same type as the intrinsic is incorrect in
1480 // general. But for the intrinsics we can fold (just shifts), the LHS
1481 // has the same type as the intrinsic, and we can always treat the RHS
1482 // as a signed int, because we're using 64 bits for it.
1483 std::get<0>(args).make_folded_const(val, ty, state);
1486 // We can just directly get the second arg here, because we only want to
1487 // instantiate this method for shifts, which have two args.
1488 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1489
1490 if (intrin == Call::shift_left) {
1491 if (arg1.u.i64 < 0) {
1492 if (ty.code == halide_type_int) {
1493 // Arithmetic shift
1494 val.u.i64 >>= -arg1.u.i64;
1495 } else {
1496 // Logical shift
1497 val.u.u64 >>= -arg1.u.i64;
1498 }
1499 } else {
1500 val.u.u64 <<= arg1.u.i64;
1501 }
1502 } else if (intrin == Call::shift_right) {
1503 if (arg1.u.i64 > 0) {
1504 if (ty.code == halide_type_int) {
1505 // Arithmetic shift
1506 val.u.i64 >>= arg1.u.i64;
1507 } else {
1508 // Logical shift
1509 val.u.u64 >>= arg1.u.i64;
1510 }
1511 } else {
1512 val.u.u64 <<= -arg1.u.i64;
1513 }
1514 } else {
1515 internal_error << "Folding not implemented for intrinsic: " << intrin;
1516 }
1517 }
1518
1523};
1524
1525template<typename... Args>
1526std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1527 s << op.intrin << "(";
1528 op.print_args(s);
1529 s << ")";
1530 return s;
1531}
1532
1533template<typename... Args>
1534HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1535 return {intrinsic_op, pattern_arg(args)...};
1536}
1537
1538template<typename A, typename B>
1539auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1541}
1542template<typename A, typename B>
1543auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1545}
1546template<typename A, typename B>
1547auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1549}
1550
1551template<typename A, typename B>
1552auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1554}
1555template<typename A, typename B>
1556auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1558}
1559template<typename A, typename B>
1560auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1562}
1563template<typename A, typename B>
1564auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1566}
1567template<typename A, typename B>
1568auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1570}
1571template<typename A>
1572auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1575 return p;
1576}
1577template<typename A, typename B>
1578auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1579 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1580}
1581template<typename A, typename B>
1582auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1583 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1584}
1585template<typename A, typename B>
1586auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1588}
1589template<typename A, typename B>
1590auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1591 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1592}
1593template<typename A, typename B>
1594auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1595 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1596}
1597template<typename A, typename B>
1598auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1600}
1601template<typename A, typename B>
1602auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1604}
1605template<typename A, typename B, typename C>
1606auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1608}
1609template<typename A, typename B, typename C>
1610auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1612}
1613
1614template<typename A>
1615struct NotOp {
1616 struct pattern_tag {};
1617 A a;
1618
1620
1623 constexpr static bool canonical = A::canonical;
1624
1625 template<uint32_t bound>
1626 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1627 if (e.node_type != IRNodeType::Not) {
1628 return false;
1629 }
1630 const Not &op = (const Not &)e;
1631 return (a.template match<bound>(*op.a.get(), state));
1632 }
1633
1634 template<uint32_t bound, typename A2>
1635 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1636 return a.template match<bound>(unwrap(op.a), state);
1637 }
1638
1641 return Not::make(a.make(state, type_hint));
1642 }
1643
1644 constexpr static bool foldable = A::foldable;
1645
1646 template<typename A1 = A>
1648 a.make_folded_const(val, ty, state);
1649 val.u.u64 = ~val.u.u64;
1650 val.u.u64 &= 1;
1651 }
1652};
1653
1654template<typename A>
1655HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1657 return {pattern_arg(a)};
1658}
1659
1660template<typename A>
1665
1666template<typename A>
1667inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1668 s << "!(" << op.a << ")";
1669 return s;
1670}
1671
1672template<typename C, typename T, typename F>
1673struct SelectOp {
1674 struct pattern_tag {};
1676 T t;
1678
1680
1683
1684 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1685
1686 template<uint32_t bound>
1687 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1688 if (e.node_type != Select::_node_type) {
1689 return false;
1690 }
1691 const Select &op = (const Select &)e;
1692 return (c.template match<bound>(*op.condition.get(), state) &&
1693 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1694 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1695 }
1696 template<uint32_t bound, typename C2, typename T2, typename F2>
1697 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1698 return (c.template match<bound>(unwrap(instance.c), state) &&
1699 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1700 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1701 }
1702
1705 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1706 }
1707
1708 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1709
1710 template<typename C1 = C>
1714 c.make_folded_const(c_val, c_ty, state);
1715 if ((c_val.u.u64 & 1) == 1) {
1716 t.make_folded_const(val, ty, state);
1717 } else {
1718 f.make_folded_const(val, ty, state);
1719 }
1721 }
1722};
1723
1724template<typename C, typename T, typename F>
1725std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1726 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1727 return s;
1728}
1729
1730template<typename C, typename T, typename F>
1731HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1735 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1736}
1737
1738template<typename A, typename B>
1740 struct pattern_tag {};
1741 A a;
1743
1745
1748
1749 constexpr static bool canonical = A::canonical && B::canonical;
1750
1751 template<uint32_t bound>
1752 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1753 if (e.node_type == Broadcast::_node_type) {
1754 const Broadcast &op = (const Broadcast &)e;
1755 if (a.template match<bound>(*op.value.get(), state) &&
1756 lanes.template match<bound>(op.lanes, state)) {
1757 return true;
1758 }
1759 }
1760 return false;
1761 }
1762
1763 template<uint32_t bound, typename A2, typename B2>
1764 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1765 return (a.template match<bound>(unwrap(op.a), state) &&
1766 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1767 }
1768
1773 lanes.make_folded_const(lanes_val, ty, state);
1774 int32_t l = (int32_t)lanes_val.u.i64;
1775 type_hint.lanes /= l;
1776 Expr val = a.make(state, type_hint);
1777 if (l == 1) {
1778 return val;
1779 } else {
1780 return Broadcast::make(std::move(val), l);
1781 }
1782 }
1783
1784 constexpr static bool foldable = false;
1785
1786 template<typename A1 = A>
1790 lanes.make_folded_const(lanes_val, lanes_ty, state);
1791 uint16_t l = (uint16_t)lanes_val.u.i64;
1792 a.make_folded_const(val, ty, state);
1793 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1794 }
1795};
1796
1797template<typename A, typename B>
1798inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1799 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1800 return s;
1801}
1802
1803template<typename A, typename B>
1804HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1806 return {pattern_arg(a), pattern_arg(lanes)};
1807}
1808
1809template<typename A, typename B, typename C>
1810struct RampOp {
1811 struct pattern_tag {};
1812 A a;
1813 B b;
1815
1817
1820
1821 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1822
1823 template<uint32_t bound>
1824 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1825 if (e.node_type != Ramp::_node_type) {
1826 return false;
1827 }
1828 const Ramp &op = (const Ramp &)e;
1829 if (a.template match<bound>(*op.base.get(), state) &&
1830 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1831 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1832 return true;
1833 } else {
1834 return false;
1835 }
1836 }
1837
1838 template<uint32_t bound, typename A2, typename B2, typename C2>
1839 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1840 return (a.template match<bound>(unwrap(op.a), state) &&
1841 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1842 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1843 }
1844
1849 lanes.make_folded_const(lanes_val, ty, state);
1850 int32_t l = (int32_t)lanes_val.u.i64;
1851 type_hint.lanes /= l;
1852 Expr ea, eb;
1853 eb = b.make(state, type_hint);
1854 ea = a.make(state, eb.type());
1855 return Ramp::make(ea, eb, l);
1856 }
1857
1858 constexpr static bool foldable = false;
1859};
1860
1861template<typename A, typename B, typename C>
1862std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1863 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1864 return s;
1865}
1866
1867template<typename A, typename B, typename C>
1868HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1872 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1873}
1874
1875template<typename A, typename B, VectorReduce::Operator reduce_op>
1877 struct pattern_tag {};
1878 A a;
1880
1882
1885 constexpr static bool canonical = A::canonical;
1886
1887 template<uint32_t bound>
1888 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1889 if (e.node_type == VectorReduce::_node_type) {
1890 const VectorReduce &op = (const VectorReduce &)e;
1891 if (op.op == reduce_op &&
1892 a.template match<bound>(*op.value.get(), state) &&
1893 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1894 return true;
1895 }
1896 }
1897 return false;
1898 }
1899
1900 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1902 return (reduce_op == reduce_op_2 &&
1903 a.template match<bound>(unwrap(op.a), state) &&
1904 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1905 }
1906
1911 lanes.make_folded_const(lanes_val, ty, state);
1912 int l = (int)lanes_val.u.i64;
1913 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1914 }
1915
1916 constexpr static bool foldable = false;
1917};
1918
1919template<typename A, typename B, VectorReduce::Operator reduce_op>
1920inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1921 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1922 return s;
1923}
1924
1925template<typename A, typename B>
1926HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1928 return {pattern_arg(a), pattern_arg(lanes)};
1929}
1930
1931template<typename A, typename B>
1932HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1934 return {pattern_arg(a), pattern_arg(lanes)};
1935}
1936
1937template<typename A, typename B>
1938HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1940 return {pattern_arg(a), pattern_arg(lanes)};
1941}
1942
1943template<typename A, typename B>
1944HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1946 return {pattern_arg(a), pattern_arg(lanes)};
1947}
1948
1949template<typename A, typename B>
1950HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1952 return {pattern_arg(a), pattern_arg(lanes)};
1953}
1954
1955template<typename A>
1956struct NegateOp {
1957 struct pattern_tag {};
1958 A a;
1959
1961
1964
1965 constexpr static bool canonical = A::canonical;
1966
1967 template<uint32_t bound>
1968 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1969 if (e.node_type != Sub::_node_type) {
1970 return false;
1971 }
1972 const Sub &op = (const Sub &)e;
1973 return (a.template match<bound>(*op.b.get(), state) &&
1974 is_const_zero(op.a));
1975 }
1976
1977 template<uint32_t bound, typename A2>
1978 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1979 return a.template match<bound>(unwrap(p.a), state);
1980 }
1981
1984 Expr ea = a.make(state, type_hint);
1985 Expr z = make_zero(ea.type());
1986 return Sub::make(std::move(z), std::move(ea));
1987 }
1988
1989 constexpr static bool foldable = A::foldable;
1990
1991 template<typename A1 = A>
1993 a.make_folded_const(val, ty, state);
1994 int dead_bits = 64 - ty.bits;
1995 switch (ty.code) {
1996 case halide_type_int:
1997 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1998 // Trying to negate the most negative signed int for a no-overflow type.
2000 } else {
2001 // Negate, drop the high bits, and then sign-extend them back
2002 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2003 }
2004 break;
2005 case halide_type_uint:
2006 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2007 break;
2008 case halide_type_float:
2009 case halide_type_bfloat:
2010 val.u.f64 = -val.u.f64;
2011 break;
2012 default:
2013 // unreachable
2014 ;
2015 }
2016 }
2017};
2018
2019template<typename A>
2020std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2021 s << "-" << op.a;
2022 return s;
2023}
2024
2025template<typename A>
2026HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2028 return {pattern_arg(a)};
2029}
2030
2031template<typename A>
2036
2037template<typename A>
2038struct CastOp {
2039 struct pattern_tag {};
2041 A a;
2042
2044
2047 constexpr static bool canonical = A::canonical;
2048
2049 template<uint32_t bound>
2050 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2051 if (e.node_type != Cast::_node_type) {
2052 return false;
2053 }
2054 const Cast &op = (const Cast &)e;
2055 return (e.type == t &&
2056 a.template match<bound>(*op.value.get(), state));
2057 }
2058 template<uint32_t bound, typename A2>
2059 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2060 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2061 }
2062
2065 return cast(t, a.make(state, {}));
2066 }
2067
2068 constexpr static bool foldable = false;
2069};
2070
2071template<typename A>
2072std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2073 s << "cast(" << op.t << ", " << op.a << ")";
2074 return s;
2075}
2076
2077template<typename A>
2078HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2080 return {t, pattern_arg(a)};
2081}
2082
2083template<typename Vec, typename Base, typename Stride, typename Lanes>
2084struct SliceOp {
2085 struct pattern_tag {};
2090
2091 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2092
2095 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2096
2097 template<uint32_t bound>
2098 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2099 if (e.node_type != IRNodeType::Shuffle) {
2100 return false;
2101 }
2102 const Shuffle &v = (const Shuffle &)e;
2103 return v.vectors.size() == 1 &&
2104 v.is_slice() &&
2105 vec.template match<bound>(*v.vectors[0].get(), state) &&
2106 base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2107 stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2109 }
2110
2115 base.make_folded_const(base_val, ty, state);
2116 int b = (int)base_val.u.i64;
2117 stride.make_folded_const(stride_val, ty, state);
2118 int s = (int)stride_val.u.i64;
2119 lanes.make_folded_const(lanes_val, ty, state);
2120 int l = (int)lanes_val.u.i64;
2121 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2122 }
2123
2124 constexpr static bool foldable = false;
2125
2128 : vec(v), base(b), stride(s), lanes(l) {
2129 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2130 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2131 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2132 }
2133};
2134
2135template<typename Vec, typename Base, typename Stride, typename Lanes>
2136std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2137 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2138 return s;
2139}
2140
2141template<typename Vec, typename Base, typename Stride, typename Lanes>
2142HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2143 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2144 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2145}
2146
2147template<typename A>
2148struct Fold {
2149 struct pattern_tag {};
2150 A a;
2151
2153
2156 constexpr static bool canonical = true;
2157
2162 a.make_folded_const(c, ty, state);
2163
2164 // The result of the fold may have an underspecified type
2165 // (e.g. because it's from an int literal). Make the type code
2166 // and bits match the required type, if there is one (we can
2167 // tell from the bits field).
2168 if (type_hint.bits) {
2169 if (((int)ty.code == (int)halide_type_int) &&
2170 ((int)type_hint.code == (int)halide_type_float)) {
2171 int64_t x = c.u.i64;
2172 c.u.f64 = (double)x;
2173 }
2174 ty.code = type_hint.code;
2175 ty.bits = type_hint.bits;
2176 }
2177
2178 Expr e = make_const_expr(c, ty);
2179 return e;
2180 }
2181
2182 constexpr static bool foldable = A::foldable;
2183
2184 template<typename A1 = A>
2186 a.make_folded_const(val, ty, state);
2187 }
2188};
2189
2190template<typename A>
2191HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2193 return {pattern_arg(a)};
2194}
2195
2196template<typename A>
2197std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2198 s << "fold(" << op.a << ")";
2199 return s;
2200}
2201
2202template<typename A>
2204 struct pattern_tag {};
2205 A a;
2206
2208
2209 // This rule is a predicate, so it always evaluates to a boolean,
2210 // which has IRNodeType UIntImm
2213 constexpr static bool canonical = true;
2214
2215 constexpr static bool foldable = A::foldable;
2216
2217 template<typename A1 = A>
2219 a.make_folded_const(val, ty, state);
2220 ty.code = halide_type_uint;
2221 ty.bits = 64;
2222 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2223 ty.lanes = 1;
2224 }
2225};
2226
2227template<typename A>
2228HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2230 return {pattern_arg(a)};
2231}
2232
2233template<typename A>
2234std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2235 s << "overflows(" << op.a << ")";
2236 return s;
2237}
2238
2239struct Overflow {
2240 struct pattern_tag {};
2241
2242 constexpr static uint32_t binds = 0;
2243
2244 // Overflow is an intrinsic, represented as a Call node
2247 constexpr static bool canonical = true;
2248
2249 template<uint32_t bound>
2250 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2251 if (e.node_type != Call::_node_type) {
2252 return false;
2253 }
2254 const Call &op = (const Call &)e;
2256 }
2257
2263
2264 constexpr static bool foldable = true;
2265
2268 val.u.u64 = 0;
2270 }
2271};
2272
2273inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2274 s << "overflow()";
2275 return s;
2276}
2277
2278template<typename A>
2279struct IsConst {
2280 struct pattern_tag {};
2281
2283
2284 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2287 constexpr static bool canonical = true;
2288
2289 A a;
2292
2293 constexpr static bool foldable = true;
2294
2295 template<typename A1 = A>
2297 Expr e = a.make(state, {});
2298 ty.code = halide_type_uint;
2299 ty.bits = 64;
2300 ty.lanes = 1;
2301 if (check_v) {
2302 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2303 } else {
2304 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2305 }
2306 }
2307};
2308
2309template<typename A>
2310HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2312 return {pattern_arg(a), false, 0};
2313}
2314
2315template<typename A>
2316HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2318 return {pattern_arg(a), true, value};
2319}
2320
2321template<typename A>
2322std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2323 if (op.check_v) {
2324 s << "is_const(" << op.a << ")";
2325 } else {
2326 s << "is_const(" << op.a << ", " << op.v << ")";
2327 }
2328 return s;
2329}
2330
2331template<typename A, typename Prover>
2332struct CanProve {
2333 struct pattern_tag {};
2334 A a;
2335 Prover *prover; // An existing simplifying mutator
2336
2338
2339 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2342 constexpr static bool canonical = true;
2343
2344 constexpr static bool foldable = true;
2345
2346 // Includes a raw call to an inlined make method, so don't inline.
2348 Expr condition = a.make(state, {});
2349 condition = prover->mutate(condition, nullptr);
2350 val.u.u64 = is_const_one(condition);
2351 ty.code = halide_type_uint;
2352 ty.bits = 1;
2353 ty.lanes = condition.type().lanes();
2354 }
2355};
2356
2357template<typename A, typename Prover>
2358HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2360 return {pattern_arg(a), p};
2361}
2362
2363template<typename A, typename Prover>
2364std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2365 s << "can_prove(" << op.a << ")";
2366 return s;
2367}
2368
2369template<typename A>
2370struct IsFloat {
2371 struct pattern_tag {};
2372 A a;
2373
2375
2376 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2379 constexpr static bool canonical = true;
2380
2381 constexpr static bool foldable = true;
2382
2385 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2386 Type t = a.make(state, {}).type();
2387 val.u.u64 = t.is_float();
2388 ty.code = halide_type_uint;
2389 ty.bits = 1;
2390 ty.lanes = t.lanes();
2391 }
2392};
2393
2394template<typename A>
2395HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2397 return {pattern_arg(a)};
2398}
2399
2400template<typename A>
2401std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2402 s << "is_float(" << op.a << ")";
2403 return s;
2404}
2405
2406template<typename A>
2407struct IsInt {
2408 struct pattern_tag {};
2409 A a;
2411
2413
2414 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2417 constexpr static bool canonical = true;
2418
2419 constexpr static bool foldable = true;
2420
2423 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2424 Type t = a.make(state, {}).type();
2425 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2426 ty.code = halide_type_uint;
2427 ty.bits = 1;
2428 ty.lanes = t.lanes();
2429 }
2430};
2431
2432template<typename A>
2433HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2435 return {pattern_arg(a), bits, lanes};
2436}
2437
2438template<typename A>
2439std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2440 s << "is_int(" << op.a;
2441 if (op.bits > 0) {
2442 s << ", " << op.bits;
2443 }
2444 if (op.lanes > 0) {
2445 s << ", " << op.lanes;
2446 }
2447 s << ")";
2448 return s;
2449}
2450
2451template<typename A>
2452struct IsUInt {
2453 struct pattern_tag {};
2454 A a;
2456
2458
2459 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2462 constexpr static bool canonical = true;
2463
2464 constexpr static bool foldable = true;
2465
2468 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2469 Type t = a.make(state, {}).type();
2470 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2471 ty.code = halide_type_uint;
2472 ty.bits = 1;
2473 ty.lanes = t.lanes();
2474 }
2475};
2476
2477template<typename A>
2478HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2480 return {pattern_arg(a), bits, lanes};
2481}
2482
2483template<typename A>
2484std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2485 s << "is_uint(" << op.a;
2486 if (op.bits > 0) {
2487 s << ", " << op.bits;
2488 }
2489 if (op.lanes > 0) {
2490 s << ", " << op.lanes;
2491 }
2492 s << ")";
2493 return s;
2494}
2495
2496template<typename A>
2497struct IsScalar {
2498 struct pattern_tag {};
2499 A a;
2500
2502
2503 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2506 constexpr static bool canonical = true;
2507
2508 constexpr static bool foldable = true;
2509
2512 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2513 Type t = a.make(state, {}).type();
2514 val.u.u64 = t.is_scalar();
2515 ty.code = halide_type_uint;
2516 ty.bits = 1;
2517 ty.lanes = t.lanes();
2518 }
2519};
2520
2521template<typename A>
2522HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2524 return {pattern_arg(a)};
2525}
2526
2527template<typename A>
2528std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2529 s << "is_scalar(" << op.a << ")";
2530 return s;
2531}
2532
2533template<typename A>
2535 struct pattern_tag {};
2536 A a;
2537
2539
2540 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2543 constexpr static bool canonical = true;
2544
2545 constexpr static bool foldable = true;
2546
2549 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2550 a.make_folded_const(val, ty, state);
2551 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2552 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2553 val.u.u64 = (val.u.u64 == max_bits);
2554 } else {
2555 val.u.u64 = 0;
2556 }
2557 ty.code = halide_type_uint;
2558 ty.bits = 1;
2559 }
2560};
2561
2562template<typename A>
2563HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2565 return {pattern_arg(a)};
2566}
2567
2568template<typename A>
2569std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2570 s << "is_max_value(" << op.a << ")";
2571 return s;
2572}
2573
2574template<typename A>
2576 struct pattern_tag {};
2577 A a;
2578
2580
2581 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2584 constexpr static bool canonical = true;
2585
2586 constexpr static bool foldable = true;
2587
2590 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2591 a.make_folded_const(val, ty, state);
2592 if (ty.code == halide_type_int) {
2593 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2594 val.u.u64 = (val.u.u64 == min_bits);
2595 } else if (ty.code == halide_type_uint) {
2596 val.u.u64 = (val.u.u64 == 0);
2597 } else {
2598 val.u.u64 = 0;
2599 }
2600 ty.code = halide_type_uint;
2601 ty.bits = 1;
2602 }
2603};
2604
2605template<typename A>
2606HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2608 return {pattern_arg(a)};
2609}
2610
2611template<typename A>
2612std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2613 s << "is_min_value(" << op.a << ")";
2614 return s;
2615}
2616
2617template<typename A>
2618struct LanesOf {
2619 struct pattern_tag {};
2620 A a;
2621
2623
2624 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2627 constexpr static bool canonical = true;
2628
2629 constexpr static bool foldable = true;
2630
2633 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2634 Type t = a.make(state, {}).type();
2635 val.u.u64 = t.lanes();
2636 ty.code = halide_type_uint;
2637 ty.bits = 32;
2638 ty.lanes = 1;
2639 }
2640};
2641
2642template<typename A>
2643HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2645 return {pattern_arg(a)};
2646}
2647
2648template<typename A>
2649std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2650 s << "lanes_of(" << op.a << ")";
2651 return s;
2652}
2653
2654// Verify properties of each rewrite rule. Currently just fuzz tests them.
2655template<typename Before,
2656 typename After,
2657 typename Predicate,
2658 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2659 std::decay<After>::type::foldable>::type>
2661 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2662
2663 // We only validate the rules in the scalar case
2664 wildcard_type.lanes = output_type.lanes = 1;
2665
2666 // Track which types this rule has been tested for before
2667 static std::set<uint32_t> tested;
2668
2669 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2670 return;
2671 }
2672
2673 // Print it in a form where it can be piped into a python/z3 validator
2674 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2675
2676 // Substitute some random constants into the before and after
2677 // expressions and see if the rule holds true. This should catch
2678 // silly errors, but not necessarily corner cases.
2679 static std::mt19937_64 rng(0);
2680 MatcherState state;
2681
2682 Expr exprs[max_wild];
2683
2684 for (int trials = 0; trials < 100; trials++) {
2685 // We want to test small constants more frequently than
2686 // large ones, otherwise we'll just get coverage of
2687 // overflow rules.
2688 int shift = (int)(rng() & (wildcard_type.bits - 1));
2689
2690 for (int i = 0; i < max_wild; i++) {
2691 // Bind all the exprs and constants
2692 switch (wildcard_type.code) {
2693 case halide_type_uint: {
2694 // Normalize to the type's range by adding zero
2695 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2696 state.set_bound_const(i, val, wildcard_type);
2697 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2698 exprs[i] = make_const(wildcard_type, val);
2699 state.set_binding(i, *exprs[i].get());
2700 } break;
2701 case halide_type_int: {
2702 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2703 state.set_bound_const(i, val, wildcard_type);
2704 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2705 exprs[i] = make_const(wildcard_type, val);
2706 } break;
2707 case halide_type_float:
2708 case halide_type_bfloat: {
2709 // Use a very narrow range of precise floats, so
2710 // that none of the rules a human is likely to
2711 // write have instabilities.
2712 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2713 state.set_bound_const(i, val, wildcard_type);
2714 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2715 exprs[i] = make_const(wildcard_type, val);
2716 } break;
2717 default:
2718 return; // Don't care about handles
2719 }
2720 state.set_binding(i, *exprs[i].get());
2721 }
2722
2724 halide_type_t type = output_type;
2725 if (!evaluate_predicate(pred, state)) {
2726 continue;
2727 }
2728 before.make_folded_const(val_before, type, state);
2729 uint16_t lanes = type.lanes;
2730 after.make_folded_const(val_after, type, state);
2731 lanes |= type.lanes;
2732
2734 continue;
2735 }
2736
2737 bool ok = true;
2738 switch (output_type.code) {
2739 case halide_type_uint:
2740 // Compare normalized representations
2741 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2742 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2743 break;
2744 case halide_type_int:
2745 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2746 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2747 break;
2748 case halide_type_float:
2749 case halide_type_bfloat: {
2750 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2751 // We accept an equal bit pattern (e.g. inf vs inf),
2752 // a small floating point difference, or turning a nan into not-a-nan.
2753 ok &= (error < 0.01 ||
2754 val_before.u.u64 == val_after.u.u64 ||
2755 std::isnan(val_before.u.f64));
2756 break;
2757 }
2758 default:
2759 return;
2760 }
2761
2762 if (!ok) {
2763 debug(0) << "Fails with values:\n";
2764 for (int i = 0; i < max_wild; i++) {
2766 state.get_bound_const(i, val, wildcard_type);
2767 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2768 }
2769 for (int i = 0; i < max_wild; i++) {
2770 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2771 }
2772 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2773 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2774 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2776 }
2777 }
2778}
2779
2780template<typename Before,
2781 typename After,
2782 typename Predicate,
2783 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2784 std::decay<After>::type::foldable)>::type>
2786 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2787 // We can't verify rewrite rules that can't be constant-folded.
2788}
2789
2791bool evaluate_predicate(bool x, MatcherState &) noexcept {
2792 return x;
2793}
2794
2795template<typename Pattern,
2796 typename = typename enable_if_pattern<Pattern>::type>
2800 p.make_folded_const(c, ty, state);
2801 // Overflow counts as a failed predicate
2802 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2803}
2804
2805// #defines for testing
2806
2807// Print all successful or failed matches
2808#define HALIDE_DEBUG_MATCHED_RULES 0
2809#define HALIDE_DEBUG_UNMATCHED_RULES 0
2810
2811// Set to true if you want to fuzz test every rewrite passed to
2812// operator() to ensure the input and the output have the same value
2813// for lots of random values of the wildcards. Run
2814// correctness_simplify with this on.
2815#define HALIDE_FUZZ_TEST_RULES 0
2816
2817template<typename Instance>
2818struct Rewriter {
2824
2829
2830 template<typename After>
2834
2835 template<typename Before,
2836 typename After,
2837 typename = typename enable_if_pattern<Before>::type,
2838 typename = typename enable_if_pattern<After>::type>
2840 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2841 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2842 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2843#if HALIDE_FUZZ_TEST_RULES
2845#endif
2846 if (before.template match<0>(unwrap(instance), state)) {
2848#if HALIDE_DEBUG_MATCHED_RULES
2849 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2850#endif
2851 return true;
2852 } else {
2853#if HALIDE_DEBUG_UNMATCHED_RULES
2854 debug(0) << instance << " does not match " << before << "\n";
2855#endif
2856 return false;
2857 }
2858 }
2859
2860 template<typename Before,
2861 typename = typename enable_if_pattern<Before>::type>
2863 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2864 if (before.template match<0>(unwrap(instance), state)) {
2865 result = after;
2866#if HALIDE_DEBUG_MATCHED_RULES
2867 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2868#endif
2869 return true;
2870 } else {
2871#if HALIDE_DEBUG_UNMATCHED_RULES
2872 debug(0) << instance << " does not match " << before << "\n";
2873#endif
2874 return false;
2875 }
2876 }
2877
2878 template<typename Before,
2879 typename = typename enable_if_pattern<Before>::type>
2881 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2882#if HALIDE_FUZZ_TEST_RULES
2884#endif
2885 if (before.template match<0>(unwrap(instance), state)) {
2887#if HALIDE_DEBUG_MATCHED_RULES
2888 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2889#endif
2890 return true;
2891 } else {
2892#if HALIDE_DEBUG_UNMATCHED_RULES
2893 debug(0) << instance << " does not match " << before << "\n";
2894#endif
2895 return false;
2896 }
2897 }
2898
2899 template<typename Before,
2900 typename After,
2901 typename Predicate,
2902 typename = typename enable_if_pattern<Before>::type,
2903 typename = typename enable_if_pattern<After>::type,
2904 typename = typename enable_if_pattern<Predicate>::type>
2906 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2907 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2908 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2909 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2910 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2911
2912#if HALIDE_FUZZ_TEST_RULES
2914#endif
2915 if (before.template match<0>(unwrap(instance), state) &&
2918#if HALIDE_DEBUG_MATCHED_RULES
2919 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2920#endif
2921 return true;
2922 } else {
2923#if HALIDE_DEBUG_UNMATCHED_RULES
2924 debug(0) << instance << " does not match " << before << "\n";
2925#endif
2926 return false;
2927 }
2928 }
2929
2930 template<typename Before,
2931 typename Predicate,
2932 typename = typename enable_if_pattern<Before>::type,
2933 typename = typename enable_if_pattern<Predicate>::type>
2935 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2936 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2937
2938 if (before.template match<0>(unwrap(instance), state) &&
2940 result = after;
2941#if HALIDE_DEBUG_MATCHED_RULES
2942 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2943#endif
2944 return true;
2945 } else {
2946#if HALIDE_DEBUG_UNMATCHED_RULES
2947 debug(0) << instance << " does not match " << before << "\n";
2948#endif
2949 return false;
2950 }
2951 }
2952
2953 template<typename Before,
2954 typename Predicate,
2955 typename = typename enable_if_pattern<Before>::type,
2956 typename = typename enable_if_pattern<Predicate>::type>
2958 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2959 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2960#if HALIDE_FUZZ_TEST_RULES
2962#endif
2963 if (before.template match<0>(unwrap(instance), state) &&
2966#if HALIDE_DEBUG_MATCHED_RULES
2967 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2968#endif
2969 return true;
2970 } else {
2971#if HALIDE_DEBUG_UNMATCHED_RULES
2972 debug(0) << instance << " does not match " << before << "\n";
2973#endif
2974 return false;
2975 }
2976 }
2977};
2978
2979/** Construct a rewriter for the given instance, which may be a pattern
2980 * with concrete expressions as leaves, or just an expression. The
2981 * second optional argument (wildcard_type) is a hint as to what the
2982 * type of the wildcards is likely to be. If omitted it uses the same
2983 * type as the expression itself. They are not required to be this
2984 * type, but the rule will only be tested for wildcards of that type
2985 * when testing is enabled.
2986 *
2987 * The rewriter can be used to check to see if the instance is one of
2988 * some number of patterns and if so rewrite it into another form,
2989 * using its operator() method. See Simplify.cpp for a bunch of
2990 * example usage.
2991 *
2992 * Important: Any Exprs in patterns are captured by reference, not by
2993 * value, so ensure they outlive the rewriter.
2994 */
2995// @{
2996template<typename Instance,
2997 typename = typename enable_if_pattern<Instance>::type>
2998HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2999 return {pattern_arg(instance), output_type, wildcard_type};
3000}
3001
3002template<typename Instance,
3003 typename = typename enable_if_pattern<Instance>::type>
3004HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3005 return {pattern_arg(instance), output_type, output_type};
3006}
3007
3009auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3010 return {pattern_arg(e), e.type(), wildcard_type};
3011}
3012
3014auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3015 return {pattern_arg(e), e.type(), e.type()};
3016}
3017// @}
3018
3019} // namespace IRMatcher
3020
3021} // namespace Internal
3022} // namespace Halide
3023
3024#endif
#define internal_error
Definition Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1598
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1590
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:2998
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:579
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1539
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1287
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1655
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2791
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1044
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1262
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2032
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1182
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:933
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2563
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:229
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1313
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1944
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1162
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2310
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition IRMatch.h:1534
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1192
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:999
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1586
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1602
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1547
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:940
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1039
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:1006
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2142
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1868
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1032
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1560
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1073
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1318
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:571
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1157
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2078
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2228
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1552
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:588
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1059
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:980
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2522
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2191
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1661
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1578
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1117
constexpr bool and_reduce()
Definition IRMatch.h:1342
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1282
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1556
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1257
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2395
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1207
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1132
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1308
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2478
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1950
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:627
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1543
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:973
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1938
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1804
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2433
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1731
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2606
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1095
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2660
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1167
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1582
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1568
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1013
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1606
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1594
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1217
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:966
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1187
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1137
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2643
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1142
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1932
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1926
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1292
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1333
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1610
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1242
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1212
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition IRMatch.h:1572
constexpr int const_min(int a, int b)
Definition IRMatch.h:1352
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1267
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1066
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1232
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:947
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2358
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1237
T div_imp(T a, T b)
Definition IROperator.h:260
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:239
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr min(const FuncRef &a, const FuncRef &b)
Explicit overloads of min and max for FuncRef.
Definition Func.h:603
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr max(const FuncRef &a, const FuncRef &b)
Definition Func.h:606
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:322
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:316
The sum of two expressions.
Definition IR.h:56
Logical and - are both expressions true.
Definition IR.h:175
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:707
static const IRNodeType _node_type
Definition IR.h:752
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
static constexpr bool canonical
Definition IRMatch.h:653
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:676
static constexpr uint32_t binds
Definition IRMatch.h:645
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:657
static constexpr bool foldable
Definition IRMatch.h:673
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:719
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:667
static constexpr IRNodeType max_node_type
Definition IRMatch.h:648
static constexpr IRNodeType min_node_type
Definition IRMatch.h:647
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1746
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1770
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1764
static constexpr uint32_t binds
Definition IRMatch.h:1744
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1752
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1747
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1787
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2347
static constexpr uint32_t binds
Definition IRMatch.h:2337
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2340
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2341
static constexpr bool foldable
Definition IRMatch.h:2344
static constexpr bool canonical
Definition IRMatch.h:2342
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2046
static constexpr bool foldable
Definition IRMatch.h:2068
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2050
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2045
static constexpr uint32_t binds
Definition IRMatch.h:2043
static constexpr bool canonical
Definition IRMatch.h:2047
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2059
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2064
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:820
static constexpr IRNodeType max_node_type
Definition IRMatch.h:759
static constexpr uint32_t binds
Definition IRMatch.h:756
static constexpr bool canonical
Definition IRMatch.h:760
static constexpr bool foldable
Definition IRMatch.h:783
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:767
static constexpr IRNodeType min_node_type
Definition IRMatch.h:758
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:786
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:777
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2155
static constexpr uint32_t binds
Definition IRMatch.h:2152
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2154
static constexpr bool canonical
Definition IRMatch.h:2156
static constexpr bool foldable
Definition IRMatch.h:2182
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2159
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2185
static constexpr IRNodeType max_node_type
Definition IRMatch.h:507
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:516
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:511
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:539
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:551
static constexpr IRNodeType min_node_type
Definition IRMatch.h:506
static constexpr bool canonical
Definition IRMatch.h:508
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:544
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:534
static constexpr uint32_t binds
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1382
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1369
static constexpr bool canonical
Definition IRMatch.h:1370
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1417
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1412
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1477
static constexpr uint32_t binds
Definition IRMatch.h:1366
static constexpr bool foldable
Definition IRMatch.h:1475
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1375
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1399
std::tuple< Args... > args
Definition IRMatch.h:1360
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1387
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1408
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition IRMatch.h:1520
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1368
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2285
static constexpr bool canonical
Definition IRMatch.h:2287
static constexpr bool foldable
Definition IRMatch.h:2293
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2286
static constexpr uint32_t binds
Definition IRMatch.h:2282
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2296
static constexpr bool foldable
Definition IRMatch.h:2381
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2384
static constexpr bool canonical
Definition IRMatch.h:2379
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2377
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2378
static constexpr uint32_t binds
Definition IRMatch.h:2374
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2416
static constexpr bool foldable
Definition IRMatch.h:2419
static constexpr uint32_t binds
Definition IRMatch.h:2412
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2422
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2415
static constexpr bool canonical
Definition IRMatch.h:2417
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2541
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2542
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2548
static constexpr uint32_t binds
Definition IRMatch.h:2538
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2582
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2589
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2583
static constexpr uint32_t binds
Definition IRMatch.h:2579
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2505
static constexpr uint32_t binds
Definition IRMatch.h:2501
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2511
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2504
static constexpr bool foldable
Definition IRMatch.h:2508
static constexpr bool canonical
Definition IRMatch.h:2506
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2467
static constexpr bool foldable
Definition IRMatch.h:2464
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2460
static constexpr bool canonical
Definition IRMatch.h:2462
static constexpr uint32_t binds
Definition IRMatch.h:2457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2461
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2626
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2632
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2625
static constexpr bool foldable
Definition IRMatch.h:2629
static constexpr uint32_t binds
Definition IRMatch.h:2622
static constexpr bool canonical
Definition IRMatch.h:2627
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1968
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1983
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1978
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1992
static constexpr uint32_t binds
Definition IRMatch.h:1960
static constexpr bool canonical
Definition IRMatch.h:1965
static constexpr bool foldable
Definition IRMatch.h:1989
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1963
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1962
static constexpr uint32_t binds
Definition IRMatch.h:1619
static constexpr bool foldable
Definition IRMatch.h:1644
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1626
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1622
static constexpr bool canonical
Definition IRMatch.h:1623
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1635
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1640
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1647
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1621
static constexpr uint32_t binds
Definition IRMatch.h:2242
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2246
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2250
static constexpr bool canonical
Definition IRMatch.h:2247
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2259
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2267
static constexpr bool foldable
Definition IRMatch.h:2264
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2245
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2218
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2211
static constexpr uint32_t binds
Definition IRMatch.h:2207
static constexpr bool canonical
Definition IRMatch.h:2213
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2212
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1846
static constexpr bool canonical
Definition IRMatch.h:1821
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1819
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1818
static constexpr uint32_t binds
Definition IRMatch.h:1816
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1839
static constexpr bool foldable
Definition IRMatch.h:1858
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1824
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2831
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2905
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2880
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2826
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2934
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2862
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:2957
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2839
static constexpr uint32_t binds
Definition IRMatch.h:1679
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1711
static constexpr bool foldable
Definition IRMatch.h:1708
static constexpr bool canonical
Definition IRMatch.h:1684
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1697
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1687
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1704
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1682
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1681
static constexpr bool canonical
Definition IRMatch.h:2095
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2094
static constexpr bool foldable
Definition IRMatch.h:2124
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2127
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2093
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2098
static constexpr uint32_t binds
Definition IRMatch.h:2091
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2112
static constexpr IRNodeType min_node_type
Definition IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:222
static constexpr IRNodeType max_node_type
Definition IRMatch.h:211
static constexpr uint32_t binds
Definition IRMatch.h:207
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1901
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1883
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1888
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1884
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1908
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:364
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:385
static constexpr IRNodeType max_node_type
Definition IRMatch.h:360
static constexpr IRNodeType min_node_type
Definition IRMatch.h:359
static constexpr uint32_t binds
Definition IRMatch.h:357
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:395
static constexpr bool canonical
Definition IRMatch.h:415
static constexpr IRNodeType max_node_type
Definition IRMatch.h:414
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:443
static constexpr uint32_t binds
Definition IRMatch.h:411
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:418
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:453
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:437
static constexpr IRNodeType min_node_type
Definition IRMatch.h:413
static constexpr bool foldable
Definition IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:279
static constexpr uint32_t binds
Definition IRMatch.h:238
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:266
static constexpr IRNodeType min_node_type
Definition IRMatch.h:240
static constexpr IRNodeType max_node_type
Definition IRMatch.h:241
static constexpr uint32_t binds
Definition IRMatch.h:304
static constexpr IRNodeType max_node_type
Definition IRMatch.h:307
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:342
static constexpr IRNodeType min_node_type
Definition IRMatch.h:306
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:332
static constexpr IRNodeType min_node_type
Definition IRMatch.h:471
static constexpr uint32_t binds
Definition IRMatch.h:469
static constexpr IRNodeType max_node_type
Definition IRMatch.h:472
static constexpr bool canonical
Definition IRMatch.h:473
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:485
static constexpr bool foldable
Definition IRMatch.h:489
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:476
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition IR.h:184
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:841
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:842
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:896
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:893
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:966
static const IRNodeType _node_type
Definition IR.h:985
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:276
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:428
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:348
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:434
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:342
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:410
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:416
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@4 u
A runtime tag for a type in the halide type system.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.