|
5 | 5 | #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
6 | 6 |
|
7 | 7 | // Determine if the architecture supports rowwise scaled mm
|
8 |
| -// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571 |
| 8 | +// Currently failing on windows with: |
| 9 | +// https://github.com/NVIDIA/cutlass/issues/1571 |
9 | 10 | #if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
|
10 | 11 |
|
11 | 12 | #define BUILD_ROWWISE_FP8_KERNEL
|
12 | 13 | #endif
|
13 | 14 |
|
14 | 15 | #if defined(BUILD_ROWWISE_FP8_KERNEL)
|
15 | 16 |
|
16 |
| -// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader |
17 |
| -static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled( |
18 |
| - CUtensorMap* tensorMap, |
19 |
| - CUtensorMapDataType tensorDataType, |
20 |
| - cuuint32_t tensorRank, |
21 |
| - void* globalAddress, |
22 |
| - const cuuint64_t* globalDim, |
23 |
| - const cuuint64_t* globalStrides, |
24 |
| - const cuuint32_t* boxDim, |
25 |
| - const cuuint32_t* elementStrides, |
26 |
| - CUtensorMapInterleave interleave, |
27 |
| - CUtensorMapSwizzle swizzle, |
28 |
| - CUtensorMapL2promotion l2Promotion, |
29 |
| - CUtensorMapFloatOOBfill oobFill) { |
30 |
| - return at::globalContext().getNVRTC().cuTensorMapEncodeTiled( |
31 |
| - tensorMap, |
32 |
| - tensorDataType, |
33 |
| - tensorRank, |
34 |
| - globalAddress, |
35 |
| - globalDim, |
36 |
| - globalStrides, |
37 |
| - boxDim, |
38 |
| - elementStrides, |
39 |
| - interleave, |
40 |
| - swizzle, |
41 |
| - l2Promotion, |
42 |
| - oobFill); |
43 |
| -} |
44 |
| - |
45 |
| - |
46 |
| -#include <cutlass/version.h> |
| 17 | +#include <cute/tensor.hpp> |
47 | 18 | #include <cutlass/core_io.h>
|
48 | 19 | #include <cutlass/cutlass.h>
|
49 | 20 | #include <cutlass/gemm/device/gemm.h>
|
50 | 21 | #include <cutlass/half.h>
|
51 | 22 | #include <cutlass/numeric_types.h>
|
52 | 23 | #include <cutlass/trace.h>
|
53 | 24 | #include <cutlass/util/host_tensor.h>
|
54 |
| - |
55 |
| -// Rename the global function symbol |
56 |
| -#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled |
57 |
| -#include <cute/tensor.hpp> |
58 |
| -#undef cuTensorMapEncodeTiled |
59 |
| -// Set everything back to normal |
| 25 | +#include <cutlass/version.h> |
60 | 26 |
|
61 | 27 | #include <cutlass/gemm/collective/collective_builder.hpp>
|
62 | 28 | #include <cutlass/gemm/device/gemm_universal_adapter.h>
|
|
0 commit comments