Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit fe590fe

Browse filesBrowse files
committed
Bump Cutlass to 3.5.1 OSS PR
ghstack-source-id: 962588b Pull Request resolved: #144000
1 parent 1ab9e83 commit fe590fe
Copy full SHA for fe590fe

File tree

Expand file treeCollapse file tree

2 files changed

+5
-39
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+5
-39
lines changed

‎aten/src/ATen/native/cuda/RowwiseScaledMM.cu

Copy file name to clipboardExpand all lines: aten/src/ATen/native/cuda/RowwiseScaledMM.cu
+4-38Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,58 +5,24 @@
55
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
66

77
// Determine if the architecture supports rowwise scaled mm
8-
// Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571
8+
// Currently failing on windows with:
9+
// https://github.com/NVIDIA/cutlass/issues/1571
910
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
1011

1112
#define BUILD_ROWWISE_FP8_KERNEL
1213
#endif
1314

1415
#if defined(BUILD_ROWWISE_FP8_KERNEL)
1516

16-
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
17-
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
18-
CUtensorMap* tensorMap,
19-
CUtensorMapDataType tensorDataType,
20-
cuuint32_t tensorRank,
21-
void* globalAddress,
22-
const cuuint64_t* globalDim,
23-
const cuuint64_t* globalStrides,
24-
const cuuint32_t* boxDim,
25-
const cuuint32_t* elementStrides,
26-
CUtensorMapInterleave interleave,
27-
CUtensorMapSwizzle swizzle,
28-
CUtensorMapL2promotion l2Promotion,
29-
CUtensorMapFloatOOBfill oobFill) {
30-
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
31-
tensorMap,
32-
tensorDataType,
33-
tensorRank,
34-
globalAddress,
35-
globalDim,
36-
globalStrides,
37-
boxDim,
38-
elementStrides,
39-
interleave,
40-
swizzle,
41-
l2Promotion,
42-
oobFill);
43-
}
44-
45-
46-
#include <cutlass/version.h>
17+
#include <cute/tensor.hpp>
4718
#include <cutlass/core_io.h>
4819
#include <cutlass/cutlass.h>
4920
#include <cutlass/gemm/device/gemm.h>
5021
#include <cutlass/half.h>
5122
#include <cutlass/numeric_types.h>
5223
#include <cutlass/trace.h>
5324
#include <cutlass/util/host_tensor.h>
54-
55-
// Rename the global function symbol
56-
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
57-
#include <cute/tensor.hpp>
58-
#undef cuTensorMapEncodeTiled
59-
// Set everything back to normal
25+
#include <cutlass/version.h>
6026

6127
#include <cutlass/gemm/collective/collective_builder.hpp>
6228
#include <cutlass/gemm/device/gemm_universal_adapter.h>

‎third_party/cutlass

Copy file name to clipboard
Submodule cutlass updated 696 files

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.