Halide  20.0.0
Halide compiler and libraries
Simplify_Internal.h
Go to the documentation of this file.
1 #ifndef HALIDE_SIMPLIFY_VISITORS_H
2 #define HALIDE_SIMPLIFY_VISITORS_H
3 
4 /** \file
5  * The simplifier is separated into multiple compilation units with
6  * this single shared header to speed up the build. This file is not
7  * exported in Halide.h. */
8 
9 #include "Bounds.h"
10 #include "ConstantInterval.h"
11 #include "IRMatch.h"
12 #include "IRPrinter.h"
13 #include "IRVisitor.h"
14 #include "Scope.h"
15 
16 // Because this file is only included by the simplify methods and
17 // doesn't go into Halide.h, we're free to use any old names for our
18 // macros.
19 
20 #define LOG_EXPR_MUTATIONS 0
21 #define LOG_STMT_MUTATIONS 0
22 
23 // On old compilers, some visitors would use large stack frames,
24 // because they use expression templates that generate large numbers
25 // of temporary objects when they are built and matched against. If we
26 // wrap the expressions that imply lots of temporaries in a lambda, we
27 // can get these large frames out of the recursive path.
28 #define EVAL_IN_LAMBDA(x) (([&]() HALIDE_NEVER_INLINE { return (x); })())
29 
30 namespace Halide {
31 namespace Internal {
32 
33 class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
35 
36 public:
37  Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai);
38 
39  struct ExprInfo {
40  // We track constant integer bounds when they exist
42  // And the alignment of integer variables
44 
46  if (alignment.modulus == 0) {
48  } else if (alignment.modulus > 1) {
49  if (bounds.min_defined) {
50  int64_t adjustment;
52  adjustment = mod_imp(adjustment, alignment.modulus);
53  int64_t new_min;
54  no_overflow &= add_with_overflow(64, bounds.min, adjustment, &new_min);
55  if (no_overflow) {
56  bounds.min = new_min;
57  }
58  }
59  if (bounds.max_defined) {
60  int64_t adjustment;
62  adjustment = mod_imp(adjustment, alignment.modulus);
63  int64_t new_max;
64  no_overflow &= sub_with_overflow(64, bounds.max, adjustment, &new_max);
65  if (no_overflow) {
66  bounds.max = new_max;
67  }
68  }
69  }
70 
71  if (bounds.is_single_point()) {
72  alignment.modulus = 0;
74  }
75 
76  if (bounds.is_bounded() && bounds.min > bounds.max) {
77  // Impossible, we must be in unreachable code. TODO: surface
78  // this to the simplify instance's in_unreachable flag.
79  bounds.max = bounds.min;
80  }
81  }
82 
83  void cast_to(Type t) {
84  if ((!t.is_int() && !t.is_uint()) || (t.is_int() && t.bits() >= 32)) {
85  return;
86  }
87 
88  // We've just done some infinite-integer operation on a bounded
89  // integer type, and we need to project the bounds and alignment
90  // back in-range.
91 
92  if (!t.can_represent(bounds)) {
93  if (t.bits() >= 64) {
94  // Just preserve any power-of-two factor in the modulus. When
95  // alignment.modulus == 0, the value is some positive constant
96  // representable as any 64-bit integer type, so there's no
97  // wraparound.
98  if (alignment.modulus > 0) {
99  // This masks off all bits except for the lowest set one,
100  // giving the largest power-of-two factor of a number.
103  }
104  } else {
105  // A narrowing integer cast that could possibly overflow adds
106  // some unknown multiple of 2^bits
107  alignment = alignment + ModulusRemainder(((int64_t)1 << t.bits()), 0);
108  }
109  }
110 
111  // Truncate the bounds to the new type.
112  bounds.cast_to(t);
113  }
114 
115  // Mix in existing knowledge about this Expr
116  void intersect(const ExprInfo &other) {
117  if (bounds < other.bounds || other.bounds < bounds) {
118  // Impossible. We must be in unreachable code. TODO: It might
119  // be nice to surface this to the simplify instance's
120  // in_unreachable flag, but we'd have to be sure that it's going
121  // to be caught at the right place.
122  return;
123  }
127  }
128  };
129 
132  if (b) {
133  *b = ExprInfo{};
134  }
135  }
136 
137 #if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
138  int debug_indent = 0;
139 #endif
140 
141 #if LOG_EXPR_MUTATIONS
142  Expr mutate(const Expr &e, ExprInfo *b) {
143  internal_assert(debug_indent >= 0);
144  const std::string spaces(debug_indent, ' ');
145  debug(1) << spaces << "Simplifying Expr: " << e << "\n";
146  debug_indent++;
147  Expr new_e = Super::dispatch(e, b);
148  debug_indent--;
149  if (!new_e.same_as(e)) {
150  debug(1)
151  << spaces << "Before: " << e << "\n"
152  << spaces << "After: " << new_e << "\n";
153  if (b) {
154  debug(1)
155  << spaces << "Bounds: " << b->bounds << " " << b->alignment << "\n";
156  if (auto i = as_const_int(new_e)) {
157  internal_assert(b->bounds.contains(*i)) << e << "\n"
158  << new_e << "\n"
159  << b->bounds;
160  } else if (auto i = as_const_uint(new_e)) {
161  internal_assert(b->bounds.contains(*i)) << e << "\n"
162  << new_e << "\n"
163  << b->bounds;
164  }
165  }
166  }
167  internal_assert(e.type() == new_e.type());
168  return new_e;
169  }
170 
171 #else
173  Expr mutate(const Expr &e, ExprInfo *b) {
174  // This gets inlined into every call to mutate, so do not add any code here.
175  return Super::dispatch(e, b);
176  }
177 #endif
178 
179 #if LOG_STMT_MUTATIONS
180  Stmt mutate(const Stmt &s) {
181  const std::string spaces(debug_indent, ' ');
182  debug(1) << spaces << "Simplifying Stmt: " << s << "\n";
183  debug_indent++;
184  Stmt new_s = Super::dispatch(s);
185  debug_indent--;
186  if (!new_s.same_as(s)) {
187  debug(1)
188  << spaces << "Before: " << s << "\n"
189  << spaces << "After: " << new_s << "\n";
190  }
191  return new_s;
192  }
193 #else
194  Stmt mutate(const Stmt &s) {
195  return Super::dispatch(s);
196  }
197 #endif
198 
200  bool no_float_simplify = false;
201 
203  bool may_simplify(const Type &t) const {
204  return !no_float_simplify || !t.is_float();
205  }
206 
207  // Returns true iff t is an integral type where overflow is undefined
210  return t.is_int() && t.bits() >= 32;
211  }
212 
215  return t.is_scalar() && no_overflow_int(t);
216  }
217 
218  // Returns true iff t does not have a well defined overflow behavior.
220  bool no_overflow(Type t) {
221  return t.is_float() || no_overflow_int(t);
222  }
223 
224  struct VarInfo {
227  };
228 
229  // Tracked for all let vars
231 
232  // Only tracked for integer let vars
234 
235  // Symbols used by rewrite rules
248 
249  // Tracks whether or not we're inside a vector loop. Certain
250  // transformations are not a good idea if the code is to be
251  // vectorized.
252  bool in_vector_loop = false;
253 
254  // Tracks whether or not the current IR is unconditionally unreachable.
255  bool in_unreachable = false;
256 
257  // If we encounter a reference to a buffer (a Load, Store, Call,
258  // or Provide), there's an implicit dependence on some associated
259  // symbols.
260  void found_buffer_reference(const std::string &name, size_t dimensions = 0);
261 
262  // Put the args to a commutative op in a canonical order
264  bool should_commute(const Expr &a, const Expr &b) {
265  if (a.node_type() < b.node_type()) {
266  return true;
267  }
268  if (a.node_type() > b.node_type()) {
269  return false;
270  }
271 
272  if (a.node_type() == IRNodeType::Variable) {
273  const Variable *va = a.as<Variable>();
274  const Variable *vb = b.as<Variable>();
275  return va->name.compare(vb->name) > 0;
276  }
277 
278  return false;
279  }
280 
281  std::set<Expr, IRDeepCompare> truths, falsehoods;
282 
283  struct ScopedFact {
285 
286  std::vector<const Variable *> pop_list;
287  std::vector<const Variable *> bounds_pop_list;
288  std::set<Expr, IRDeepCompare> truths, falsehoods;
289 
290  void learn_false(const Expr &fact);
291  void learn_true(const Expr &fact);
292  void learn_upper_bound(const Variable *v, int64_t val);
293  void learn_lower_bound(const Variable *v, int64_t val);
294 
295  // Replace exprs known to be truths or falsehoods with const_true or const_false.
298 
300  : simplify(s) {
301  }
303 
304  // allow move but not copy
305  ScopedFact(const ScopedFact &that) = delete;
306  ScopedFact(ScopedFact &&that) = default;
307  };
308 
309  // Tell the simplifier to learn from and exploit a boolean
310  // condition, over the lifetime of the returned object.
311  ScopedFact scoped_truth(const Expr &fact) {
312  ScopedFact f(this);
313  f.learn_true(fact);
314  return f;
315  }
316 
317  // Tell the simplifier to assume a boolean condition is false over
318  // the lifetime of the returned object.
320  ScopedFact f(this);
321  f.learn_false(fact);
322  return f;
323  }
324 
326  return mutate(s);
327  }
328  Expr mutate_let_body(const Expr &e, ExprInfo *info) {
329  return mutate(e, info);
330  }
331 
332  template<typename T, typename Body>
333  Body simplify_let(const T *op, ExprInfo *info);
334 
335  Expr visit(const IntImm *op, ExprInfo *info);
336  Expr visit(const UIntImm *op, ExprInfo *info);
337  Expr visit(const FloatImm *op, ExprInfo *info);
338  Expr visit(const StringImm *op, ExprInfo *info);
339  Expr visit(const Broadcast *op, ExprInfo *info);
340  Expr visit(const Cast *op, ExprInfo *info);
341  Expr visit(const Reinterpret *op, ExprInfo *info);
342  Expr visit(const Variable *op, ExprInfo *info);
343  Expr visit(const Add *op, ExprInfo *info);
344  Expr visit(const Sub *op, ExprInfo *info);
345  Expr visit(const Mul *op, ExprInfo *info);
346  Expr visit(const Div *op, ExprInfo *info);
347  Expr visit(const Mod *op, ExprInfo *info);
348  Expr visit(const Min *op, ExprInfo *info);
349  Expr visit(const Max *op, ExprInfo *info);
350  Expr visit(const EQ *op, ExprInfo *info);
351  Expr visit(const NE *op, ExprInfo *info);
352  Expr visit(const LT *op, ExprInfo *info);
353  Expr visit(const LE *op, ExprInfo *info);
354  Expr visit(const GT *op, ExprInfo *info);
355  Expr visit(const GE *op, ExprInfo *info);
356  Expr visit(const And *op, ExprInfo *info);
357  Expr visit(const Or *op, ExprInfo *info);
358  Expr visit(const Not *op, ExprInfo *info);
359  Expr visit(const Select *op, ExprInfo *info);
360  Expr visit(const Ramp *op, ExprInfo *info);
361  Stmt visit(const IfThenElse *op);
362  Expr visit(const Load *op, ExprInfo *info);
363  Expr visit(const Call *op, ExprInfo *info);
364  Expr visit(const Shuffle *op, ExprInfo *info);
365  Expr visit(const VectorReduce *op, ExprInfo *info);
366  Expr visit(const Let *op, ExprInfo *info);
367  Stmt visit(const LetStmt *op);
368  Stmt visit(const AssertStmt *op);
369  Stmt visit(const For *op);
370  Stmt visit(const Provide *op);
371  Stmt visit(const Store *op);
372  Stmt visit(const Allocate *op);
373  Stmt visit(const Evaluate *op);
375  Stmt visit(const Block *op);
376  Stmt visit(const Realize *op);
377  Stmt visit(const Prefetch *op);
378  Stmt visit(const Free *op);
379  Stmt visit(const Acquire *op);
380  Stmt visit(const Fork *op);
381  Stmt visit(const Atomic *op);
383 
384  std::pair<std::vector<Expr>, bool> mutate_with_changes(const std::vector<Expr> &old_exprs);
385 };
386 
387 } // namespace Internal
388 } // namespace Halide
389 
390 #endif
Methods for computing the upper and lower bounds of an expression, and the regions of a function read...
Defines the ConstantInterval class, and operators on it.
#define internal_assert(c)
Definition: Errors.h:19
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:49
Defines a method to match a fragment of IR against a pattern containing wildcards.
This header file defines operators that let you dump a Halide expression, statement,...
Defines the base class for things that recursively walk over the IR.
Defines the Scope class, which is used for keeping track of names in a scope while traversing IR.
A common pattern when traversing Halide IR is that you need to keep track of stuff when you find a Le...
Definition: Scope.h:94
Stmt visit(const HoistedStorage *op)
Expr visit(const LE *op, ExprInfo *info)
Stmt visit(const ProducerConsumer *op)
HALIDE_ALWAYS_INLINE Expr mutate(const Expr &e, ExprInfo *b)
Scope< ExprInfo > bounds_and_alignment_info
Expr visit(const Variable *op, ExprInfo *info)
IRMatcher::WildConst< 5 > c5
void found_buffer_reference(const std::string &name, size_t dimensions=0)
Stmt visit(const Block *op)
Expr visit(const Load *op, ExprInfo *info)
Expr visit(const Cast *op, ExprInfo *info)
Stmt visit(const AssertStmt *op)
Stmt mutate(const Stmt &s)
Stmt visit(const Evaluate *op)
Expr visit(const Let *op, ExprInfo *info)
Expr visit(const LT *op, ExprInfo *info)
Simplify(bool r, const Scope< Interval > *bi, const Scope< ModulusRemainder > *ai)
Expr visit(const Ramp *op, ExprInfo *info)
Stmt visit(const Prefetch *op)
HALIDE_ALWAYS_INLINE bool no_overflow(Type t)
IRMatcher::WildConst< 1 > c1
Expr visit(const Shuffle *op, ExprInfo *info)
Stmt visit(const IfThenElse *op)
Expr visit(const Mod *op, ExprInfo *info)
IRMatcher::WildConst< 0 > c0
ScopedFact scoped_truth(const Expr &fact)
IRMatcher::WildConst< 3 > c3
Expr visit(const UIntImm *op, ExprInfo *info)
Expr visit(const Add *op, ExprInfo *info)
Expr visit(const Max *op, ExprInfo *info)
HALIDE_ALWAYS_INLINE void clear_expr_info(ExprInfo *b)
Expr visit(const Mul *op, ExprInfo *info)
IRMatcher::WildConst< 2 > c2
Expr visit(const StringImm *op, ExprInfo *info)
Expr visit(const VectorReduce *op, ExprInfo *info)
std::pair< std::vector< Expr >, bool > mutate_with_changes(const std::vector< Expr > &old_exprs)
Expr visit(const Min *op, ExprInfo *info)
Expr visit(const Reinterpret *op, ExprInfo *info)
Expr visit(const NE *op, ExprInfo *info)
HALIDE_ALWAYS_INLINE bool may_simplify(const Type &t) const
Stmt visit(const For *op)
Expr visit(const Select *op, ExprInfo *info)
Stmt visit(const Atomic *op)
Expr visit(const Div *op, ExprInfo *info)
Expr visit(const GE *op, ExprInfo *info)
Stmt visit(const Provide *op)
Expr visit(const Not *op, ExprInfo *info)
Body simplify_let(const T *op, ExprInfo *info)
Expr mutate_let_body(const Expr &e, ExprInfo *info)
Expr visit(const Or *op, ExprInfo *info)
Expr visit(const FloatImm *op, ExprInfo *info)
Stmt mutate_let_body(const Stmt &s, ExprInfo *)
Stmt visit(const Acquire *op)
Stmt visit(const Fork *op)
HALIDE_ALWAYS_INLINE bool no_overflow_int(Type t)
std::set< Expr, IRDeepCompare > truths
ScopedFact scoped_falsehood(const Expr &fact)
HALIDE_ALWAYS_INLINE bool should_commute(const Expr &a, const Expr &b)
Expr visit(const Broadcast *op, ExprInfo *info)
Expr visit(const Sub *op, ExprInfo *info)
Stmt visit(const Store *op)
Expr visit(const GT *op, ExprInfo *info)
HALIDE_ALWAYS_INLINE bool no_overflow_scalar_int(Type t)
Stmt visit(const Free *op)
IRMatcher::WildConst< 4 > c4
Expr visit(const Call *op, ExprInfo *info)
Stmt visit(const Allocate *op)
Stmt visit(const Realize *op)
Expr visit(const EQ *op, ExprInfo *info)
Stmt visit(const LetStmt *op)
std::set< Expr, IRDeepCompare > falsehoods
Expr visit(const IntImm *op, ExprInfo *info)
Expr visit(const And *op, ExprInfo *info)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition: IRVisitor.h:161
HALIDE_ALWAYS_INLINE Stmt dispatch(const Stmt &s, Args &&...args)
Definition: IRVisitor.h:335
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:252
std::optional< uint64_t > as_const_uint(const Expr &e)
If an expression is a UIntImm or a Broadcast of a UIntImm, return its value.
HALIDE_MUST_USE_RESULT bool add_with_overflow(int bits, int64_t a, int64_t b, int64_t *result)
Routines to perform arithmetic on signed types without triggering signed overflow.
std::optional< int64_t > as_const_int(const Expr &e)
If an expression is an IntImm or a Broadcast of an IntImm, return a its value.
HALIDE_MUST_USE_RESULT bool sub_with_overflow(int bits, int64_t a, int64_t b, int64_t *result)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
signed __INT64_TYPE__ int64_t
A fragment of Halide syntax.
Definition: Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:327
The sum of two expressions.
Definition: IR.h:56
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:371
Logical and - are both expressions true.
Definition: IR.h:175
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:294
Lock all the Store nodes in the body statement.
Definition: IR.h:961
A sequence of statements to be executed in-order.
Definition: IR.h:442
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:259
A function call.
Definition: IR.h:490
The actual IR nodes begin here.
Definition: IR.h:30
A class to represent ranges of integers.
int64_t min
The lower and upper bound of the interval.
bool is_bounded() const
Does the interval have a finite upper and lower bound.
static ConstantInterval make_intersection(const ConstantInterval &a, const ConstantInterval &b)
Construct the largest interval contained within two intervals.
bool is_single_point() const
Is the interval just a single value (min == max)
void cast_to(const Type &t)
Track what happens if a constant integer interval is forced to fit into a concrete integer type.
static ConstantInterval single_point(int64_t x)
Construct an interval representing a single point.
The ratio of two expressions.
Definition: IR.h:83
Is the first expression equal to the second.
Definition: IR.h:121
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:476
Floating point constants.
Definition: Expr.h:236
A for loop.
Definition: IR.h:819
A pair of statements executed concurrently.
Definition: IR.h:457
Free the resources associated with the given buffer.
Definition: IR.h:413
Is the first expression greater than or equal to the second.
Definition: IR.h:166
Is the first expression greater than the second.
Definition: IR.h:157
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition: IR.h:945
IRNodeType node_type() const
Definition: Expr.h:212
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:205
An if-then-else block.
Definition: IR.h:466
Integer constants.
Definition: Expr.h:218
HALIDE_ALWAYS_INLINE bool same_as(const IntrusivePtr &other) const
Definition: IntrusivePtr.h:171
Is the first expression less than or equal to the second.
Definition: IR.h:148
Is the first expression less than the second.
Definition: IR.h:139
A let expression, like you might find in a functional language.
Definition: IR.h:271
The statement form of a let node.
Definition: IR.h:282
Load a value from a named symbol if predicate is true.
Definition: IR.h:217
The greater of two values.
Definition: IR.h:112
The lesser of two values.
Definition: IR.h:103
The remainder of a / b.
Definition: IR.h:94
The result of modulus_remainder analysis.
static ModulusRemainder intersect(const ModulusRemainder &a, const ModulusRemainder &b)
The product of two expressions.
Definition: IR.h:74
Is the first expression not equal to the second.
Definition: IR.h:130
Logical not - true if the expression false.
Definition: IR.h:193
Logical or - is at least one of the expression true.
Definition: IR.h:184
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:923
This node is a helpful annotation to do with permissions.
Definition: IR.h:315
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:354
A linear ramp vector node.
Definition: IR.h:247
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:427
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition: IR.h:47
A ternary operator.
Definition: IR.h:204
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:855
void intersect(const ExprInfo &other)
ScopedFact(ScopedFact &&that)=default
void learn_false(const Expr &fact)
std::vector< const Variable * > bounds_pop_list
ScopedFact(const ScopedFact &that)=delete
std::set< Expr, IRDeepCompare > truths
std::vector< const Variable * > pop_list
void learn_lower_bound(const Variable *v, int64_t val)
std::set< Expr, IRDeepCompare > falsehoods
void learn_upper_bound(const Variable *v, int64_t val)
A reference-counted handle to a statement node.
Definition: Expr.h:427
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:333
String constants.
Definition: Expr.h:245
The difference of two expressions.
Definition: IR.h:65
Unsigned integer constants.
Definition: Expr.h:227
A named variable.
Definition: IR.h:772
std::string name
Definition: IR.h:773
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:979
Types in the halide type system.
Definition: Type.h:283
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:435
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:349
bool can_represent(Type other) const
Can this type represent all values of another type?
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:423