diff --git a/include/bfloat16.hpp b/include/bfloat16.hpp new file mode 100644 index 00000000..5e9541bd --- /dev/null +++ b/include/bfloat16.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +class bfloat16; + +class bfloat16 { + using StorageType = uint16_t; + StorageType value; + + static StorageType from_float(const float &a) { + if (std::isnan(a)) + return 0xffc1; + union { + uint32_t intStorage; + float floatValue; + }; + floatValue = a; + // Do RNE and truncate + uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF; + return static_cast((intStorage + roundingBias) >> 16); + } + + static float to_float(const StorageType &a) { + union { + uint32_t intStorage; + float floatValue; + }; + intStorage = a << 16; + return floatValue; + } + +public: + bfloat16() = default; + bfloat16(const bfloat16 &) = default; + ~bfloat16() = default; + + // Implicit conversion from float to bfloat16 + bfloat16(const float &a) { value = from_float(a); } + + bfloat16 &operator=(const float &rhs) { + value = from_float(rhs); + return *this; + } + + // Implicit conversion from bfloat16 to float + operator float() const { return to_float(value); } + + // Logical operators (!,||,&&) are covered if we can cast to bool + explicit operator bool() { return to_float(value) != 0.0f; } + + // Unary minus operator overloading + friend bfloat16 operator-(bfloat16 &lhs) { + return -to_float(lhs.value); + } + + // Increment and decrement operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs) { \ + float f = to_float(lhs.value); \ + lhs.value = from_float(op f); \ + return lhs; \ + } \ + friend bfloat16 operator op(bfloat16 &lhs, int) { \ + bfloat16 old = lhs; \ + operator op(lhs); \ + return old; \ + } + OP(++) + OP(--) +#undef OP + + // Assignment operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template \ + friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template friend T &operator op(T &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } + OP(+=) + OP(-=) + OP(*=) + OP(/=) +#undef OP + +// Binary operators overloading +#define OP(type, op) \ + friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const bfloat16 &lhs, const T &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const T &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } + OP(bfloat16, +) + OP(bfloat16, -) + OP(bfloat16, *) + OP(bfloat16, /) + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP + + // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported + // for floating-point types. +}; \ No newline at end of file diff --git a/include/util.hpp b/include/util.hpp index 8947dc91..2b65600f 100644 --- a/include/util.hpp +++ b/include/util.hpp @@ -6,6 +6,8 @@ #pragma once #include + +#include #include static cl_version getDeviceOpenCLVersion( @@ -68,6 +70,27 @@ static bool checkDeviceForExtension( return supported; } +static std::string readStringFromFile( + const std::string& filename ) +{ + std::ifstream is(filename, std::ios::binary); + if (!is.good()) { + printf("Couldn't open file '%s'!\n", filename.c_str()); + return ""; + } + + size_t filesize = 0; + is.seekg(0, std::ios::end); + filesize = (size_t)is.tellg(); + is.seekg(0, std::ios::beg); + + std::string source{ + std::istreambuf_iterator(is), + std::istreambuf_iterator() }; + + return source; +} + static bool checkPlatformIndex( const std::vector& platforms, int platformIndex) diff --git a/samples/05_kernelfromfile/main.cpp b/samples/05_kernelfromfile/main.cpp index 4a665bcd..575e6a62 100644 --- a/samples/05_kernelfromfile/main.cpp +++ b/samples/05_kernelfromfile/main.cpp @@ -13,27 +13,6 @@ #include "util.hpp" -static std::string readStringFromFile( - const std::string& filename ) -{ - std::ifstream is(filename, std::ios::binary); - if (!is.good()) { - printf("Couldn't open file '%s'!\n", filename.c_str()); - return ""; - } - - size_t filesize = 0; - is.seekg(0, std::ios::end); - filesize = (size_t)is.tellg(); - is.seekg(0, std::ios::beg); - - std::string source{ - std::istreambuf_iterator(is), - std::istreambuf_iterator() }; - - return source; -} - int main( int argc, char** argv ) diff --git a/samples/06_ndrangekernelfromfile/main.cpp b/samples/06_ndrangekernelfromfile/main.cpp index 9968079e..db7bc09f 100644 --- a/samples/06_ndrangekernelfromfile/main.cpp +++ b/samples/06_ndrangekernelfromfile/main.cpp @@ -13,27 +13,6 @@ #include "util.hpp" -static std::string readStringFromFile( - const std::string& filename ) -{ - std::ifstream is(filename, std::ios::binary); - if (!is.good()) { - printf("Couldn't open file '%s'!\n", filename.c_str()); - return ""; - } - - size_t filesize = 0; - is.seekg(0, std::ios::end); - filesize = (size_t)is.tellg(); - is.seekg(0, std::ios::beg); - - std::string source{ - std::istreambuf_iterator(is), - std::istreambuf_iterator() }; - - return source; -} - int main( int argc, char** argv ) diff --git a/samples/20_matrixexperiments-bf16/CMakeLists.txt b/samples/20_matrixexperiments-bf16/CMakeLists.txt new file mode 100644 index 00000000..0c08c2bc --- /dev/null +++ b/samples/20_matrixexperiments-bf16/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2024-2026 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 20 + TARGET matrixexperiments-bf16 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_bf16.cl matrix_kernels_bf16.cl matrix_kernel_tiled_bf16.cl) diff --git a/samples/20_matrixexperiments-bf16/README.md b/samples/20_matrixexperiments-bf16/README.md new file mode 100644 index 00000000..793bba22 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/README.md @@ -0,0 +1,60 @@ +# matrixexperiments-bf16 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 16-bit `bfloat16` data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +A good place to start for some devices is: + +```sh +./matrixexperiments-bf16 -m4096 --options="-DSGS_PER_WG_X=4 -DSGS_PER_WG_Y=8 -DKK=2 -cl-intel-256-GRF-per-thread" --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_bfloat16_conversions +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate +* cl_intel_subgroups +* cl_intel_subgroups_short + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_bf16.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp new file mode 100644 index 00000000..c0c4dce4 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -0,0 +1,953 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include "bfloat16.hpp" +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool zeroData = false; +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +bool roundRobin = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) +{ + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); +} + +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0.0f; }); + } + else if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1.0f; }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = static_cast(r + c); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); + } +} + +template +static void vnni_matrix( + std::vector &dst, const std::vector &src, + size_t numRows, size_t numCols, size_t factor) +{ + for (size_t r = 0; r < numRows / factor; r++) { + for (size_t c = 0; c < numCols; c++) { + for (size_t k = 0; k < factor; k++) { + dst[r * numCols * factor + c * factor + k] = + src[(r * factor + k) * numCols + c]; + } + } + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = std::fma(static_cast(A[m * K + k]), + static_cast(B[k * N + n]), sum); + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + auto localErr = std::fabs(C[index] - C_ref[index]) / + std::max(std::fabs(C[index]), + std::fabs(C_ref[index])); + err = std::max(localErr, err); + if (localErr >= threshold) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (local error " << localErr << "): Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + +static void bfloat16_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void bfloat16_dpas_blockread_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_bf16.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperiments-bf16 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + + if (!checkPlatformIndex(platforms, platformIndex)) { + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_sg8 = supportsSubgroupSize(device, 8); + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DHAS_SG8=" + std::to_string(has_sg8); + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + vnni_matrix(Bvnni_vec, B_vec, K, N, 2); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x2) { + bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x4) { + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x8) { + bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x10) { + bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x20) { + bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x80) { + bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x100) { + bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x200) { + bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x800) { + bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x1000) { + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl new file mode 100644 index 00000000..7dcb2e27 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -0,0 +1,561 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +float bf16_to_fp32(ushort u) +{ +#if defined(cl_intel_bfloat16_conversions) + return intel_convert_as_bfloat16_float(u); +#else + return as_float(u << 16); +#endif +} + +__attribute__((overloadable)) +float activation(float f) +{ +#if defined(ACTIVATION_RELU) + return fmax(f, 0); +#else // identity + return f; +#endif +} + +__attribute__((overloadable)) +float2 activation(float2 f) +{ + float2 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + return res; +} + +__attribute__((overloadable)) +float4 activation(float4 f) +{ + float4 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + return res; +} + +float8 activation(float8 f) +{ + float8 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + res.s4 = activation(f.s4); + res.s5 = activation(f.s5); + res.s6 = activation(f.s6); + res.s7 = activation(f.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) + +typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); + +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; + return m_index * tM * MM; +} + +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + +// Emulated SIMD8 dpas: +__attribute__((overloadable)) +float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).x), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).y), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).x), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).y), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).x), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).y), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).x), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).y), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).x), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).y), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).x), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).y), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).x), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).y), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).x), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).y), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// Emulated SIMD16 dpas: +__attribute__((overloadable)) +float emu_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int load_a_rowmajor_16b_1r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + ret = intel_sub_group_block_read(A_ui + offset_ui); + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int2 load_a_rowmajor_16b_2r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int4 load_a_rowmajor_16b_4r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int8 load_a_rowmajor_16b_8r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_16b_8r16x2c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + uint16 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return as_int16(ret); +} + +// M rows x K columns x V tiles (in the K dimension) +void prefetch_a_rowmajor_16b_8r16x2c_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one value. +short load_a_rowmajor_16b_1r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + uint offset = rowStart * stride + colStart; + ret = intel_sub_group_block_read_us(A + offset); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one value. +short2 load_a_rowmajor_16b_2r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + uint offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one value. +short4 load_a_rowmajor_16b_4r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + uint offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one value. +short8 load_a_rowmajor_16b_8r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + uint offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s4 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s5 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s6 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s7 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short8(ret); +} + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_16b_8r16x2c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort16 ret; + + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; + + return as_short16(ret); +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_16b_8x2r16x2c_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns: +// Each work-item loads K values and packs into 32-bits. +// Stride is in units of elements. +int8 load_b_rowmajor_16b_16rNc(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + uint offset = rowStart * stride + colStart; + + ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row2 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row3 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row4 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row5 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row6 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row7 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row8 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row9 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row10 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row11 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row12 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row13 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row14 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row15 = intel_sub_group_block_read_us(B + offset); offset += stride; + + ret.s0 = as_int((ushort2)(row0, row1 )); + ret.s1 = as_int((ushort2)(row2, row3 )); + ret.s2 = as_int((ushort2)(row4, row5 )); + ret.s3 = as_int((ushort2)(row6, row7 )); + ret.s4 = as_int((ushort2)(row8, row9 )); + ret.s5 = as_int((ushort2)(row10, row11)); + ret.s6 = as_int((ushort2)(row12, row13)); + ret.s7 = as_int((ushort2)(row14, row15)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values that have already been packed into 32-bits. +// Stride is in units of elements. +int8 load_b_packed_16b_16rNc(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + uint offset_ui = rowStart / 2 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_16b_16r8x4c_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_16b_16r16x2c_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_packed_16b_16r8x2c_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the K dimension) +void prefetch_b_packed_16b_16x2r16c_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +void store_c_rowmajor_fp32_1rNc(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_fp32_2rNc(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_fp32_4rNc(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) diff --git a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl new file mode 100644 index 00000000..d76ee526 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -0,0 +1,763 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#if !defined(KK) +#define KK 1 +#endif + +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 +#endif + +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_rowmajor_16b_16rNc(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_packed, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_packed_16b_16rNc(B, k + kk * tK, n + nn * tN, N); + } + } +} + +#if HAS_SG8 + +void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_16b_8r16x2c_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_16b_16r8x4c_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_packed_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_packed_16b_16r8x2c_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_16b_8r16x2c_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_16b_8r16c_sg8(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_packed_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_packed, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#endif // HAS_SG8 + +void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_16b_8x2r16x2c_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_16b_16r16x2c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_packed, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_packed_16b_16x2r16c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_16b_8r16x2c_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_16b_8r16c_sg16(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_packed, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_packed, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_packed, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#ifdef cl_intel_subgroup_2d_block_io + +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + short8 aTemp[2][4]; + intel_sub_group_2d_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm]; + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + short8 aTemp[2][2]; + intel_sub_group_2d_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm]; + } + } + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short8 aTemp[2]; + intel_sub_group_2d_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + aData[kk + 0][mm] = aTemp[0]; + aData[kk + 1][mm] = aTemp[1]; + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + short8 aTemp[4]; + intel_sub_group_2d_block_read_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = aTemp[tmm]; + } + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + short8 aTemp[1]; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + aData[kk][mm] = aTemp[0]; + } + } + } +} + +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn+=2) { + //if (get_sub_group_local_id() == 0) { + // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + int8 bTemp[2][2]; + intel_sub_group_2d_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + for (int tnn = 0; tnn < 2; tnn++) { + for (int tkk = 0; tkk < 2; tkk++) { + bData[nn + tnn][kk + tkk] = bTemp[tnn][tkk]; + } + } + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + int8 bTemp[2]; + intel_sub_group_2d_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn + 0][kk] = bTemp[0]; + bData[nn + 1][kk] = bTemp[1]; + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int8 bTemp[2]; + intel_sub_group_2d_block_read_transform_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn][kk + 0] = bTemp[0]; + bData[nn][kk + 1] = bTemp[1]; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + int8 bTemp[1]; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn][kk] = bTemp[0]; + } + } + } +} + +void HELPER_NAME(btile_block_load_packed, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int8 bTemp[2]; + intel_sub_group_2d_block_read_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp); + bData[nn][kk + 0] = bTemp[0]; + bData[nn][kk + 1] = bTemp[1]; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + int8 bTemp[1]; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp); + bData[nn][kk] = bTemp[0]; + } + } + } +} + +void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) +{ + if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { + const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) + const int kk = 0; + const int mm = sg_index_x % 4; + //if (get_sub_group_local_id() == 0) { + // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } else if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + intel_sub_group_2d_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + intel_sub_group_2d_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + intel_sub_group_2d_block_prefetch_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_sub_group_2d_block_prefetch_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... + const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... + //if (get_sub_group_local_id() == 0) { + // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } else if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn += 2) { + intel_sub_group_2d_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_sub_group_2d_block_prefetch_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_sub_group_2d_block_prefetch_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_packed, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 + const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 + intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_sub_group_2d_block_prefetch_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_packed, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_packed, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + // TODO: skip prefetch on the last iterations. + HELPER_NAME(btile_block_prefetch_packed, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); + } + } +} + +#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl new file mode 100644 index 00000000..b98711b6 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl @@ -0,0 +1,631 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include "matrix_helpers_bf16.cl" + +#if EMULATE_tN8 +#define mat_mul_sg8 emu_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_sg8 intel_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_sg16 intel_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + float sum = 0; + for (int k = 0; k < K; k++) { + sum = fma(bf16_to_fp32(A[m * K + k]), bf16_to_fp32(B[k * N + n]), sum); + } + + sum = activation(sum); + C[m * N + n] = sum; +} + +// For all bfloat16 kernels tK == 16: +#define tK 16 + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#if HAS_SG8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_16b_1r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_16b_2r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_16b_4r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_16b_8r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); +} + +// pre-packed kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_16b_1r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_16b_2r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_16b_4r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_16b_8r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); +} + +#endif // HAS_SG8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_16b_1r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_16b_2r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_16b_4r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_16b_8r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); +} + +// pre-packed kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_16b_1r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_16b_2r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_16b_4r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_16b_8r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); +} + +#ifdef cl_intel_subgroup_2d_block_io + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + short aData; + intel_sub_group_2d_block_read_16b_1r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData; + intel_sub_group_2d_block_read_16b_2r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData; + intel_sub_group_2d_block_read_16b_4r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData);; + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + short aData; + intel_sub_group_2d_block_read_16b_1r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData; + intel_sub_group_2d_block_read_16b_2r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData; + intel_sub_group_2d_block_read_16b_4r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +#endif // cl_intel_subgroup_2d_block_io + +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled_bf16.cl" +#undef MM +#undef NN + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/20_matrixexperiments-i8/CMakeLists.txt b/samples/20_matrixexperiments-i8/CMakeLists.txt new file mode 100644 index 00000000..cc59c28d --- /dev/null +++ b/samples/20_matrixexperiments-i8/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2024-2026 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 20 + TARGET matrixexperiments-i8 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl) diff --git a/samples/20_matrixexperiments-i8/README.md b/samples/20_matrixexperiments-i8/README.md new file mode 100644 index 00000000..8fe63a4b --- /dev/null +++ b/samples/20_matrixexperiments-i8/README.md @@ -0,0 +1,60 @@ +# matrixexperiments-i8 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 8-bit integer data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +Note, these kernels are not as highly tuned as the kernels for `bfloat16` and `tf32`! +A good place to start for some devices is: + +```sh +./matrixexperiments-i8 -m4096 --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate +* cl_intel_subgroups +* cl_intel_subgroups_char + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_bf16.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp new file mode 100644 index 00000000..8bea9828 --- /dev/null +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -0,0 +1,647 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool zeroData = false; +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +bool roundRobin = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) +{ + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); +} + +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0; }); + } + else if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1; }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = static_cast(r + c); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_int_distribution dist(-64, 64); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); + } +} + +template +static void vnni_matrix( + std::vector &dst, const std::vector &src, + size_t numRows, size_t numCols, size_t factor) +{ + for (size_t r = 0; r < numRows / factor; r++) { + for (size_t c = 0; c < numCols; c++) { + for (size_t k = 0; k < factor; k++) { + dst[r * numCols * factor + c * factor + k] = + src[(r * factor + k) * numCols + c]; + } + } + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + if (C[index] != C_ref[index]) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + +static void i8_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "i8_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (K < 64 || N < 64) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (K < 64 || N < 64/4) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_i8.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperiments-i8 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + + if (!checkPlatformIndex(platforms, platformIndex)) { + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_sg8 = supportsSubgroupSize(device, 8); + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DHAS_SG8=" + std::to_string(has_sg8); + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + vnni_matrix(Bvnni_vec, B_vec, K, N, 4); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + i8_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x2) { + i8_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x20) { + i8_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x80) { + i8_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x200) { + i8_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x800) { + i8_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl new file mode 100644 index 00000000..9aa6ffcb --- /dev/null +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -0,0 +1,590 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +__attribute__((overloadable)) +int activation(int i) +{ +#if defined(ACTIVATION_RELU) + return max(i, 0); +#else // identity + return i; +#endif +} + +__attribute__((overloadable)) +int2 activation(int2 i) +{ + int2 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + return res; +} + +__attribute__((overloadable)) +int4 activation(int4 i) +{ + int4 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + return res; +} + +int8 activation(int8 i) +{ + int8 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + res.s4 = activation(i.s4); + res.s5 = activation(i.s5); + res.s6 = activation(i.s6); + res.s7 = activation(i.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + +#if defined(__opencl_c_integer_dot_product_input_4x8bit_packed) +#define dp4 dot_4x8packed_ss_int +#else +#define dp4 emu_dot_4x8packed_ss_int + +int emu_dot_4x8packed_ss_int(const uint a, const uint b) +{ + const char4 a_c4 = as_char4(a); + const char4 b_c4 = as_char4(b); + + return a_c4.x * b_c4.x + + a_c4.y * b_c4.y + + a_c4.z * b_c4.z + + a_c4.w * b_c4.w; +} +#endif + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char) + +typedef global char* global_aligned_char_ptr __attribute__((align_value(4))); + +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; + return m_index * tM * MM; +} + +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + +// Emulated SIMD8 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc) +{ + int res = acc; + + res = dp4(sub_group_broadcast(a, 0), b.s0) + res; + res = dp4(sub_group_broadcast(a, 1), b.s1) + res; + res = dp4(sub_group_broadcast(a, 2), b.s2) + res; + res = dp4(sub_group_broadcast(a, 3), b.s3) + res; + res = dp4(sub_group_broadcast(a, 4), b.s4) + res; + res = dp4(sub_group_broadcast(a, 5), b.s5) + res; + res = dp4(sub_group_broadcast(a, 6), b.s6) + res; + res = dp4(sub_group_broadcast(a, 7), b.s7) + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// Emulated SIMD16 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) +{ + float res = acc; + + res = dp4(as_uint((short2)(sub_group_broadcast(a, 0), sub_group_broadcast(a, 1))), b.s0) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 2), sub_group_broadcast(a, 3))), b.s1) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 4), sub_group_broadcast(a, 5))), b.s2) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 6), sub_group_broadcast(a, 7))), b.s3) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 8), sub_group_broadcast(a, 9))), b.s4) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 10), sub_group_broadcast(a, 11))), b.s5) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 12), sub_group_broadcast(a, 13))), b.s6) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 14), sub_group_broadcast(a, 15))), b.s7) + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int load_a_rowmajor_d8_m1_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + ret = intel_sub_group_block_read(A_ui + offset_ui); + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int2 load_a_rowmajor_d8_m2_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int4 load_a_rowmajor_d8_m4_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int8 load_a_rowmajor_d8_m8_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + uint16 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return as_int16(ret); +} + +// M rows x K columns x V tiles (in the K dimension) +void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short load_a_rowmajor_d8_m1_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret = intel_sub_group_block_read_us(A_us + offset_us); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short2 load_a_rowmajor_d8_m2_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short4 load_a_rowmajor_d8_m4_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short8 load_a_rowmajor_d8_m8_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s4 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s5 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s6 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s7 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short8(ret); +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort16 ret; + + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; + + return as_short16(ret); +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// K rows x N columns: +// Each work-item loads K values and packs into 32-bits. +// Stride is in units of elements. +int8 load_b_rowmajor_8b_32rNc(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uchar* B_uc = (global uchar*)B; + uint offset = rowStart * stride + colStart; + + uchar row0 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row1 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row2 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row3 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row4 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row5 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row6 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row7 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row8 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row9 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row10 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row11 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row12 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row13 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row14 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row15 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row16 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row17 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row18 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row19 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row20 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row21 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row22 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row23 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row24 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row25 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row26 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row27 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row28 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row29 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row30 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row31 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + + ret.s0 = as_int((uchar4)(row0, row1, row2, row3)); + ret.s1 = as_int((uchar4)(row4, row5, row6, row7)); + ret.s2 = as_int((uchar4)(row8, row9, row10, row11)); + ret.s3 = as_int((uchar4)(row12, row13, row14, row15)); + ret.s4 = as_int((uchar4)(row16, row17, row18, row19)); + ret.s5 = as_int((uchar4)(row20, row21, row22, row23)); + ret.s6 = as_int((uchar4)(row24, row25, row26, row27)); + ret.s7 = as_int((uchar4)(row28, row29, row30, row31)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +int8 load_b_vnni_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + uint offset_ui = rowStart / 4 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +#if 0 + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the K dimension) +void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +void store_c_rowmajor_int32_m1_nx(global int* C, int v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_int32_m2_nx(global int* C, int2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_int32_m4_nx(global int* C, int4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) diff --git a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl new file mode 100644 index 00000000..6bec5f2f --- /dev/null +++ b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl @@ -0,0 +1,587 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include "matrix_helpers_i8.cl" + +#if EMULATE_tN8 +#define mat_mul_sg8 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg8 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg16 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +kernel void i8_naive(global int* C, global char* A, global char* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + int sum = 0; + for (int k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + + sum = activation(sum); + C[m * N + n] = sum; +} + +// For all i8 kernels tK == 32: +#define tK 32 + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) + +#if HAS_SG8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#endif // HAS_SG8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#ifdef cl_intel_subgroup_2d_block_io + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData; + intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData; + intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData; + intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData; + intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData; + intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData; + intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData; + intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData; + intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +#endif // cl_intel_subgroup_2d_block_io + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/20_matrixexperiments-tf32/CMakeLists.txt b/samples/20_matrixexperiments-tf32/CMakeLists.txt new file mode 100644 index 00000000..fe34bea2 --- /dev/null +++ b/samples/20_matrixexperiments-tf32/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2024-2026 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 20 + TARGET matrixexperiments-tf32 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_tf32.cl matrix_kernels_tf32.cl matrix_kernel_tiled_tf32.cl) diff --git a/samples/20_matrixexperiments-tf32/README.md b/samples/20_matrixexperiments-tf32/README.md new file mode 100644 index 00000000..62ffadf5 --- /dev/null +++ b/samples/20_matrixexperiments-tf32/README.md @@ -0,0 +1,58 @@ +# matrixexperiments-tf32 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 32-bit `tf32` data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +A good place to start for some devices is: + +```sh +./matrixexperiments-tf32 -m4096 --options="-DSGS_PER_WG_X=4 -DSGS_PER_WG_Y=8 -DKK=2 -cl-intel-256-GRF-per-thread" --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate_tf32 +* cl_intel_subgroups + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_tf32.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp new file mode 100644 index 00000000..fcfcc2fc --- /dev/null +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -0,0 +1,647 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool zeroData = false; +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +bool roundRobin = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + +float to_tf32(float f) +{ + union { + uint32_t u; + float f; + } value; + + value.f = f; + value.u &= 0xFFFFE000; + + // Be careful not to convert NAN to INF: + if (std::isnan(f) && !std::isnan(value.f)) { + value.u |= 0x00002000; + } + + return value.f; +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0.0f; }); + } + else if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return to_tf32(1.0f); }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = to_tf32(static_cast(r) + static_cast(c) / 64.0f); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return to_tf32(dist(rng)); }); + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = std::fma(static_cast(A[m * K + k]), + static_cast(B[k * N + n]), sum); + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + auto localErr = std::fabs(C[index] - C_ref[index]) / + std::max(std::fabs(C[index]), + std::fabs(C_ref[index])); + err = std::max(localErr, err); + if (localErr >= threshold) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (local error " << localErr << "): Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + +static void tf32_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "tf32_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_tf32.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperiments-tf32 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + + if (!checkPlatformIndex(platforms, platformIndex)) { + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate_tf32")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate_tf32, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + tf32_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x20) { + tf32_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + tf32_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x200) { + tf32_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + tf32_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl new file mode 100644 index 00000000..da7c1f8a --- /dev/null +++ b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl @@ -0,0 +1,292 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +__attribute__((overloadable)) +float activation(float f) +{ +#if defined(ACTIVATION_RELU) + return fmax(f, 0); +#else // identity + return f; +#endif +} + +__attribute__((overloadable)) +float2 activation(float2 f) +{ + float2 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + return res; +} + +__attribute__((overloadable)) +float4 activation(float4 f) +{ + float4 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + return res; +} + +float8 activation(float8 f) +{ + float8 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + res.s4 = activation(f.s4); + res.s5 = activation(f.s5); + res.s6 = activation(f.s6); + res.s7 = activation(f.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + +#if defined(cl_intel_subgroups) + +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; + return m_index * tM * MM; +} + +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + +// Emulated dpas: +__attribute__((overloadable)) +float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) +{ + float res = acc; + + res = fma(sub_group_broadcast(a, 0), b.s0, res); + res = fma(sub_group_broadcast(a, 1), b.s1, res); + res = fma(sub_group_broadcast(a, 2), b.s2, res); + res = fma(sub_group_broadcast(a, 3), b.s3, res); + res = fma(sub_group_broadcast(a, 4), b.s4, res); + res = fma(sub_group_broadcast(a, 5), b.s5, res); + res = fma(sub_group_broadcast(a, 6), b.s6, res); + res = fma(sub_group_broadcast(a, 7), b.s7, res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float2 acc) +{ + float2 res = acc; + + res.s0 = fma(sub_group_broadcast(a, 0), b.s0, res.s0); + res.s0 = fma(sub_group_broadcast(a, 1), b.s1, res.s0); + res.s0 = fma(sub_group_broadcast(a, 2), b.s2, res.s0); + res.s0 = fma(sub_group_broadcast(a, 3), b.s3, res.s0); + res.s0 = fma(sub_group_broadcast(a, 4), b.s4, res.s0); + res.s0 = fma(sub_group_broadcast(a, 5), b.s5, res.s0); + res.s0 = fma(sub_group_broadcast(a, 6), b.s6, res.s0); + res.s0 = fma(sub_group_broadcast(a, 7), b.s7, res.s0); + + res.s1 = fma(sub_group_broadcast(a, 8), b.s0, res.s1); + res.s1 = fma(sub_group_broadcast(a, 9), b.s1, res.s1); + res.s1 = fma(sub_group_broadcast(a, 10), b.s2, res.s1); + res.s1 = fma(sub_group_broadcast(a, 11), b.s3, res.s1); + res.s1 = fma(sub_group_broadcast(a, 12), b.s4, res.s1); + res.s1 = fma(sub_group_broadcast(a, 13), b.s5, res.s1); + res.s1 = fma(sub_group_broadcast(a, 14), b.s6, res.s1); + res.s1 = fma(sub_group_broadcast(a, 15), b.s7, res.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_tf32_tf32_matrix_mad_k8(float2 a, float8 b, float4 acc) +{ + float4 res; + + res.s01 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s0, b, acc.s01); + res.s23 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s1, b, acc.s23); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_tf32_tf32_matrix_mad_k8(float4 a, float8 b, float8 acc) +{ + float8 res; + + res.s01 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s0, b, acc.s01); + res.s23 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s1, b, acc.s23); + res.s45 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s2, b, acc.s45); + res.s67 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s3, b, acc.s67); + + return res; +} + +// M rows x K columns +float load_a_rowmajor_32b_1r8c_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float ret; + + // Note: only the low eight channels should be used. + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() % 8); + + ret = A[offset]; + + return ret; +} + +// M rows x K columns +float load_a_rowmajor_32b_2r8c_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret = A[offset]; + + return ret; +} + +// M rows x K columns +float2 load_a_rowmajor_32b_4r8c_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float2 ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret.s0 = A[offset]; offset += stride * 2; + ret.s1 = A[offset]; offset += stride * 2; + + return ret; +} + +// M rows x K columns +float4 load_a_rowmajor_32b_8r8c_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float4 ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret.s0 = A[offset]; offset += stride * 2; + ret.s1 = A[offset]; offset += stride * 2; + ret.s2 = A[offset]; offset += stride * 2; + ret.s3 = A[offset]; offset += stride * 2; + + return ret; +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_32b_8x2r8x2c_sg16(global float* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(A + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns: +// Each work-item loads K values. +// Stride is in units of elements. +float8 load_b_rowmajor_32b_8rNc(global float* B, int rowStart, int colStart, int stride) +{ + float8 ret; + + uint offset = rowStart * stride + colStart; + + ret.s0 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s1 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s2 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s3 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s4 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s5 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s6 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s7 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + + return ret; +} + +// K rows x N columns x V tiles (in the K and N dimensions) +void prefetch_b_rowmajor_32b_8x2r8x2c_sg16(global float* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(B + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +void store_c_rowmajor_fp32_1rNc(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_fp32_2rNc(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_fp32_4rNc(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) diff --git a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl new file mode 100644 index 00000000..e5fb52ee --- /dev/null +++ b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl @@ -0,0 +1,228 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#if !defined(KK) +#define KK 1 +#endif + +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 +#endif + +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global float* B, int tN, int N, int k, int n, float8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_rowmajor_32b_8rNc(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(global float* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_32b_8x2r8x2c_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(global float* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_32b_8x2r8x2c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor_sg16, MM, NN)(global float* A, int tM, int K, int m, int k, float4 aData[KK][MM]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_32b_8r8c_sg16(A, m + mm * tM, k + kk * tK, K); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + float4 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg16, MM, NN)(A, tM, K, m, k, aData); + + float8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#ifdef cl_intel_subgroup_2d_block_io + +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int M, int K, int m, int k, float4 aData[KK][MM]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_sub_group_2d_block_read_32b_8r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM), (uint*)&aData[kk][mm]); + } + } +} + +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global float* B, int tN, int K, int N, int k, int n, float8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK), (uint*)&bData[nn][kk]); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + float4 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + float8 bData[NN][KK]; + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); + } + } +} + +#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl new file mode 100644 index 00000000..bdac8e37 --- /dev/null +++ b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl @@ -0,0 +1,267 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include "matrix_helpers_tf32.cl" + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_tf32_tf32_matrix_mad_k8 +#else +#define mat_mul_sg16 intel_sub_group_tf32_tf32_matrix_mad_k8 +#endif + +kernel void tf32_naive(global float* C, global float* A, global float* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + float sum = 0; + for (int k = 0; k < K; k++) { + sum = fma(A[m * K + k], B[k * N + n], sum); + } + + sum = activation(sum); + C[m * N + n] = sum; +} + +// For all tf32 kernels tK == 8: +#define tK 8 + +#if defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = load_a_rowmajor_32b_1r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = load_a_rowmajor_32b_2r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + float2 aData = load_a_rowmajor_32b_4r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + float4 aData = load_a_rowmajor_32b_8r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); +} + +#ifdef cl_intel_subgroup_2d_block_io + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + float aData; + intel_sub_group_2d_block_read_32b_1r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + float aData; + intel_sub_group_2d_block_read_32b_2r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + float2 aData; + intel_sub_group_2d_block_read_32b_4r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + float4 aData; + intel_sub_group_2d_block_read_32b_8r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + +#endif // cl_intel_subgroup_2d_block_io + +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index d0d8ac35..8e8d5282 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -78,6 +78,10 @@ add_subdirectory( 06_ndrangekernelfromfile ) add_subdirectory( 10_queueexperiments ) add_subdirectory( 16_floatatomics ) +add_subdirectory( 20_matrixexperiments-bf16 ) +add_subdirectory( 20_matrixexperiments-i8 ) +add_subdirectory( 20_matrixexperiments-tf32 ) + set(BUILD_EXTENSION_SAMPLES TRUE) if(NOT TARGET OpenCLExt) message(STATUS "Skipping Extension Samples - OpenCL Extension Loader is not found.")