Halide  20.0.0
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(std::move(e), lanes);
187  }
188  return e;
189 }
190 
191 // A pattern that matches a specific expression
192 struct SpecificExpr {
193  struct pattern_tag {};
194 
195  constexpr static uint32_t binds = 0;
196 
197  // What is the weakest and strongest IR node this could possibly be
200  constexpr static bool canonical = true;
201 
203 
204  template<uint32_t bound>
205  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206  return equal(expr, e);
207  }
208 
210  Expr make(MatcherState &state, halide_type_t type_hint) const {
211  return Expr(&expr);
212  }
213 
214  constexpr static bool foldable = false;
215 };
216 
217 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218  s << Expr(&e.expr);
219  return s;
220 }
221 
222 template<int i>
223 struct WildConstInt {
224  struct pattern_tag {};
225 
226  constexpr static uint32_t binds = 1 << i;
227 
230  constexpr static bool canonical = true;
231 
232  template<uint32_t bound>
233  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235  const BaseExprNode *op = &e;
236  if (op->node_type == IRNodeType::Broadcast) {
237  op = ((const Broadcast *)op)->value.get();
238  }
239  if (op->node_type != IRNodeType::IntImm) {
240  return false;
241  }
242  int64_t value = ((const IntImm *)op)->value;
243  if (bound & binds) {
245  halide_type_t type;
246  state.get_bound_const(i, val, type);
247  return (halide_type_t)e.type == type && value == val.u.i64;
248  }
249  state.set_bound_const(i, value, e.type);
250  return true;
251  }
252 
253  template<uint32_t bound>
254  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256  if (bound & binds) {
258  halide_type_t type;
259  state.get_bound_const(i, val, type);
260  return type == i64_type && value == val.u.i64;
261  }
262  state.set_bound_const(i, value, i64_type);
263  return true;
264  }
265 
267  Expr make(MatcherState &state, halide_type_t type_hint) const {
269  halide_type_t type;
270  state.get_bound_const(i, val, type);
271  return make_const_expr(val, type);
272  }
273 
274  constexpr static bool foldable = true;
275 
278  state.get_bound_const(i, val, ty);
279  }
280 };
281 
282 template<int i>
283 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284  s << "ci" << i;
285  return s;
286 }
287 
288 template<int i>
290  struct pattern_tag {};
291 
292  constexpr static uint32_t binds = 1 << i;
293 
296  constexpr static bool canonical = true;
297 
298  template<uint32_t bound>
299  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301  const BaseExprNode *op = &e;
302  if (op->node_type == IRNodeType::Broadcast) {
303  op = ((const Broadcast *)op)->value.get();
304  }
305  if (op->node_type != IRNodeType::UIntImm) {
306  return false;
307  }
308  uint64_t value = ((const UIntImm *)op)->value;
309  if (bound & binds) {
311  halide_type_t type;
312  state.get_bound_const(i, val, type);
313  return (halide_type_t)e.type == type && value == val.u.u64;
314  }
315  state.set_bound_const(i, value, e.type);
316  return true;
317  }
318 
320  Expr make(MatcherState &state, halide_type_t type_hint) const {
322  halide_type_t type;
323  state.get_bound_const(i, val, type);
324  return make_const_expr(val, type);
325  }
326 
327  constexpr static bool foldable = true;
328 
330  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
331  state.get_bound_const(i, val, ty);
332  }
333 };
334 
335 template<int i>
336 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337  s << "cu" << i;
338  return s;
339 }
340 
341 template<int i>
343  struct pattern_tag {};
344 
345  constexpr static uint32_t binds = 1 << i;
346 
349  constexpr static bool canonical = true;
350 
351  template<uint32_t bound>
352  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354  const BaseExprNode *op = &e;
355  if (op->node_type == IRNodeType::Broadcast) {
356  op = ((const Broadcast *)op)->value.get();
357  }
358  if (op->node_type != IRNodeType::FloatImm) {
359  return false;
360  }
361  double value = ((const FloatImm *)op)->value;
362  if (bound & binds) {
364  halide_type_t type;
365  state.get_bound_const(i, val, type);
366  return (halide_type_t)e.type == type && value == val.u.f64;
367  }
368  state.set_bound_const(i, value, e.type);
369  return true;
370  }
371 
373  Expr make(MatcherState &state, halide_type_t type_hint) const {
375  halide_type_t type;
376  state.get_bound_const(i, val, type);
377  return make_const_expr(val, type);
378  }
379 
380  constexpr static bool foldable = true;
381 
383  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
384  state.get_bound_const(i, val, ty);
385  }
386 };
387 
388 template<int i>
389 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390  s << "cf" << i;
391  return s;
392 }
393 
394 // Matches and binds to any constant Expr. Does not support constant-folding.
395 template<int i>
396 struct WildConst {
397  struct pattern_tag {};
398 
399  constexpr static uint32_t binds = 1 << i;
400 
403  constexpr static bool canonical = true;
404 
405  template<uint32_t bound>
406  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408  const BaseExprNode *op = &e;
409  if (op->node_type == IRNodeType::Broadcast) {
410  op = ((const Broadcast *)op)->value.get();
411  }
412  switch (op->node_type) {
413  case IRNodeType::IntImm:
414  return WildConstInt<i>().template match<bound>(e, state);
415  case IRNodeType::UIntImm:
416  return WildConstUInt<i>().template match<bound>(e, state);
418  return WildConstFloat<i>().template match<bound>(e, state);
419  default:
420  return false;
421  }
422  }
423 
424  template<uint32_t bound>
425  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427  return WildConstInt<i>().template match<bound>(e, state);
428  }
429 
431  Expr make(MatcherState &state, halide_type_t type_hint) const {
433  halide_type_t type;
434  state.get_bound_const(i, val, type);
435  return make_const_expr(val, type);
436  }
437 
438  constexpr static bool foldable = true;
439 
441  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
442  state.get_bound_const(i, val, ty);
443  }
444 };
445 
446 template<int i>
447 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448  s << "c" << i;
449  return s;
450 }
451 
452 // Matches and binds to any Expr
453 template<int i>
454 struct Wild {
455  struct pattern_tag {};
456 
457  constexpr static uint32_t binds = 1 << (i + 16);
458 
461  constexpr static bool canonical = true;
462 
463  template<uint32_t bound>
464  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465  if (bound & binds) {
466  return equal(*state.get_binding(i), e);
467  }
468  state.set_binding(i, e);
469  return true;
470  }
471 
473  Expr make(MatcherState &state, halide_type_t type_hint) const {
474  return state.get_binding(i);
475  }
476 
477  constexpr static bool foldable = false;
478 };
479 
480 template<int i>
481 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482  s << "_" << i;
483  return s;
484 }
485 
486 // Matches a specific constant or broadcast of that constant. The
487 // constant must be representable as an int64_t.
488 struct IntLiteral {
489  struct pattern_tag {};
491 
492  constexpr static uint32_t binds = 0;
493 
496  constexpr static bool canonical = true;
497 
499  explicit IntLiteral(int64_t v)
500  : v(v) {
501  }
502 
503  template<uint32_t bound>
504  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505  const BaseExprNode *op = &e;
506  if (e.node_type == IRNodeType::Broadcast) {
507  op = ((const Broadcast *)op)->value.get();
508  }
509  switch (op->node_type) {
510  case IRNodeType::IntImm:
511  return ((const IntImm *)op)->value == (int64_t)v;
512  case IRNodeType::UIntImm:
513  return ((const UIntImm *)op)->value == (uint64_t)v;
515  return ((const FloatImm *)op)->value == (double)v;
516  default:
517  return false;
518  }
519  }
520 
521  template<uint32_t bound>
522  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523  return v == val;
524  }
525 
526  template<uint32_t bound>
527  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528  return v == b.v;
529  }
530 
532  Expr make(MatcherState &state, halide_type_t type_hint) const {
533  return make_const(type_hint, v);
534  }
535 
536  constexpr static bool foldable = true;
537 
539  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
540  // Assume type is already correct
541  switch (ty.code) {
542  case halide_type_int:
543  val.u.i64 = v;
544  break;
545  case halide_type_uint:
546  val.u.u64 = (uint64_t)v;
547  break;
548  case halide_type_float:
549  case halide_type_bfloat:
550  val.u.f64 = (double)v;
551  break;
552  default:
553  // Unreachable
554  ;
555  }
556  }
557 };
558 
560  return t.v;
561 }
562 
563 // Convert a provided pattern, expr, or constant int into the internal
564 // representation we use in the matcher trees.
565 template<typename T,
566  typename = typename std::decay<T>::type::pattern_tag>
568  return t;
569 }
572  return IntLiteral{x};
573 }
574 
575 template<typename T>
577  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579 }
580 
582  return {*e.get()};
583 }
584 
585 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
586 // passing them by value anywhere (incurring lots of refcounting)
587 template<typename T,
588  // T must be a pattern node
589  typename = typename std::decay<T>::type::pattern_tag,
590  // But T may not be SpecificExpr
591  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593  return t;
594 }
595 
597 const BaseExprNode &unwrap(const SpecificExpr &e) {
598  return e.expr;
599 }
600 
601 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602  s << op.v;
603  return s;
604 }
605 
606 template<typename Op>
608 
609 template<typename Op>
611 
612 template<typename Op>
613 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614 
615 constexpr bool commutative(IRNodeType t) {
616  return (t == IRNodeType::Add ||
617  t == IRNodeType::Mul ||
618  t == IRNodeType::And ||
619  t == IRNodeType::Or ||
620  t == IRNodeType::Min ||
621  t == IRNodeType::Max ||
622  t == IRNodeType::EQ ||
623  t == IRNodeType::NE);
624 }
625 
626 // Matches one of the binary operators
627 template<typename Op, typename A, typename B>
628 struct BinOp {
629  struct pattern_tag {};
630  A a;
631  B b;
632 
634 
635  constexpr static IRNodeType min_node_type = Op::_node_type;
636  constexpr static IRNodeType max_node_type = Op::_node_type;
637 
638  // For commutative bin ops, we expect the weaker IR node type on
639  // the right. That is, for the rule to be canonical it must be
640  // possible that A is at least as strong as B.
641  constexpr static bool canonical =
642  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643 
644  template<uint32_t bound>
645  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646  if (e.node_type != Op::_node_type) {
647  return false;
648  }
649  const Op &op = (const Op &)e;
650  return (a.template match<bound>(*op.a.get(), state) &&
651  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
652  }
653 
654  template<uint32_t bound, typename Op2, typename A2, typename B2>
655  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656  return (std::is_same<Op, Op2>::value &&
657  a.template match<bound>(unwrap(op.a), state) &&
658  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
659  }
660 
661  constexpr static bool foldable = A::foldable && B::foldable;
662 
664  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
665  halide_scalar_value_t val_a, val_b;
666  if (std::is_same<A, IntLiteral>::value) {
667  b.make_folded_const(val_b, ty, state);
668  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670  // Short circuit
671  val = val_b;
672  return;
673  }
674  const uint16_t l = ty.lanes;
675  a.make_folded_const(val_a, ty, state);
676  ty.lanes |= l; // Make sure the overflow bits are sticky
677  } else {
678  a.make_folded_const(val_a, ty, state);
679  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681  // Short circuit
682  val = val_a;
683  return;
684  }
685  const uint16_t l = ty.lanes;
686  b.make_folded_const(val_b, ty, state);
687  ty.lanes |= l;
688  }
689  switch (ty.code) {
690  case halide_type_int:
691  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692  break;
693  case halide_type_uint:
694  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695  break;
696  case halide_type_float:
697  case halide_type_bfloat:
698  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699  break;
700  default:
701  // unreachable
702  ;
703  }
704  }
705 
707  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708  Expr ea, eb;
709  if (std::is_same<A, IntLiteral>::value) {
710  eb = b.make(state, type_hint);
711  ea = a.make(state, eb.type());
712  } else {
713  ea = a.make(state, type_hint);
714  eb = b.make(state, ea.type());
715  }
716  return Op::make(std::move(ea), std::move(eb));
717  }
718 };
719 
720 template<typename Op>
722 
723 template<typename Op>
725 
726 template<typename Op>
727 uint64_t constant_fold_cmp_op(double, double) noexcept;
728 
729 // Matches one of the comparison operators
730 template<typename Op, typename A, typename B>
731 struct CmpOp {
732  struct pattern_tag {};
733  A a;
734  B b;
735 
737 
738  constexpr static IRNodeType min_node_type = Op::_node_type;
739  constexpr static IRNodeType max_node_type = Op::_node_type;
740  constexpr static bool canonical = (A::canonical &&
741  B::canonical &&
742  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
743  (Op::_node_type != IRNodeType::GE) &&
744  (Op::_node_type != IRNodeType::GT));
745 
746  template<uint32_t bound>
747  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
748  if (e.node_type != Op::_node_type) {
749  return false;
750  }
751  const Op &op = (const Op &)e;
752  return (a.template match<bound>(*op.a.get(), state) &&
753  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
754  }
755 
756  template<uint32_t bound, typename Op2, typename A2, typename B2>
757  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
758  return (std::is_same<Op, Op2>::value &&
759  a.template match<bound>(unwrap(op.a), state) &&
760  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
761  }
762 
763  constexpr static bool foldable = A::foldable && B::foldable;
764 
766  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
767  halide_scalar_value_t val_a, val_b;
768  // If one side is an untyped const, evaluate the other side first to get a type hint.
769  if (std::is_same<A, IntLiteral>::value) {
770  b.make_folded_const(val_b, ty, state);
771  const uint16_t l = ty.lanes;
772  a.make_folded_const(val_a, ty, state);
773  ty.lanes |= l;
774  } else {
775  a.make_folded_const(val_a, ty, state);
776  const uint16_t l = ty.lanes;
777  b.make_folded_const(val_b, ty, state);
778  ty.lanes |= l;
779  }
780  switch (ty.code) {
781  case halide_type_int:
782  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
783  break;
784  case halide_type_uint:
785  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
786  break;
787  case halide_type_float:
788  case halide_type_bfloat:
789  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
790  break;
791  default:
792  // unreachable
793  ;
794  }
795  ty.code = halide_type_uint;
796  ty.bits = 1;
797  }
798 
800  Expr make(MatcherState &state, halide_type_t type_hint) const {
801  // If one side is an untyped const, evaluate the other side first to get a type hint.
802  Expr ea, eb;
803  if (std::is_same<A, IntLiteral>::value) {
804  eb = b.make(state, {});
805  ea = a.make(state, eb.type());
806  } else {
807  ea = a.make(state, {});
808  eb = b.make(state, ea.type());
809  }
810  return Op::make(std::move(ea), std::move(eb));
811  }
812 };
813 
814 template<typename A, typename B>
815 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
816  s << "(" << op.a << " + " << op.b << ")";
817  return s;
818 }
819 
820 template<typename A, typename B>
821 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
822  s << "(" << op.a << " - " << op.b << ")";
823  return s;
824 }
825 
826 template<typename A, typename B>
827 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
828  s << "(" << op.a << " * " << op.b << ")";
829  return s;
830 }
831 
832 template<typename A, typename B>
833 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
834  s << "(" << op.a << " / " << op.b << ")";
835  return s;
836 }
837 
838 template<typename A, typename B>
839 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
840  s << "(" << op.a << " && " << op.b << ")";
841  return s;
842 }
843 
844 template<typename A, typename B>
845 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
846  s << "(" << op.a << " || " << op.b << ")";
847  return s;
848 }
849 
850 template<typename A, typename B>
851 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
852  s << "min(" << op.a << ", " << op.b << ")";
853  return s;
854 }
855 
856 template<typename A, typename B>
857 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
858  s << "max(" << op.a << ", " << op.b << ")";
859  return s;
860 }
861 
862 template<typename A, typename B>
863 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
864  s << "(" << op.a << " <= " << op.b << ")";
865  return s;
866 }
867 
868 template<typename A, typename B>
869 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
870  s << "(" << op.a << " < " << op.b << ")";
871  return s;
872 }
873 
874 template<typename A, typename B>
875 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
876  s << "(" << op.a << " >= " << op.b << ")";
877  return s;
878 }
879 
880 template<typename A, typename B>
881 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
882  s << "(" << op.a << " > " << op.b << ")";
883  return s;
884 }
885 
886 template<typename A, typename B>
887 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
888  s << "(" << op.a << " == " << op.b << ")";
889  return s;
890 }
891 
892 template<typename A, typename B>
893 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
894  s << "(" << op.a << " != " << op.b << ")";
895  return s;
896 }
897 
898 template<typename A, typename B>
899 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
900  s << "(" << op.a << " % " << op.b << ")";
901  return s;
902 }
903 
904 template<typename A, typename B>
905 HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
906  assert_is_lvalue_if_expr<A>();
907  assert_is_lvalue_if_expr<B>();
908  return {pattern_arg(a), pattern_arg(b)};
909 }
910 
911 template<typename A, typename B>
912 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
913  assert_is_lvalue_if_expr<A>();
914  assert_is_lvalue_if_expr<B>();
915  return IRMatcher::operator+(a, b);
916 }
917 
918 template<>
920  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
921  int dead_bits = 64 - t.bits;
922  // Drop the high bits then sign-extend them back
923  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
924 }
925 
926 template<>
928  uint64_t ones = (uint64_t)(-1);
929  return (a + b) & (ones >> (64 - t.bits));
930 }
931 
932 template<>
933 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
934  return a + b;
935 }
936 
937 template<typename A, typename B>
938 HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
939  assert_is_lvalue_if_expr<A>();
940  assert_is_lvalue_if_expr<B>();
941  return {pattern_arg(a), pattern_arg(b)};
942 }
943 
944 template<typename A, typename B>
945 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
946  assert_is_lvalue_if_expr<A>();
947  assert_is_lvalue_if_expr<B>();
948  return IRMatcher::operator-(a, b);
949 }
950 
951 template<>
953  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
954  // Drop the high bits then sign-extend them back
955  int dead_bits = 64 - t.bits;
956  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
957 }
958 
959 template<>
961  uint64_t ones = (uint64_t)(-1);
962  return (a - b) & (ones >> (64 - t.bits));
963 }
964 
965 template<>
966 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
967  return a - b;
968 }
969 
970 template<typename A, typename B>
971 HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
972  assert_is_lvalue_if_expr<A>();
973  assert_is_lvalue_if_expr<B>();
974  return {pattern_arg(a), pattern_arg(b)};
975 }
976 
977 template<typename A, typename B>
978 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
979  assert_is_lvalue_if_expr<A>();
980  assert_is_lvalue_if_expr<B>();
981  return IRMatcher::operator*(a, b);
982 }
983 
984 template<>
986  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
987  int dead_bits = 64 - t.bits;
988  // Drop the high bits then sign-extend them back
989  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
990 }
991 
992 template<>
994  uint64_t ones = (uint64_t)(-1);
995  return (a * b) & (ones >> (64 - t.bits));
996 }
997 
998 template<>
999 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1000  return a * b;
1001 }
1002 
1003 template<typename A, typename B>
1004 HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1005  assert_is_lvalue_if_expr<A>();
1006  assert_is_lvalue_if_expr<B>();
1007  return {pattern_arg(a), pattern_arg(b)};
1008 }
1009 
1010 template<typename A, typename B>
1011 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1012  return IRMatcher::operator/(a, b);
1013 }
1014 
1015 template<>
1017  return div_imp(a, b);
1018 }
1019 
1020 template<>
1022  return div_imp(a, b);
1023 }
1024 
1025 template<>
1026 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1027  return div_imp(a, b);
1028 }
1029 
1030 template<typename A, typename B>
1031 HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1032  assert_is_lvalue_if_expr<A>();
1033  assert_is_lvalue_if_expr<B>();
1034  return {pattern_arg(a), pattern_arg(b)};
1035 }
1036 
1037 template<typename A, typename B>
1038 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1039  assert_is_lvalue_if_expr<A>();
1040  assert_is_lvalue_if_expr<B>();
1041  return IRMatcher::operator%(a, b);
1042 }
1043 
1044 template<>
1046  return mod_imp(a, b);
1047 }
1048 
1049 template<>
1051  return mod_imp(a, b);
1052 }
1053 
1054 template<>
1055 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1056  return mod_imp(a, b);
1057 }
1058 
1059 template<typename A, typename B>
1060 HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1061  assert_is_lvalue_if_expr<A>();
1062  assert_is_lvalue_if_expr<B>();
1063  return {pattern_arg(a), pattern_arg(b)};
1064 }
1065 
1066 template<>
1068  return std::min(a, b);
1069 }
1070 
1071 template<>
1073  return std::min(a, b);
1074 }
1075 
1076 template<>
1077 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1078  return std::min(a, b);
1079 }
1080 
1081 template<typename A, typename B>
1082 HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1083  assert_is_lvalue_if_expr<A>();
1084  assert_is_lvalue_if_expr<B>();
1085  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1086 }
1087 
1088 template<>
1090  return std::max(a, b);
1091 }
1092 
1093 template<>
1095  return std::max(a, b);
1096 }
1097 
1098 template<>
1099 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1100  return std::max(a, b);
1101 }
1102 
1103 template<typename A, typename B>
1104 HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1105  return {pattern_arg(a), pattern_arg(b)};
1106 }
1107 
1108 template<typename A, typename B>
1109 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1110  return IRMatcher::operator<(a, b);
1111 }
1112 
1113 template<>
1115  return a < b;
1116 }
1117 
1118 template<>
1120  return a < b;
1121 }
1122 
1123 template<>
1125  return a < b;
1126 }
1127 
1128 template<typename A, typename B>
1129 HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1130  return {pattern_arg(a), pattern_arg(b)};
1131 }
1132 
1133 template<typename A, typename B>
1134 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1135  return IRMatcher::operator>(a, b);
1136 }
1137 
1138 template<>
1140  return a > b;
1141 }
1142 
1143 template<>
1145  return a > b;
1146 }
1147 
1148 template<>
1150  return a > b;
1151 }
1152 
1153 template<typename A, typename B>
1154 HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1155  return {pattern_arg(a), pattern_arg(b)};
1156 }
1157 
1158 template<typename A, typename B>
1159 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1160  return IRMatcher::operator<=(a, b);
1161 }
1162 
1163 template<>
1165  return a <= b;
1166 }
1167 
1168 template<>
1170  return a <= b;
1171 }
1172 
1173 template<>
1175  return a <= b;
1176 }
1177 
1178 template<typename A, typename B>
1179 HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1180  return {pattern_arg(a), pattern_arg(b)};
1181 }
1182 
1183 template<typename A, typename B>
1184 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1185  return IRMatcher::operator>=(a, b);
1186 }
1187 
1188 template<>
1190  return a >= b;
1191 }
1192 
1193 template<>
1195  return a >= b;
1196 }
1197 
1198 template<>
1200  return a >= b;
1201 }
1202 
1203 template<typename A, typename B>
1204 HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1205  return {pattern_arg(a), pattern_arg(b)};
1206 }
1207 
1208 template<typename A, typename B>
1209 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1210  return IRMatcher::operator==(a, b);
1211 }
1212 
1213 template<>
1215  return a == b;
1216 }
1217 
1218 template<>
1220  return a == b;
1221 }
1222 
1223 template<>
1225  return a == b;
1226 }
1227 
1228 template<typename A, typename B>
1229 HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1230  return {pattern_arg(a), pattern_arg(b)};
1231 }
1232 
1233 template<typename A, typename B>
1234 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1235  return IRMatcher::operator!=(a, b);
1236 }
1237 
1238 template<>
1240  return a != b;
1241 }
1242 
1243 template<>
1245  return a != b;
1246 }
1247 
1248 template<>
1250  return a != b;
1251 }
1252 
1253 template<typename A, typename B>
1254 HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1255  return {pattern_arg(a), pattern_arg(b)};
1256 }
1257 
1258 template<typename A, typename B>
1259 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1260  return IRMatcher::operator||(a, b);
1261 }
1262 
1263 template<>
1265  return (a | b) & 1;
1266 }
1267 
1268 template<>
1270  return (a | b) & 1;
1271 }
1272 
1273 template<>
1274 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1275  // Unreachable, as it would be a type mismatch.
1276  return 0;
1277 }
1278 
1279 template<typename A, typename B>
1280 HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281  return {pattern_arg(a), pattern_arg(b)};
1282 }
1283 
1284 template<typename A, typename B>
1285 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1286  return IRMatcher::operator&&(a, b);
1287 }
1288 
1289 template<>
1291  return a & b & 1;
1292 }
1293 
1294 template<>
1296  return a & b & 1;
1297 }
1298 
1299 template<>
1300 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1301  // Unreachable
1302  return 0;
1303 }
1304 
1306  return 0;
1307 }
1308 
1309 template<typename... Args>
1310 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1311  return first | bitwise_or_reduce(rest...);
1312 }
1313 
1314 constexpr bool and_reduce() {
1315  return true;
1316 }
1317 
1318 template<typename... Args>
1319 constexpr bool and_reduce(bool first, Args... rest) {
1320  return first && and_reduce(rest...);
1321 }
1322 
1323 // TODO: this can be replaced with std::min() once we require C++14 or later
1324 constexpr int const_min(int a, int b) {
1325  return a < b ? a : b;
1326 }
1327 
1328 template<Call::IntrinsicOp intrin>
1330  bool check(const Type &) const {
1331  return true;
1332  }
1333 };
1334 
1335 template<>
1338  bool check(const Type &t) const {
1339  return t == Type(type);
1340  }
1341 };
1342 
1343 template<Call::IntrinsicOp intrin, typename... Args>
1344 struct Intrin {
1345  struct pattern_tag {};
1346  std::tuple<Args...> args;
1347  // The type of the output of the intrinsic node.
1348  // Only necessary in cases where it can't be inferred
1349  // from the input types (e.g. saturating_cast).
1350 
1352 
1354 
1357  constexpr static bool canonical = and_reduce((Args::canonical)...);
1358 
1359  template<int i,
1360  uint32_t bound,
1361  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1362  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1363  using T = decltype(std::get<i>(args));
1364  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1365  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1366  }
1367 
1368  template<int i, uint32_t binds>
1369  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1370  return true;
1371  }
1372 
1373  template<uint32_t bound>
1374  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1375  if (e.node_type != IRNodeType::Call) {
1376  return false;
1377  }
1378  const Call &c = (const Call &)e;
1379  return (c.is_intrinsic(intrin) &&
1380  optional_type_hint.check(e.type) &&
1381  match_args<0, bound>(0, c, state));
1382  }
1383 
1384  template<int i,
1385  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1386  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1387  s << std::get<i>(args);
1388  if (i + 1 < sizeof...(Args)) {
1389  s << ", ";
1390  }
1391  print_args<i + 1>(0, s);
1392  }
1393 
1394  template<int i>
1395  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1396  }
1397 
1399  void print_args(std::ostream &s) const {
1400  print_args<0>(0, s);
1401  }
1402 
1404  Expr make(MatcherState &state, halide_type_t type_hint) const {
1405  Expr arg0 = std::get<0>(args).make(state, type_hint);
1406  if (intrin == Call::likely) {
1407  return likely(std::move(arg0));
1408  } else if (intrin == Call::likely_if_innermost) {
1409  return likely_if_innermost(std::move(arg0));
1410  } else if (intrin == Call::abs) {
1411  return abs(std::move(arg0));
1412  } else if constexpr (intrin == Call::saturating_cast) {
1413  return saturating_cast(optional_type_hint.type, std::move(arg0));
1414  }
1415 
1416  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1417  if (intrin == Call::absd) {
1418  return absd(std::move(arg0), std::move(arg1));
1419  } else if (intrin == Call::widen_right_add) {
1420  return widen_right_add(std::move(arg0), std::move(arg1));
1421  } else if (intrin == Call::widen_right_mul) {
1422  return widen_right_mul(std::move(arg0), std::move(arg1));
1423  } else if (intrin == Call::widen_right_sub) {
1424  return widen_right_sub(std::move(arg0), std::move(arg1));
1425  } else if (intrin == Call::widening_add) {
1426  return widening_add(std::move(arg0), std::move(arg1));
1427  } else if (intrin == Call::widening_sub) {
1428  return widening_sub(std::move(arg0), std::move(arg1));
1429  } else if (intrin == Call::widening_mul) {
1430  return widening_mul(std::move(arg0), std::move(arg1));
1431  } else if (intrin == Call::saturating_add) {
1432  return saturating_add(std::move(arg0), std::move(arg1));
1433  } else if (intrin == Call::saturating_sub) {
1434  return saturating_sub(std::move(arg0), std::move(arg1));
1435  } else if (intrin == Call::halving_add) {
1436  return halving_add(std::move(arg0), std::move(arg1));
1437  } else if (intrin == Call::halving_sub) {
1438  return halving_sub(std::move(arg0), std::move(arg1));
1439  } else if (intrin == Call::rounding_halving_add) {
1440  return rounding_halving_add(std::move(arg0), std::move(arg1));
1441  } else if (intrin == Call::shift_left) {
1442  return std::move(arg0) << std::move(arg1);
1443  } else if (intrin == Call::shift_right) {
1444  return std::move(arg0) >> std::move(arg1);
1445  } else if (intrin == Call::rounding_shift_left) {
1446  return rounding_shift_left(std::move(arg0), std::move(arg1));
1447  } else if (intrin == Call::rounding_shift_right) {
1448  return rounding_shift_right(std::move(arg0), std::move(arg1));
1449  }
1450 
1451  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1452  if (intrin == Call::mul_shift_right) {
1453  return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1454  } else if (intrin == Call::rounding_mul_shift_right) {
1455  return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1456  }
1457 
1458  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1459  return Expr();
1460  }
1461 
1462  constexpr static bool foldable = true;
1463 
1465  halide_scalar_value_t arg1;
1466  // Assuming the args have the same type as the intrinsic is incorrect in
1467  // general. But for the intrinsics we can fold (just shifts), the LHS
1468  // has the same type as the intrinsic, and we can always treat the RHS
1469  // as a signed int, because we're using 64 bits for it.
1470  std::get<0>(args).make_folded_const(val, ty, state);
1471  halide_type_t signed_ty = ty;
1472  signed_ty.code = halide_type_int;
1473  // We can just directly get the second arg here, because we only want to
1474  // instantiate this method for shifts, which have two args.
1475  std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1476 
1477  if (intrin == Call::shift_left) {
1478  if (arg1.u.i64 < 0) {
1479  if (ty.code == halide_type_int) {
1480  // Arithmetic shift
1481  val.u.i64 >>= -arg1.u.i64;
1482  } else {
1483  // Logical shift
1484  val.u.u64 >>= -arg1.u.i64;
1485  }
1486  } else {
1487  val.u.u64 <<= arg1.u.i64;
1488  }
1489  } else if (intrin == Call::shift_right) {
1490  if (arg1.u.i64 > 0) {
1491  if (ty.code == halide_type_int) {
1492  // Arithmetic shift
1493  val.u.i64 >>= arg1.u.i64;
1494  } else {
1495  // Logical shift
1496  val.u.u64 >>= arg1.u.i64;
1497  }
1498  } else {
1499  val.u.u64 <<= -arg1.u.i64;
1500  }
1501  } else {
1502  internal_error << "Folding not implemented for intrinsic: " << intrin;
1503  }
1504  }
1505 
1507  Intrin(Args... args) noexcept
1508  : args(args...) {
1509  }
1510 };
1511 
1512 template<Call::IntrinsicOp intrin, typename... Args>
1513 std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
1514  s << intrin << "(";
1515  op.print_args(s);
1516  s << ")";
1517  return s;
1518 }
1519 
1520 template<typename A, typename B>
1521 auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1522  return {pattern_arg(a), pattern_arg(b)};
1523 }
1524 template<typename A, typename B>
1525 auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526  return {pattern_arg(a), pattern_arg(b)};
1527 }
1528 template<typename A, typename B>
1529 auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1530  return {pattern_arg(a), pattern_arg(b)};
1531 }
1532 
1533 template<typename A, typename B>
1534 auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1535  return {pattern_arg(a), pattern_arg(b)};
1536 }
1537 template<typename A, typename B>
1538 auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1539  return {pattern_arg(a), pattern_arg(b)};
1540 }
1541 template<typename A, typename B>
1542 auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1543  return {pattern_arg(a), pattern_arg(b)};
1544 }
1545 template<typename A, typename B>
1546 auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1547  return {pattern_arg(a), pattern_arg(b)};
1548 }
1549 template<typename A, typename B>
1550 auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1551  return {pattern_arg(a), pattern_arg(b)};
1552 }
1553 template<typename A>
1554 auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
1555  Intrin<Call::saturating_cast, decltype(pattern_arg(a))> p = {pattern_arg(a)};
1556  p.optional_type_hint.type = t;
1557  return p;
1558 }
1559 template<typename A, typename B>
1560 auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561  return {pattern_arg(a), pattern_arg(b)};
1562 }
1563 template<typename A, typename B>
1564 auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1565  return {pattern_arg(a), pattern_arg(b)};
1566 }
1567 template<typename A, typename B>
1568 auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1569  return {pattern_arg(a), pattern_arg(b)};
1570 }
1571 template<typename A, typename B>
1572 auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1573  return {pattern_arg(a), pattern_arg(b)};
1574 }
1575 template<typename A, typename B>
1576 auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1577  return {pattern_arg(a), pattern_arg(b)};
1578 }
1579 template<typename A, typename B>
1580 auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1581  return {pattern_arg(a), pattern_arg(b)};
1582 }
1583 template<typename A, typename B>
1584 auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1585  return {pattern_arg(a), pattern_arg(b)};
1586 }
1587 template<typename A, typename B, typename C>
1588 auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1589  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1590 }
1591 template<typename A, typename B, typename C>
1592 auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1593  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1594 }
1595 
1596 template<typename A>
1597 auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
1598  return {pattern_arg(a)};
1599 }
1600 
1601 template<typename A, typename B>
1602 auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1603  return {pattern_arg(a), pattern_arg(b)};
1604 }
1605 
1606 template<typename A>
1607 auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
1608  return {pattern_arg(a)};
1609 }
1610 
1611 template<typename A>
1612 auto likely_if_innermost(A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
1613  return {pattern_arg(a)};
1614 }
1615 
1616 template<typename A>
1617 struct NotOp {
1618  struct pattern_tag {};
1619  A a;
1620 
1621  constexpr static uint32_t binds = bindings<A>::mask;
1622 
1625  constexpr static bool canonical = A::canonical;
1626 
1627  template<uint32_t bound>
1628  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1629  if (e.node_type != IRNodeType::Not) {
1630  return false;
1631  }
1632  const Not &op = (const Not &)e;
1633  return (a.template match<bound>(*op.a.get(), state));
1634  }
1635 
1636  template<uint32_t bound, typename A2>
1637  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1638  return a.template match<bound>(unwrap(op.a), state);
1639  }
1640 
1642  Expr make(MatcherState &state, halide_type_t type_hint) const {
1643  return Not::make(a.make(state, type_hint));
1644  }
1645 
1646  constexpr static bool foldable = A::foldable;
1647 
1648  template<typename A1 = A>
1650  a.make_folded_const(val, ty, state);
1651  val.u.u64 = ~val.u.u64;
1652  val.u.u64 &= 1;
1653  }
1654 };
1655 
1656 template<typename A>
1657 HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1658  assert_is_lvalue_if_expr<A>();
1659  return {pattern_arg(a)};
1660 }
1661 
1662 template<typename A>
1663 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1664  assert_is_lvalue_if_expr<A>();
1665  return IRMatcher::operator!(a);
1666 }
1667 
1668 template<typename A>
1669 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1670  s << "!(" << op.a << ")";
1671  return s;
1672 }
1673 
1674 template<typename C, typename T, typename F>
1675 struct SelectOp {
1676  struct pattern_tag {};
1677  C c;
1678  T t;
1679  F f;
1680 
1682 
1685 
1686  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1687 
1688  template<uint32_t bound>
1689  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1690  if (e.node_type != Select::_node_type) {
1691  return false;
1692  }
1693  const Select &op = (const Select &)e;
1694  return (c.template match<bound>(*op.condition.get(), state) &&
1695  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1696  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1697  }
1698  template<uint32_t bound, typename C2, typename T2, typename F2>
1699  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1700  return (c.template match<bound>(unwrap(instance.c), state) &&
1701  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1702  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1703  }
1704 
1706  Expr make(MatcherState &state, halide_type_t type_hint) const {
1707  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1708  }
1709 
1710  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1711 
1712  template<typename C1 = C>
1714  halide_scalar_value_t c_val, t_val, f_val;
1715  halide_type_t c_ty;
1716  c.make_folded_const(c_val, c_ty, state);
1717  if ((c_val.u.u64 & 1) == 1) {
1718  t.make_folded_const(val, ty, state);
1719  } else {
1720  f.make_folded_const(val, ty, state);
1721  }
1722  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1723  }
1724 };
1725 
1726 template<typename C, typename T, typename F>
1727 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1728  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1729  return s;
1730 }
1731 
1732 template<typename C, typename T, typename F>
1733 HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1734  assert_is_lvalue_if_expr<C>();
1735  assert_is_lvalue_if_expr<T>();
1736  assert_is_lvalue_if_expr<F>();
1737  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1738 }
1739 
1740 template<typename A, typename B>
1741 struct BroadcastOp {
1742  struct pattern_tag {};
1743  A a;
1745 
1747 
1750 
1751  constexpr static bool canonical = A::canonical && B::canonical;
1752 
1753  template<uint32_t bound>
1754  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1755  if (e.node_type == Broadcast::_node_type) {
1756  const Broadcast &op = (const Broadcast &)e;
1757  if (a.template match<bound>(*op.value.get(), state) &&
1758  lanes.template match<bound>(op.lanes, state)) {
1759  return true;
1760  }
1761  }
1762  return false;
1763  }
1764 
1765  template<uint32_t bound, typename A2, typename B2>
1766  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1767  return (a.template match<bound>(unwrap(op.a), state) &&
1768  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1769  }
1770 
1772  Expr make(MatcherState &state, halide_type_t type_hint) const {
1773  halide_scalar_value_t lanes_val;
1774  halide_type_t ty;
1775  lanes.make_folded_const(lanes_val, ty, state);
1776  int32_t l = (int32_t)lanes_val.u.i64;
1777  type_hint.lanes /= l;
1778  Expr val = a.make(state, type_hint);
1779  if (l == 1) {
1780  return val;
1781  } else {
1782  return Broadcast::make(std::move(val), l);
1783  }
1784  }
1785 
1786  constexpr static bool foldable = false;
1787 
1788  template<typename A1 = A>
1790  halide_scalar_value_t lanes_val;
1791  halide_type_t lanes_ty;
1792  lanes.make_folded_const(lanes_val, lanes_ty, state);
1793  uint16_t l = (uint16_t)lanes_val.u.i64;
1794  a.make_folded_const(val, ty, state);
1795  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1796  }
1797 };
1798 
1799 template<typename A, typename B>
1800 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1801  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1802  return s;
1803 }
1804 
1805 template<typename A, typename B>
1806 HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1807  assert_is_lvalue_if_expr<A>();
1808  return {pattern_arg(a), pattern_arg(lanes)};
1809 }
1810 
1811 template<typename A, typename B, typename C>
1812 struct RampOp {
1813  struct pattern_tag {};
1814  A a;
1815  B b;
1817 
1819 
1822 
1823  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1824 
1825  template<uint32_t bound>
1826  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1827  if (e.node_type != Ramp::_node_type) {
1828  return false;
1829  }
1830  const Ramp &op = (const Ramp &)e;
1831  if (a.template match<bound>(*op.base.get(), state) &&
1832  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1833  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1834  return true;
1835  } else {
1836  return false;
1837  }
1838  }
1839 
1840  template<uint32_t bound, typename A2, typename B2, typename C2>
1841  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1842  return (a.template match<bound>(unwrap(op.a), state) &&
1843  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1844  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1845  }
1846 
1848  Expr make(MatcherState &state, halide_type_t type_hint) const {
1849  halide_scalar_value_t lanes_val;
1850  halide_type_t ty;
1851  lanes.make_folded_const(lanes_val, ty, state);
1852  int32_t l = (int32_t)lanes_val.u.i64;
1853  type_hint.lanes /= l;
1854  Expr ea, eb;
1855  eb = b.make(state, type_hint);
1856  ea = a.make(state, eb.type());
1857  return Ramp::make(std::move(ea), std::move(eb), l);
1858  }
1859 
1860  constexpr static bool foldable = false;
1861 };
1862 
1863 template<typename A, typename B, typename C>
1864 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1865  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1866  return s;
1867 }
1868 
1869 template<typename A, typename B, typename C>
1870 HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1871  assert_is_lvalue_if_expr<A>();
1872  assert_is_lvalue_if_expr<B>();
1873  assert_is_lvalue_if_expr<C>();
1874  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1875 }
1876 
1877 template<typename A, typename B, VectorReduce::Operator reduce_op>
1879  struct pattern_tag {};
1880  A a;
1882 
1884 
1887  constexpr static bool canonical = A::canonical;
1888 
1889  template<uint32_t bound>
1890  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1891  if (e.node_type == VectorReduce::_node_type) {
1892  const VectorReduce &op = (const VectorReduce &)e;
1893  if (op.op == reduce_op &&
1894  a.template match<bound>(*op.value.get(), state) &&
1895  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1896  return true;
1897  }
1898  }
1899  return false;
1900  }
1901 
1902  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1904  return (reduce_op == reduce_op_2 &&
1905  a.template match<bound>(unwrap(op.a), state) &&
1906  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1907  }
1908 
1910  Expr make(MatcherState &state, halide_type_t type_hint) const {
1911  halide_scalar_value_t lanes_val;
1912  halide_type_t ty;
1913  lanes.make_folded_const(lanes_val, ty, state);
1914  int l = (int)lanes_val.u.i64;
1915  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1916  }
1917 
1918  constexpr static bool foldable = false;
1919 };
1920 
1921 template<typename A, typename B, VectorReduce::Operator reduce_op>
1922 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1923  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1924  return s;
1925 }
1926 
1927 template<typename A, typename B>
1928 HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1929  assert_is_lvalue_if_expr<A>();
1930  return {pattern_arg(a), pattern_arg(lanes)};
1931 }
1932 
1933 template<typename A, typename B>
1934 HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1935  assert_is_lvalue_if_expr<A>();
1936  return {pattern_arg(a), pattern_arg(lanes)};
1937 }
1938 
1939 template<typename A, typename B>
1940 HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1941  assert_is_lvalue_if_expr<A>();
1942  return {pattern_arg(a), pattern_arg(lanes)};
1943 }
1944 
1945 template<typename A, typename B>
1946 HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1947  assert_is_lvalue_if_expr<A>();
1948  return {pattern_arg(a), pattern_arg(lanes)};
1949 }
1950 
1951 template<typename A, typename B>
1952 HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1953  assert_is_lvalue_if_expr<A>();
1954  return {pattern_arg(a), pattern_arg(lanes)};
1955 }
1956 
1957 template<typename A>
1958 struct NegateOp {
1959  struct pattern_tag {};
1960  A a;
1961 
1962  constexpr static uint32_t binds = bindings<A>::mask;
1963 
1966 
1967  constexpr static bool canonical = A::canonical;
1968 
1969  template<uint32_t bound>
1970  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1971  if (e.node_type != Sub::_node_type) {
1972  return false;
1973  }
1974  const Sub &op = (const Sub &)e;
1975  return (a.template match<bound>(*op.b.get(), state) &&
1976  is_const_zero(op.a));
1977  }
1978 
1979  template<uint32_t bound, typename A2>
1980  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1981  return a.template match<bound>(unwrap(p.a), state);
1982  }
1983 
1985  Expr make(MatcherState &state, halide_type_t type_hint) const {
1986  Expr ea = a.make(state, type_hint);
1987  Expr z = make_zero(ea.type());
1988  return Sub::make(std::move(z), std::move(ea));
1989  }
1990 
1991  constexpr static bool foldable = A::foldable;
1992 
1993  template<typename A1 = A>
1995  a.make_folded_const(val, ty, state);
1996  int dead_bits = 64 - ty.bits;
1997  switch (ty.code) {
1998  case halide_type_int:
1999  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2000  // Trying to negate the most negative signed int for a no-overflow type.
2002  } else {
2003  // Negate, drop the high bits, and then sign-extend them back
2004  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2005  }
2006  break;
2007  case halide_type_uint:
2008  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2009  break;
2010  case halide_type_float:
2011  case halide_type_bfloat:
2012  val.u.f64 = -val.u.f64;
2013  break;
2014  default:
2015  // unreachable
2016  ;
2017  }
2018  }
2019 };
2020 
2021 template<typename A>
2022 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2023  s << "-" << op.a;
2024  return s;
2025 }
2026 
2027 template<typename A>
2028 HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2029  assert_is_lvalue_if_expr<A>();
2030  return {pattern_arg(a)};
2031 }
2032 
2033 template<typename A>
2034 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
2035  assert_is_lvalue_if_expr<A>();
2036  return IRMatcher::operator-(a);
2037 }
2038 
2039 template<typename A>
2040 struct CastOp {
2041  struct pattern_tag {};
2043  A a;
2044 
2045  constexpr static uint32_t binds = bindings<A>::mask;
2046 
2049  constexpr static bool canonical = A::canonical;
2050 
2051  template<uint32_t bound>
2052  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2053  if (e.node_type != Cast::_node_type) {
2054  return false;
2055  }
2056  const Cast &op = (const Cast &)e;
2057  return (e.type == t &&
2058  a.template match<bound>(*op.value.get(), state));
2059  }
2060  template<uint32_t bound, typename A2>
2061  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2062  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2063  }
2064 
2066  Expr make(MatcherState &state, halide_type_t type_hint) const {
2067  return cast(t, a.make(state, {}));
2068  }
2069 
2070  constexpr static bool foldable = false;
2071 };
2072 
2073 template<typename A>
2074 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2075  s << "cast(" << op.t << ", " << op.a << ")";
2076  return s;
2077 }
2078 
2079 template<typename A>
2080 HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2081  assert_is_lvalue_if_expr<A>();
2082  return {t, pattern_arg(a)};
2083 }
2084 
2085 template<typename A>
2086 struct WidenOp {
2087  struct pattern_tag {};
2088  A a;
2089 
2090  constexpr static uint32_t binds = bindings<A>::mask;
2091 
2094  constexpr static bool canonical = A::canonical;
2095 
2096  template<uint32_t bound>
2097  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2098  if (e.node_type != Cast::_node_type) {
2099  return false;
2100  }
2101  const Cast &op = (const Cast &)e;
2102  return (e.type == op.value.type().widen() &&
2103  a.template match<bound>(*op.value.get(), state));
2104  }
2105  template<uint32_t bound, typename A2>
2106  HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2107  return a.template match<bound>(unwrap(op.a), state);
2108  }
2109 
2111  Expr make(MatcherState &state, halide_type_t type_hint) const {
2112  Expr e = a.make(state, {});
2113  Type w = e.type().widen();
2114  return cast(w, std::move(e));
2115  }
2116 
2117  constexpr static bool foldable = false;
2118 };
2119 
2120 template<typename A>
2121 std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2122  s << "widen(" << op.a << ")";
2123  return s;
2124 }
2125 
2126 template<typename A>
2127 HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2128  assert_is_lvalue_if_expr<A>();
2129  return {pattern_arg(a)};
2130 }
2131 
2132 template<typename Vec, typename Base, typename Stride, typename Lanes>
2133 struct SliceOp {
2134  struct pattern_tag {};
2135  Vec vec;
2136  Base base;
2137  Stride stride;
2138  Lanes lanes;
2139 
2140  static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2141 
2144  constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2145 
2146  template<uint32_t bound>
2147  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2148  if (e.node_type != IRNodeType::Shuffle) {
2149  return false;
2150  }
2151  const Shuffle &v = (const Shuffle &)e;
2152  return v.vectors.size() == 1 &&
2153  v.is_slice() &&
2154  vec.template match<bound>(*v.vectors[0].get(), state) &&
2155  base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2156  stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2158  }
2159 
2161  Expr make(MatcherState &state, halide_type_t type_hint) const {
2162  halide_scalar_value_t base_val, stride_val, lanes_val;
2163  halide_type_t ty;
2164  base.make_folded_const(base_val, ty, state);
2165  int b = (int)base_val.u.i64;
2166  stride.make_folded_const(stride_val, ty, state);
2167  int s = (int)stride_val.u.i64;
2168  lanes.make_folded_const(lanes_val, ty, state);
2169  int l = (int)lanes_val.u.i64;
2170  return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2171  }
2172 
2173  constexpr static bool foldable = false;
2174 
2176  SliceOp(Vec v, Base b, Stride s, Lanes l)
2177  : vec(v), base(b), stride(s), lanes(l) {
2178  static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2179  static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2180  static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2181  }
2182 };
2183 
2184 template<typename Vec, typename Base, typename Stride, typename Lanes>
2185 std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2186  s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2187  return s;
2188 }
2189 
2190 template<typename Vec, typename Base, typename Stride, typename Lanes>
2191 HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2192  -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2193  return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2194 }
2195 
2196 template<typename A>
2197 struct Fold {
2198  struct pattern_tag {};
2199  A a;
2200 
2201  constexpr static uint32_t binds = bindings<A>::mask;
2202 
2205  constexpr static bool canonical = true;
2206 
2208  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2210  halide_type_t ty = type_hint;
2211  a.make_folded_const(c, ty, state);
2212 
2213  // The result of the fold may have an underspecified type
2214  // (e.g. because it's from an int literal). Make the type code
2215  // and bits match the required type, if there is one (we can
2216  // tell from the bits field).
2217  if (type_hint.bits) {
2218  if (((int)ty.code == (int)halide_type_int) &&
2219  ((int)type_hint.code == (int)halide_type_float)) {
2220  int64_t x = c.u.i64;
2221  c.u.f64 = (double)x;
2222  }
2223  ty.code = type_hint.code;
2224  ty.bits = type_hint.bits;
2225  }
2226 
2227  return make_const_expr(c, ty);
2228  }
2229 
2230  constexpr static bool foldable = A::foldable;
2231 
2232  template<typename A1 = A>
2234  a.make_folded_const(val, ty, state);
2235  }
2236 };
2237 
2238 template<typename A>
2239 HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2240  assert_is_lvalue_if_expr<A>();
2241  return {pattern_arg(a)};
2242 }
2243 
2244 template<typename A>
2245 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2246  s << "fold(" << op.a << ")";
2247  return s;
2248 }
2249 
2250 template<typename A>
2251 struct Overflows {
2252  struct pattern_tag {};
2253  A a;
2254 
2255  constexpr static uint32_t binds = bindings<A>::mask;
2256 
2257  // This rule is a predicate, so it always evaluates to a boolean,
2258  // which has IRNodeType UIntImm
2261  constexpr static bool canonical = true;
2262 
2263  constexpr static bool foldable = A::foldable;
2264 
2265  template<typename A1 = A>
2267  a.make_folded_const(val, ty, state);
2268  ty.code = halide_type_uint;
2269  ty.bits = 64;
2270  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2271  ty.lanes = 1;
2272  }
2273 };
2274 
2275 template<typename A>
2276 HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2277  assert_is_lvalue_if_expr<A>();
2278  return {pattern_arg(a)};
2279 }
2280 
2281 template<typename A>
2282 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2283  s << "overflows(" << op.a << ")";
2284  return s;
2285 }
2286 
2287 struct Overflow {
2288  struct pattern_tag {};
2289 
2290  constexpr static uint32_t binds = 0;
2291 
2292  // Overflow is an intrinsic, represented as a Call node
2295  constexpr static bool canonical = true;
2296 
2297  template<uint32_t bound>
2298  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2299  if (e.node_type != Call::_node_type) {
2300  return false;
2301  }
2302  const Call &op = (const Call &)e;
2304  }
2305 
2307  Expr make(MatcherState &state, halide_type_t type_hint) const {
2309  return make_const_special_expr(type_hint);
2310  }
2311 
2312  constexpr static bool foldable = true;
2313 
2315  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2316  val.u.u64 = 0;
2318  }
2319 };
2320 
2321 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2322  s << "overflow()";
2323  return s;
2324 }
2325 
2326 template<typename A>
2327 struct IsConst {
2328  struct pattern_tag {};
2329 
2330  constexpr static uint32_t binds = bindings<A>::mask;
2331 
2332  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2335  constexpr static bool canonical = true;
2336 
2337  A a;
2338  bool check_v;
2340 
2341  constexpr static bool foldable = true;
2342 
2343  template<typename A1 = A>
2345  Expr e = a.make(state, {});
2346  ty.code = halide_type_uint;
2347  ty.bits = 64;
2348  ty.lanes = 1;
2349  if (check_v) {
2350  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2351  } else {
2352  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2353  }
2354  }
2355 };
2356 
2357 template<typename A>
2358 HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2359  assert_is_lvalue_if_expr<A>();
2360  return {pattern_arg(a), false, 0};
2361 }
2362 
2363 template<typename A>
2364 HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2365  assert_is_lvalue_if_expr<A>();
2366  return {pattern_arg(a), true, value};
2367 }
2368 
2369 template<typename A>
2370 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2371  if (op.check_v) {
2372  s << "is_const(" << op.a << ")";
2373  } else {
2374  s << "is_const(" << op.a << ", " << op.v << ")";
2375  }
2376  return s;
2377 }
2378 
2379 template<typename A, typename Prover>
2380 struct CanProve {
2381  struct pattern_tag {};
2382  A a;
2383  Prover *prover; // An existing simplifying mutator
2384 
2385  constexpr static uint32_t binds = bindings<A>::mask;
2386 
2387  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2390  constexpr static bool canonical = true;
2391 
2392  constexpr static bool foldable = true;
2393 
2394  // Includes a raw call to an inlined make method, so don't inline.
2396  Expr condition = a.make(state, {});
2397  condition = prover->mutate(condition, nullptr);
2398  val.u.u64 = is_const_one(condition);
2399  ty.code = halide_type_uint;
2400  ty.bits = 1;
2401  ty.lanes = condition.type().lanes();
2402  }
2403 };
2404 
2405 template<typename A, typename Prover>
2406 HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2407  assert_is_lvalue_if_expr<A>();
2408  return {pattern_arg(a), p};
2409 }
2410 
2411 template<typename A, typename Prover>
2412 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2413  s << "can_prove(" << op.a << ")";
2414  return s;
2415 }
2416 
2417 template<typename A>
2418 struct IsFloat {
2419  struct pattern_tag {};
2420  A a;
2421 
2422  constexpr static uint32_t binds = bindings<A>::mask;
2423 
2424  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2427  constexpr static bool canonical = true;
2428 
2429  constexpr static bool foldable = true;
2430 
2433  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2434  Type t = a.make(state, {}).type();
2435  val.u.u64 = t.is_float();
2436  ty.code = halide_type_uint;
2437  ty.bits = 1;
2438  ty.lanes = t.lanes();
2439  }
2440 };
2441 
2442 template<typename A>
2443 HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2444  assert_is_lvalue_if_expr<A>();
2445  return {pattern_arg(a)};
2446 }
2447 
2448 template<typename A>
2449 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2450  s << "is_float(" << op.a << ")";
2451  return s;
2452 }
2453 
2454 template<typename A>
2455 struct IsInt {
2456  struct pattern_tag {};
2457  A a;
2460 
2461  constexpr static uint32_t binds = bindings<A>::mask;
2462 
2463  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2466  constexpr static bool canonical = true;
2467 
2468  constexpr static bool foldable = true;
2469 
2472  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2473  Type t = a.make(state, {}).type();
2474  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2475  ty.code = halide_type_uint;
2476  ty.bits = 1;
2477  ty.lanes = t.lanes();
2478  }
2479 };
2480 
2481 template<typename A>
2482 HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2483  assert_is_lvalue_if_expr<A>();
2484  return {pattern_arg(a), bits, lanes};
2485 }
2486 
2487 template<typename A>
2488 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2489  s << "is_int(" << op.a;
2490  if (op.bits > 0) {
2491  s << ", " << op.bits;
2492  }
2493  if (op.lanes > 0) {
2494  s << ", " << op.lanes;
2495  }
2496  s << ")";
2497  return s;
2498 }
2499 
2500 template<typename A>
2501 struct IsUInt {
2502  struct pattern_tag {};
2503  A a;
2506 
2507  constexpr static uint32_t binds = bindings<A>::mask;
2508 
2509  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2512  constexpr static bool canonical = true;
2513 
2514  constexpr static bool foldable = true;
2515 
2518  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2519  Type t = a.make(state, {}).type();
2520  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2521  ty.code = halide_type_uint;
2522  ty.bits = 1;
2523  ty.lanes = t.lanes();
2524  }
2525 };
2526 
2527 template<typename A>
2528 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2529  assert_is_lvalue_if_expr<A>();
2530  return {pattern_arg(a), bits, lanes};
2531 }
2532 
2533 template<typename A>
2534 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2535  s << "is_uint(" << op.a;
2536  if (op.bits > 0) {
2537  s << ", " << op.bits;
2538  }
2539  if (op.lanes > 0) {
2540  s << ", " << op.lanes;
2541  }
2542  s << ")";
2543  return s;
2544 }
2545 
2546 template<typename A>
2547 struct IsScalar {
2548  struct pattern_tag {};
2549  A a;
2550 
2551  constexpr static uint32_t binds = bindings<A>::mask;
2552 
2553  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2556  constexpr static bool canonical = true;
2557 
2558  constexpr static bool foldable = true;
2559 
2562  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2563  Type t = a.make(state, {}).type();
2564  val.u.u64 = t.is_scalar();
2565  ty.code = halide_type_uint;
2566  ty.bits = 1;
2567  ty.lanes = t.lanes();
2568  }
2569 };
2570 
2571 template<typename A>
2572 HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2573  assert_is_lvalue_if_expr<A>();
2574  return {pattern_arg(a)};
2575 }
2576 
2577 template<typename A>
2578 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2579  s << "is_scalar(" << op.a << ")";
2580  return s;
2581 }
2582 
2583 template<typename A>
2584 struct IsMaxValue {
2585  struct pattern_tag {};
2586  A a;
2587 
2588  constexpr static uint32_t binds = bindings<A>::mask;
2589 
2590  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2593  constexpr static bool canonical = true;
2594 
2595  constexpr static bool foldable = true;
2596 
2599  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2600  a.make_folded_const(val, ty, state);
2601  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2602  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2603  val.u.u64 = (val.u.u64 == max_bits);
2604  } else {
2605  val.u.u64 = 0;
2606  }
2607  ty.code = halide_type_uint;
2608  ty.bits = 1;
2609  }
2610 };
2611 
2612 template<typename A>
2613 HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2614  assert_is_lvalue_if_expr<A>();
2615  return {pattern_arg(a)};
2616 }
2617 
2618 template<typename A>
2619 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2620  s << "is_max_value(" << op.a << ")";
2621  return s;
2622 }
2623 
2624 template<typename A>
2625 struct IsMinValue {
2626  struct pattern_tag {};
2627  A a;
2628 
2629  constexpr static uint32_t binds = bindings<A>::mask;
2630 
2631  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2634  constexpr static bool canonical = true;
2635 
2636  constexpr static bool foldable = true;
2637 
2640  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2641  a.make_folded_const(val, ty, state);
2642  if (ty.code == halide_type_int) {
2643  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2644  val.u.u64 = (val.u.u64 == min_bits);
2645  } else if (ty.code == halide_type_uint) {
2646  val.u.u64 = (val.u.u64 == 0);
2647  } else {
2648  val.u.u64 = 0;
2649  }
2650  ty.code = halide_type_uint;
2651  ty.bits = 1;
2652  }
2653 };
2654 
2655 template<typename A>
2656 HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2657  assert_is_lvalue_if_expr<A>();
2658  return {pattern_arg(a)};
2659 }
2660 
2661 template<typename A>
2662 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2663  s << "is_min_value(" << op.a << ")";
2664  return s;
2665 }
2666 
2667 template<typename A>
2668 struct LanesOf {
2669  struct pattern_tag {};
2670  A a;
2671 
2672  constexpr static uint32_t binds = bindings<A>::mask;
2673 
2674  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2677  constexpr static bool canonical = true;
2678 
2679  constexpr static bool foldable = true;
2680 
2683  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2684  Type t = a.make(state, {}).type();
2685  val.u.u64 = t.lanes();
2686  ty.code = halide_type_uint;
2687  ty.bits = 32;
2688  ty.lanes = 1;
2689  }
2690 };
2691 
2692 template<typename A>
2693 HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2694  assert_is_lvalue_if_expr<A>();
2695  return {pattern_arg(a)};
2696 }
2697 
2698 template<typename A>
2699 std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2700  s << "lanes_of(" << op.a << ")";
2701  return s;
2702 }
2703 
2704 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2705 template<typename Before,
2706  typename After,
2707  typename Predicate,
2708  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2709  std::decay<After>::type::foldable>::type>
2710 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2711  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2712 
2713  // We only validate the rules in the scalar case
2714  wildcard_type.lanes = output_type.lanes = 1;
2715 
2716  // Track which types this rule has been tested for before
2717  static std::set<uint32_t> tested;
2718 
2719  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2720  return;
2721  }
2722 
2723  // Print it in a form where it can be piped into a python/z3 validator
2724  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2725 
2726  // Substitute some random constants into the before and after
2727  // expressions and see if the rule holds true. This should catch
2728  // silly errors, but not necessarily corner cases.
2729  static std::mt19937_64 rng(0);
2730  MatcherState state;
2731 
2732  Expr exprs[max_wild];
2733 
2734  for (int trials = 0; trials < 100; trials++) {
2735  // We want to test small constants more frequently than
2736  // large ones, otherwise we'll just get coverage of
2737  // overflow rules.
2738  int shift = (int)(rng() & (wildcard_type.bits - 1));
2739 
2740  for (int i = 0; i < max_wild; i++) {
2741  // Bind all the exprs and constants
2742  switch (wildcard_type.code) {
2743  case halide_type_uint: {
2744  // Normalize to the type's range by adding zero
2745  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2746  state.set_bound_const(i, val, wildcard_type);
2747  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2748  exprs[i] = make_const(wildcard_type, val);
2749  state.set_binding(i, *exprs[i].get());
2750  } break;
2751  case halide_type_int: {
2752  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2753  state.set_bound_const(i, val, wildcard_type);
2754  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2755  exprs[i] = make_const(wildcard_type, val);
2756  } break;
2757  case halide_type_float:
2758  case halide_type_bfloat: {
2759  // Use a very narrow range of precise floats, so
2760  // that none of the rules a human is likely to
2761  // write have instabilities.
2762  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2763  state.set_bound_const(i, val, wildcard_type);
2764  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2765  exprs[i] = make_const(wildcard_type, val);
2766  } break;
2767  default:
2768  return; // Don't care about handles
2769  }
2770  state.set_binding(i, *exprs[i].get());
2771  }
2772 
2773  halide_scalar_value_t val_pred, val_before, val_after;
2774  halide_type_t type = output_type;
2775  if (!evaluate_predicate(pred, state)) {
2776  continue;
2777  }
2778  before.make_folded_const(val_before, type, state);
2779  uint16_t lanes = type.lanes;
2780  after.make_folded_const(val_after, type, state);
2781  lanes |= type.lanes;
2782 
2783  if (lanes & MatcherState::special_values_mask) {
2784  continue;
2785  }
2786 
2787  bool ok = true;
2788  switch (output_type.code) {
2789  case halide_type_uint:
2790  // Compare normalized representations
2791  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2792  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2793  break;
2794  case halide_type_int:
2795  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2796  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2797  break;
2798  case halide_type_float:
2799  case halide_type_bfloat: {
2800  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2801  // We accept an equal bit pattern (e.g. inf vs inf),
2802  // a small floating point difference, or turning a nan into not-a-nan.
2803  ok &= (error < 0.01 ||
2804  val_before.u.u64 == val_after.u.u64 ||
2805  std::isnan(val_before.u.f64));
2806  break;
2807  }
2808  default:
2809  return;
2810  }
2811 
2812  if (!ok) {
2813  debug(0) << "Fails with values:\n";
2814  for (int i = 0; i < max_wild; i++) {
2816  state.get_bound_const(i, val, wildcard_type);
2817  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2818  }
2819  for (int i = 0; i < max_wild; i++) {
2820  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2821  }
2822  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2823  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2824  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2826  }
2827  }
2828 }
2829 
2830 template<typename Before,
2831  typename After,
2832  typename Predicate,
2833  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2834  std::decay<After>::type::foldable)>::type>
2835 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2836  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2837  // We can't verify rewrite rules that can't be constant-folded.
2838 }
2839 
2841 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2842  return x;
2843 }
2844 
2845 template<typename Pattern,
2846  typename = typename enable_if_pattern<Pattern>::type>
2849  halide_type_t ty = halide_type_of<bool>();
2850  p.make_folded_const(c, ty, state);
2851  // Overflow counts as a failed predicate
2852  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2853 }
2854 
2855 // #defines for testing
2856 
2857 // Print all successful or failed matches
2858 #define HALIDE_DEBUG_MATCHED_RULES 0
2859 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2860 
2861 // Set to true if you want to fuzz test every rewrite passed to
2862 // operator() to ensure the input and the output have the same value
2863 // for lots of random values of the wildcards. Run
2864 // correctness_simplify with this on.
2865 #define HALIDE_FUZZ_TEST_RULES 0
2866 
2867 template<typename Instance>
2868 struct Rewriter {
2869  Instance instance;
2873  bool validate;
2874 
2877  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2878  }
2879 
2880  template<typename After>
2882 #if HALIDE_DEBUG_MATCHED_RULES
2883  debug(0) << instance << " -> " << after << "\n";
2884 #endif
2885  result = after.make(state, output_type);
2886  }
2887 
2888  template<typename Before,
2889  typename After,
2890  typename = typename enable_if_pattern<Before>::type,
2891  typename = typename enable_if_pattern<After>::type>
2892  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2893  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2894  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2895  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2896 #if HALIDE_FUZZ_TEST_RULES
2897  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2898 #endif
2899  if (before.template match<0>(unwrap(instance), state)) {
2900  build_replacement(after);
2901 #if HALIDE_DEBUG_MATCHED_RULES
2902  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2903 #endif
2904  return true;
2905  } else {
2906 #if HALIDE_DEBUG_UNMATCHED_RULES
2907  debug(0) << instance << " does not match " << before << "\n";
2908 #endif
2909  return false;
2910  }
2911  }
2912 
2913  template<typename Before,
2914  typename = typename enable_if_pattern<Before>::type>
2915  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2916  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2917  if (before.template match<0>(unwrap(instance), state)) {
2918  result = after;
2919 #if HALIDE_DEBUG_MATCHED_RULES
2920  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2921 #endif
2922  return true;
2923  } else {
2924 #if HALIDE_DEBUG_UNMATCHED_RULES
2925  debug(0) << instance << " does not match " << before << "\n";
2926 #endif
2927  return false;
2928  }
2929  }
2930 
2931  template<typename Before,
2932  typename = typename enable_if_pattern<Before>::type>
2933  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2934  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2935 #if HALIDE_FUZZ_TEST_RULES
2936  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2937 #endif
2938  if (before.template match<0>(unwrap(instance), state)) {
2939  result = make_const(output_type, after);
2940 #if HALIDE_DEBUG_MATCHED_RULES
2941  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2942 #endif
2943  return true;
2944  } else {
2945 #if HALIDE_DEBUG_UNMATCHED_RULES
2946  debug(0) << instance << " does not match " << before << "\n";
2947 #endif
2948  return false;
2949  }
2950  }
2951 
2952  template<typename Before,
2953  typename After,
2954  typename Predicate,
2955  typename = typename enable_if_pattern<Before>::type,
2956  typename = typename enable_if_pattern<After>::type,
2957  typename = typename enable_if_pattern<Predicate>::type>
2958  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2959  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2960  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2961  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2962  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2963  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2964 
2965 #if HALIDE_FUZZ_TEST_RULES
2966  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2967 #endif
2968  if (before.template match<0>(unwrap(instance), state) &&
2969  evaluate_predicate(pred, state)) {
2970  build_replacement(after);
2971 #if HALIDE_DEBUG_MATCHED_RULES
2972  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2973 #endif
2974  return true;
2975  } else {
2976 #if HALIDE_DEBUG_UNMATCHED_RULES
2977  debug(0) << instance << " does not match " << before << "\n";
2978 #endif
2979  return false;
2980  }
2981  }
2982 
2983  template<typename Before,
2984  typename Predicate,
2985  typename = typename enable_if_pattern<Before>::type,
2986  typename = typename enable_if_pattern<Predicate>::type>
2987  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2988  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2989  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2990 
2991  if (before.template match<0>(unwrap(instance), state) &&
2992  evaluate_predicate(pred, state)) {
2993  result = after;
2994 #if HALIDE_DEBUG_MATCHED_RULES
2995  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2996 #endif
2997  return true;
2998  } else {
2999 #if HALIDE_DEBUG_UNMATCHED_RULES
3000  debug(0) << instance << " does not match " << before << "\n";
3001 #endif
3002  return false;
3003  }
3004  }
3005 
3006  template<typename Before,
3007  typename Predicate,
3008  typename = typename enable_if_pattern<Before>::type,
3009  typename = typename enable_if_pattern<Predicate>::type>
3010  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
3011  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3012  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3013 #if HALIDE_FUZZ_TEST_RULES
3014  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
3015 #endif
3016  if (before.template match<0>(unwrap(instance), state) &&
3017  evaluate_predicate(pred, state)) {
3018  result = make_const(output_type, after);
3019 #if HALIDE_DEBUG_MATCHED_RULES
3020  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3021 #endif
3022  return true;
3023  } else {
3024 #if HALIDE_DEBUG_UNMATCHED_RULES
3025  debug(0) << instance << " does not match " << before << "\n";
3026 #endif
3027  return false;
3028  }
3029  }
3030 };
3031 
3032 /** Construct a rewriter for the given instance, which may be a pattern
3033  * with concrete expressions as leaves, or just an expression. The
3034  * second optional argument (wildcard_type) is a hint as to what the
3035  * type of the wildcards is likely to be. If omitted it uses the same
3036  * type as the expression itself. They are not required to be this
3037  * type, but the rule will only be tested for wildcards of that type
3038  * when testing is enabled.
3039  *
3040  * The rewriter can be used to check to see if the instance is one of
3041  * some number of patterns and if so rewrite it into another form,
3042  * using its operator() method. See Simplify.cpp for a bunch of
3043  * example usage.
3044  *
3045  * Important: Any Exprs in patterns are captured by reference, not by
3046  * value, so ensure they outlive the rewriter.
3047  */
3048 // @{
3049 template<typename Instance,
3050  typename = typename enable_if_pattern<Instance>::type>
3051 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3052  return {pattern_arg(instance), output_type, wildcard_type};
3053 }
3054 
3055 template<typename Instance,
3056  typename = typename enable_if_pattern<Instance>::type>
3057 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3058  return {pattern_arg(instance), output_type, output_type};
3059 }
3060 
3062 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3063  return {pattern_arg(e), e.type(), wildcard_type};
3064 }
3065 
3067 auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3068  return {pattern_arg(e), e.type(), e.type()};
3069 }
3070 // @}
3071 
3072 } // namespace IRMatcher
3073 
3074 } // namespace Internal
3075 } // namespace Halide
3076 
3077 #endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:50
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:217
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:3051
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:567
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1568
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1259
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1657
auto shift_right(A &&a, B &&b) noexcept -> Intrin< Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1576
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1060
auto widening_add(A &&a, B &&b) noexcept -> Intrin< Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1534
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2482
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2841
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1016
auto abs(A &&a) noexcept -> Intrin< Call::abs, decltype(pattern_arg(a))>
Definition: IRMatch.h:1597
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1234
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2528
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:2034
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1154
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:905
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2613
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1946
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1134
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2358
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1164
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:971
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:912
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1521
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1011
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1525
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:978
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1082
auto absd(A &&a, B &&b) noexcept -> Intrin< Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1602
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:2191
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1870
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1004
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2127
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1045
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1290
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:559
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1542
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1129
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2080
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2276
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< Call::saturating_cast, decltype(pattern_arg(a))>
Definition: IRMatch.h:1554
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1031
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:952
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1580
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2572
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2239
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1663
auto likely(A &&a) noexcept -> Intrin< Call::likely, decltype(pattern_arg(a))>
Definition: IRMatch.h:1607
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1089
constexpr bool and_reduce()
Definition: IRMatch.h:1314
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1254
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1229
auto halving_add(A &&a, B &&b) noexcept -> Intrin< Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1560
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1588
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2443
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1538
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1179
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1104
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1280
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1952
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:615
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:945
auto likely_if_innermost(A &&a) noexcept -> Intrin< Call::likely_if_innermost, decltype(pattern_arg(a))>
Definition: IRMatch.h:1612
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1940
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1806
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1733
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1550
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2656
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1067
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2710
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1139
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:985
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1189
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:938
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1592
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1546
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1159
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1109
auto shift_left(A &&a, B &&b) noexcept -> Intrin< Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1572
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2364
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition: IRMatch.h:2693
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1584
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1114
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1934
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1928
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1264
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1305
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1214
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1184
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1324
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1239
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1038
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1204
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1529
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:919
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2406
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1209
T div_imp(T a, T b)
Definition: IROperator.h:273
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:252
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
@ C
No name mangling.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:321
The sum of two expressions.
Definition: IR.h:56
Logical and - are both expressions true.
Definition: IR.h:175
A base class for expression nodes.
Definition: Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:265
A function call.
Definition: IR.h:490
@ signed_integer_overflow
Definition: IR.h:595
@ rounding_mul_shift_right
Definition: IR.h:585
bool is_intrinsic() const
Definition: IR.h:721
static const IRNodeType _node_type
Definition: IR.h:766
The actual IR nodes begin here.
Definition: IR.h:30
static const IRNodeType _node_type
Definition: IR.h:35
The ratio of two expressions.
Definition: IR.h:83
Is the first expression equal to the second.
Definition: IR.h:121
Floating point constants.
Definition: Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:166
Is the first expression greater than the second.
Definition: IR.h:157
constexpr static uint32_t binds
Definition: IRMatch.h:633
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:636
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:664
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:645
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:635
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:655
constexpr static bool canonical
Definition: IRMatch.h:641
constexpr static bool foldable
Definition: IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1772
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1766
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1754
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1748
constexpr static uint32_t binds
Definition: IRMatch.h:1746
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1789
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1749
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2395
constexpr static bool foldable
Definition: IRMatch.h:2392
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2388
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2389
constexpr static uint32_t binds
Definition: IRMatch.h:2385
constexpr static bool canonical
Definition: IRMatch.h:2390
constexpr static bool canonical
Definition: IRMatch.h:2049
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2048
constexpr static bool foldable
Definition: IRMatch.h:2070
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2052
constexpr static uint32_t binds
Definition: IRMatch.h:2045
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2047
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2061
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2066
constexpr static bool canonical
Definition: IRMatch.h:740
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:800
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:738
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:739
constexpr static bool foldable
Definition: IRMatch.h:763
constexpr static uint32_t binds
Definition: IRMatch.h:736
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:747
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:766
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:757
constexpr static bool foldable
Definition: IRMatch.h:2230
constexpr static uint32_t binds
Definition: IRMatch.h:2201
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2203
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2204
constexpr static bool canonical
Definition: IRMatch.h:2205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2208
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2233
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:504
constexpr static bool canonical
Definition: IRMatch.h:496
constexpr static uint32_t binds
Definition: IRMatch.h:492
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:539
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:494
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:495
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:522
constexpr static bool foldable
Definition: IRMatch.h:536
HALIDE_ALWAYS_INLINE Intrin(Args... args) noexcept
Definition: IRMatch.h:1507
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1399
constexpr static bool canonical
Definition: IRMatch.h:1357
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1374
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1464
std::tuple< Args... > args
Definition: IRMatch.h:1346
static constexpr uint32_t binds
Definition: IRMatch.h:1353
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1395
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1355
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1362
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1369
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1386
constexpr static bool foldable
Definition: IRMatch.h:1462
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1356
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1404
OptionalIntrinType< intrin > optional_type_hint
Definition: IRMatch.h:1351
constexpr static bool canonical
Definition: IRMatch.h:2335
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2344
constexpr static bool foldable
Definition: IRMatch.h:2341
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2334
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2333
constexpr static uint32_t binds
Definition: IRMatch.h:2330
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2425
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2432
constexpr static bool canonical
Definition: IRMatch.h:2427
constexpr static uint32_t binds
Definition: IRMatch.h:2422
constexpr static bool foldable
Definition: IRMatch.h:2429
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2426
constexpr static uint32_t binds
Definition: IRMatch.h:2461
constexpr static bool foldable
Definition: IRMatch.h:2468
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2464
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2471
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2465
constexpr static bool canonical
Definition: IRMatch.h:2466
constexpr static bool canonical
Definition: IRMatch.h:2593
constexpr static bool foldable
Definition: IRMatch.h:2595
constexpr static uint32_t binds
Definition: IRMatch.h:2588
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2591
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2592
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2598
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2632
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2633
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2639
constexpr static bool canonical
Definition: IRMatch.h:2634
constexpr static uint32_t binds
Definition: IRMatch.h:2629
constexpr static bool foldable
Definition: IRMatch.h:2636
constexpr static bool foldable
Definition: IRMatch.h:2558
constexpr static bool canonical
Definition: IRMatch.h:2556
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2561
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2555
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2554
constexpr static uint32_t binds
Definition: IRMatch.h:2551
constexpr static bool canonical
Definition: IRMatch.h:2512
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2517
constexpr static uint32_t binds
Definition: IRMatch.h:2507
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2510
constexpr static bool foldable
Definition: IRMatch.h:2514
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2511
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2682
constexpr static bool foldable
Definition: IRMatch.h:2679
constexpr static bool canonical
Definition: IRMatch.h:2677
constexpr static uint32_t binds
Definition: IRMatch.h:2672
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2675
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2676
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1964
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1965
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1970
constexpr static uint32_t binds
Definition: IRMatch.h:1962
constexpr static bool canonical
Definition: IRMatch.h:1967
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1985
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1980
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1994
constexpr static bool foldable
Definition: IRMatch.h:1991
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1623
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1628
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1637
constexpr static uint32_t binds
Definition: IRMatch.h:1621
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1624
constexpr static bool foldable
Definition: IRMatch.h:1646
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1642
constexpr static bool canonical
Definition: IRMatch.h:1625
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1649
constexpr static bool canonical
Definition: IRMatch.h:2295
constexpr static bool foldable
Definition: IRMatch.h:2312
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2298
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2307
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2294
constexpr static uint32_t binds
Definition: IRMatch.h:2290
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2315
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2293
constexpr static bool foldable
Definition: IRMatch.h:2263
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2266
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2260
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2259
constexpr static uint32_t binds
Definition: IRMatch.h:2255
constexpr static bool canonical
Definition: IRMatch.h:2261
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1848
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1821
constexpr static bool canonical
Definition: IRMatch.h:1823
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1820
constexpr static bool foldable
Definition: IRMatch.h:1860
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1841
constexpr static uint32_t binds
Definition: IRMatch.h:1818
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1826
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2881
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2958
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2933
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2876
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2987
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2915
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:3010
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2892
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1684
constexpr static bool canonical
Definition: IRMatch.h:1686
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1713
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1683
constexpr static uint32_t binds
Definition: IRMatch.h:1681
constexpr static bool foldable
Definition: IRMatch.h:1710
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1699
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1689
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1706
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2143
constexpr static bool foldable
Definition: IRMatch.h:2173
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition: IRMatch.h:2176
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2147
static constexpr uint32_t binds
Definition: IRMatch.h:2140
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2142
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2161
constexpr static bool canonical
Definition: IRMatch.h:2144
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:210
constexpr static uint32_t binds
Definition: IRMatch.h:195
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:199
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1903
constexpr static uint32_t binds
Definition: IRMatch.h:1883
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1886
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1890
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1885
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1910
constexpr static bool foldable
Definition: IRMatch.h:2117
constexpr static bool canonical
Definition: IRMatch.h:2094
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2111
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2106
constexpr static uint32_t binds
Definition: IRMatch.h:2090
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2097
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2092
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2093
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:352
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:348
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:373
constexpr static uint32_t binds
Definition: IRMatch.h:345
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:347
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:383
constexpr static uint32_t binds
Definition: IRMatch.h:399
constexpr static bool foldable
Definition: IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:431
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:402
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:406
constexpr static bool canonical
Definition: IRMatch.h:403
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:441
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:401
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:425
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:228
constexpr static uint32_t binds
Definition: IRMatch.h:226
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:267
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:254
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:229
constexpr static uint32_t binds
Definition: IRMatch.h:292
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:295
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:294
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:330
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:320
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:460
constexpr static bool foldable
Definition: IRMatch.h:477
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:473
constexpr static bool canonical
Definition: IRMatch.h:461
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:459
constexpr static uint32_t binds
Definition: IRMatch.h:457
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:464
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:113
Integer constants.
Definition: Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:148
Is the first expression less than the second.
Definition: IR.h:139
The greater of two values.
Definition: IR.h:112
The lesser of two values.
Definition: IR.h:103
The remainder of a / b.
Definition: IR.h:94
The product of two expressions.
Definition: IR.h:74
Is the first expression not equal to the second.
Definition: IR.h:130
Logical not - true if the expression false.
Definition: IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:184
A linear ramp vector node.
Definition: IR.h:247
static const IRNodeType _node_type
Definition: IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:855
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition: IR.h:856
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:909
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition: IR.h:906
The difference of two expressions.
Definition: IR.h:65
static const IRNodeType _node_type
Definition: IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:979
static const IRNodeType _node_type
Definition: IR.h:998
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition: Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:349
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.