Halide 17.0.1
Halide compiler and libraries
Loading...
Searching...
No Matches
simd_op_check.h
Go to the documentation of this file.
1#ifndef SIMD_OP_CHECK_H
2#define SIMD_OP_CHECK_H
3
4#include "Halide.h"
5#include "halide_test_dirs.h"
6#include "test_sharding.h"
7
8#include <fstream>
9#include <iostream>
10
11namespace Halide {
12struct TestResult {
13 std::string op;
14 std::string error_msg;
15};
16
17struct Task {
18 std::string op;
19 std::string name;
22};
23
25public:
26 static constexpr int max_i8 = 127;
27 static constexpr int max_i16 = 32767;
28 static constexpr int max_i32 = 0x7fffffff;
29 static constexpr int max_u8 = 255;
30 static constexpr int max_u16 = 65535;
31 const Expr max_u32 = UInt(32).max();
32
33 std::string filter{"*"};
35 std::vector<Task> tasks;
36 std::mt19937 rng;
37
39
40 ImageParam in_f32{Float(32), 1, "in_f32"};
41 ImageParam in_f64{Float(64), 1, "in_f64"};
42 ImageParam in_f16{Float(16), 1, "in_f16"};
43 ImageParam in_bf16{BFloat(16), 1, "in_bf16"};
44 ImageParam in_i8{Int(8), 1, "in_i8"};
45 ImageParam in_u8{UInt(8), 1, "in_u8"};
46 ImageParam in_i16{Int(16), 1, "in_i16"};
47 ImageParam in_u16{UInt(16), 1, "in_u16"};
48 ImageParam in_i32{Int(32), 1, "in_i32"};
49 ImageParam in_u32{UInt(32), 1, "in_u32"};
50 ImageParam in_i64{Int(64), 1, "in_i64"};
51 ImageParam in_u64{UInt(64), 1, "in_u64"};
52
55 int W;
56 int H;
57
59
67 virtual ~SimdOpCheckTest() = default;
68
69 void set_seed(int seed) {
70 rng.seed(seed);
71 }
72
73 virtual bool can_run_code() const {
76 }
77 // If we can (target matches host), run the error checking Halide::Func.
79 bool can_run_the_code =
80 (target.arch == host_target.arch &&
81 target.bits == host_target.bits &&
82 target.os == host_target.os);
83 // A bunch of feature flags also need to match between the
84 // compiled code and the host in order to run the code.
85 for (Target::Feature f : {
107 }) {
108 if (target.has_feature(f) != host_target.has_feature(f)) {
109 can_run_the_code = false;
110 }
111 }
112 return can_run_the_code;
113 }
114
115 virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, std::ostringstream &error_msg) {
116 std::string fn_name = "test_" + name;
117 std::string file_name = output_directory + fn_name;
118
120 std::map<OutputFileType, std::string> outputs = {
124 };
125 error.compile_to(outputs, arg_types, fn_name, target);
126
127 std::ifstream asm_file;
128 asm_file.open(file_name + ".s");
129
130 bool found_it = false;
131
132 std::ostringstream msg;
133 msg << op << " did not generate for target=" << get_run_target().to_string() << " vector_width=" << vector_width << ". Instead we got:\n";
134
135 std::string line;
136 while (getline(asm_file, line)) {
137 msg << line << "\n";
138
139 // Check for the op in question
140 found_it |= wildcard_search(op, line) && !wildcard_search("_" + op, line);
141 }
142
143 if (!found_it) {
144 error_msg << "Failed: " << msg.str() << "\n";
145 }
146
147 asm_file.close();
148 }
149
150 // Check if pattern p matches str, allowing for wildcards (*).
151 bool wildcard_match(const char *p, const char *str) const {
152 // Match all non-wildcard characters.
153 while (*p && *str && *p == *str && *p != '*') {
154 str++;
155 p++;
156 }
157
158 if (!*p) {
159 return *str == 0;
160 } else if (*p == '*') {
161 p++;
162 do {
163 if (wildcard_match(p, str)) {
164 return true;
165 }
166 } while (*str++);
167 } else if (*p == ' ') { // ignore whitespace in pattern
168 p++;
169 if (wildcard_match(p, str)) {
170 return true;
171 }
172 } else if (*str == ' ') { // ignore whitespace in string
173 str++;
174 if (wildcard_match(p, str)) {
175 return true;
176 }
177 }
178 return !*p;
179 }
180
181 bool wildcard_match(const std::string &p, const std::string &str) const {
182 return wildcard_match(p.c_str(), str.c_str());
183 }
184
185 // Check if a substring of str matches a pattern p.
186 bool wildcard_search(const std::string &p, const std::string &str) const {
187 return wildcard_match("*" + p + "*", str);
188 }
189
196
197 TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) {
198 std::ostringstream error_msg;
199
201 using Internal::IRVisitor::visit;
202 void visit(const Internal::Call *op) override {
203 if (op->call_type == Internal::Call::Halide) {
205 if (f.has_update_definition()) {
206 inline_reduction = f;
207 result = true;
208 }
209 }
210 IRVisitor::visit(op);
211 }
212
213 public:
214 Internal::Function inline_reduction;
215 bool result = false;
218
219 // Define a vectorized Halide::Func that uses the pattern.
220 Halide::Func f(name);
221 f(x, y) = e;
222 f.bound(x, 0, W).vectorize(x, vector_width);
223 f.compute_root();
224
225 // Include a scalar version
226 Halide::Func f_scalar("scalar_" + name);
227 f_scalar(x, y) = e;
228
229 if (has_inline_reduction.result) {
230 // If there's an inline reduction, we want to vectorize it
231 // over the RVar.
232 Var xo, xi;
233 RVar rxi;
234 Func g{has_inline_reduction.inline_reduction};
235
236 // Do the reduction separately in f_scalar
237 g.clone_in(f_scalar);
238
239 g.compute_at(f, x)
240 .update()
241 .split(x, xo, xi, vector_width)
242 .atomic(true)
243 .vectorize(g.rvars()[0])
244 .vectorize(xi);
245 }
246
247 // The output to the pipeline is the maximum absolute difference as a double.
248 RDom r_check(0, W, 0, H);
249 Halide::Func error("error_" + name);
251
252 setup_images();
253 compile_and_check(error, op, name, vector_width, error_msg);
254
256 if (can_run_the_code) {
258
260 // Fill the inputs with noise
261 for (auto p : image_params) {
262 Halide::Buffer<> buf = p.get();
263 if (!buf.defined()) continue;
264 assert(buf.data());
265 Type t = buf.type();
266 // For floats/doubles, we only use values that aren't
267 // subject to rounding error that may differ between
268 // vectorized and non-vectorized versions
269 if (t == Float(32)) {
270 buf.as<float>().for_each_value([&](float &f) { f = (rng() & 0xfff) / 8.0f - 0xff; });
271 } else if (t == Float(64)) {
272 buf.as<double>().for_each_value([&](double &f) { f = (rng() & 0xfff) / 8.0 - 0xff; });
273 } else if (t == Float(16)) {
274 buf.as<float16_t>().for_each_value([&](float16_t &f) { f = float16_t((rng() & 0xff) / 8.0f - 0xf); });
275 } else {
276 // Random bits is fine
277 for (uint32_t *ptr = (uint32_t *)buf.data();
278 ptr != (uint32_t *)buf.data() + buf.size_in_bytes() / 4;
279 ptr++) {
280 // Never use the top four bits, to avoid
281 // signed integer overflow.
282 *ptr = ((uint32_t)rng()) & 0x0fffffff;
283 }
284 }
285 }
286 Realization r = error.realize();
287 double e = Buffer<double>(r[0])();
288 // Use a very loose tolerance for floating point tests. The
289 // kinds of bugs we're looking for are codegen bugs that
290 // return the wrong value entirely, not floating point
291 // accuracy differences between vectors and scalars.
292 if (e > 0.001) {
293 error_msg << "The vector and scalar versions of " << name << " disagree. Maximum error: " << e << "\n";
294
295 std::string error_filename = output_directory + "error_" + name + ".s";
297
298 std::ifstream error_file;
300
301 error_msg << "Error assembly: \n";
302 std::string line;
303 while (getline(error_file, line)) {
304 error_msg << line << "\n";
305 }
306
307 error_file.close();
308 }
309 }
310
311 return {op, error_msg.str()};
312 }
313
314 void check(std::string op, int vector_width, Expr e) {
315 // Make a name for the test by uniquing then sanitizing the op name
316 std::string name = "op_" + op;
317 for (size_t i = 0; i < name.size(); i++) {
318 if (!isalnum(name[i])) name[i] = '_';
319 }
320
321 name += "_" + std::to_string(tasks.size());
322
323 // Bail out after generating the unique_name, so that names are
324 // unique across different processes and don't depend on filter
325 // settings.
326 if (!wildcard_match(filter, op)) return;
327
328 tasks.emplace_back(Task{op, name, vector_width, e});
329 }
330 virtual void add_tests() = 0;
331 virtual void setup_images() {
332 for (auto p : image_params) {
333 p.reset();
334
335 const int alignment_bytes = 16;
336 p.set_host_alignment(alignment_bytes);
337 const int alignment = alignment_bytes / p.type().bytes();
338 p.dim(0).set_min((p.dim(0).min() / alignment) * alignment);
339 }
340 }
341 virtual bool test_all() {
342 /* First add some tests based on the target */
343 add_tests();
344
345 // Remove irrelevant noise from output
347 const std::string run_target_str = run_target.to_string();
348
350 bool success = true;
351 for (size_t t = 0; t < tasks.size(); t++) {
352 if (!sharder.should_run(t)) continue;
353 const auto &task = tasks.at(t);
354 auto result = check_one(task.op, task.name, task.vector_width, task.expr);
355 constexpr int tabstop = 32;
356 const int spaces = std::max(1, tabstop - (int)result.op.size());
357 std::cout << result.op << std::string(spaces, ' ') << "(" << run_target_str << ")\n";
358 if (!result.error_msg.empty()) {
359 std::cerr << result.error_msg;
360 success = false;
361 }
362 }
363
364 return success;
365 }
366
367 template<typename SIMDOpCheckT>
368 static int main(int argc, char **argv, const std::vector<Target> &targets_to_test) {
369 Target host = get_host_target();
370 std::cout << "host is: " << host << "\n";
371
372 const int seed = argc > 2 ? atoi(argv[2]) : time(nullptr);
373 std::cout << "simd_op_check test seed: " << seed << "\n";
374
375 for (const auto &t : targets_to_test) {
376 if (!t.supported()) {
377 std::cout << "[SKIP] Unsupported target: " << t << "\n";
378 return 0;
379 }
380 SIMDOpCheckT test(t);
381
382 if (!t.supported()) {
383 std::cout << "Halide was compiled without support for " << t.to_string() << ". Skipping.\n";
384 continue;
385 }
386
387 if (argc > 1) {
388 test.filter = argv[1];
389 }
390
391 if (getenv("HL_SIMD_OP_CHECK_FILTER")) {
392 test.filter = getenv("HL_SIMD_OP_CHECK_FILTER");
393 }
394
395 test.set_seed(seed);
396
397 if (argc > 2) {
398 // Don't forget: if you want to run the standard tests to a specific output
399 // directory, you'll need to invoke with the first arg enclosed
400 // in quotes (to avoid it being wildcard-expanded by the shell):
401 //
402 // correctness_simd_op_check "*" /path/to/output
403 //
404 test.output_directory = argv[2];
405 }
406
407 bool success = test.test_all();
408
409 // Compile a runtime for this target, for use in the static test.
410 compile_standalone_runtime(test.output_directory + "simd_op_check_runtime.o", test.target);
411
412 if (!success) {
413 return 1;
414 }
415 }
416
417 std::cout << "Success!\n";
418 return 0;
419 }
420
421private:
422 const Halide::Var x{"x"}, y{"y"};
423};
424
425} // namespace Halide
426
427#endif // SIMD_OP_CHECK_H
A halide function.
Definition Func.h:706
void compile_to_assembly(const std::string &filename, const std::vector< Argument > &, const std::string &fn_name, const Target &target=get_target_from_environment())
Statically compile this function to text assembly equivalent to the object file generated by compile_...
Func & compute_root()
Compute all of this function once ahead of time.
void infer_input_bounds(const std::vector< int32_t > &sizes, const Target &target=get_jit_target_from_environment())
For a given size of output, or a given output buffer, determine the bounds required of all unbound Im...
void compile_to(const std::map< OutputFileType, std::string > &output_files, const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment())
Compile and generate multiple target files with single call.
Realization realize(std::vector< int32_t > sizes={}, const Target &target=Target())
Evaluate this function over some rectangular domain and return the resulting buffer or buffers.
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & bound(const Var &var, Expr min, Expr extent)
Statically declare that the range over which a function should be evaluated is given by the second an...
An Image parameter to a halide pipeline.
Definition ImageParam.h:23
A reference-counted handle to Halide's internal representation of a function.
Definition Function.h:38
bool has_update_definition() const
Does this function have an update definition?
A base class for algorithms that need to recursively walk over the IR.
Definition IRVisitor.h:19
A multi-dimensional domain over which to iterate.
Definition RDom.h:193
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition RDom.h:29
A Realization is a vector of references to existing Buffer objects.
Definition Realization.h:19
virtual void compile_and_check(Func error, const std::string &op, const std::string &name, int vector_width, std::ostringstream &error_msg)
static constexpr int max_u8
const std::vector< Argument > arg_types
static constexpr int max_i32
virtual void setup_images()
void set_seed(int seed)
virtual void add_tests()=0
bool wildcard_match(const std::string &p, const std::string &str) const
static constexpr int max_i8
static constexpr int max_u16
bool wildcard_search(const std::string &p, const std::string &str) const
bool wildcard_match(const char *p, const char *str) const
virtual ~SimdOpCheckTest()=default
SimdOpCheckTest(const Target t, int w, int h)
void check(std::string op, int vector_width, Expr e)
Target get_run_target() const
static int main(int argc, char **argv, const std::vector< Target > &targets_to_test)
const std::vector< ImageParam > image_params
static constexpr int max_i16
TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e)
virtual bool can_run_code() const
std::vector< Task > tasks
A Halide variable, to be used when defining functions.
Definition Var.h:19
std::map< OutputFileType, const OutputInfo > get_output_info(const Target &target)
std::string get_test_tmp_dir()
Return the path to a directory that can be safely written to when running tests; the contents directo...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Target get_host_target()
Return the target corresponding to the host machine.
Type BFloat(int bits, int lanes=1)
Construct a floating-point type in the bfloat format.
Definition Type.h:545
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition Type.h:535
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition Type.h:540
Expr maximum(Expr, const std::string &s="maximum")
Expr cast(Expr a)
Cast an expression to the halide type corresponding to the C++ type T.
Definition IROperator.h:364
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition Type.h:530
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
void compile_standalone_runtime(const std::string &object_filename, const Target &t)
Create an object file containing the Halide runtime for a given target.
int atoi(const char *)
unsigned __INT32_TYPE__ uint32_t
char * getenv(const char *)
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:322
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:316
A function call.
Definition IR.h:490
FunctionPtr func
Definition IR.h:660
CallType call_type
Definition IR.h:501
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition Expr.h:192
static bool can_jit_target(const Target &target)
If the given target can be executed via the wasm executor, return true.
A struct representing a target machine and os to generate code for.
Definition Target.h:19
enum Halide::Target::Arch arch
bool has_feature(Feature f) const
int bits
The bit-width of the target machine.
Definition Target.h:50
enum Halide::Target::OS os
std::string to_string() const
Convert the Target into a string form that can be reconstituted by merge_string(),...
Target without_feature(Feature f) const
Return a copy of the target with the given feature cleared.
Feature
Optional features a target can have.
Definition Target.h:83
@ AVX512_Cannonlake
Definition Target.h:132
@ AVX512_SapphireRapids
Definition Target.h:133
@ POWER_ARCH_2_07
Definition Target.h:97
Target with_feature(Feature f) const
Return a copy of the target with the given feature set.
std::string op
std::string name
std::string error_msg
Types in the halide type system.
Definition Type.h:276
int bytes() const
The number of bytes required to store a single scalar value of this type.
Definition Type.h:292
Expr max() const
Return an expression which is the maximum value of this type.
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition Float16.h:17