Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4aa66f6

Browse filesBrowse files
eqypytorchmergebot
authored andcommitted
[CUDA][CUTLASS][submodule] Fixes for CUTLASS upgrade (#131493)
Unblocks/unbreaks against newer CUTLASS (3.5+) CC @nWEIdia @xwang233 @ptrblck @thakkarV Pull Request resolved: #131493 Approved by: https://github.com/Skylion007
1 parent 41d6cab commit 4aa66f6
Copy full SHA for 4aa66f6

File tree

Expand file treeCollapse file tree

9 files changed

+50
-19
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+50
-19
lines changed

‎aten/src/ATen/native/cuda/MixedDtypesLinear.cu

Copy file name to clipboardExpand all lines: aten/src/ATen/native/cuda/MixedDtypesLinear.cu
+2Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
// Doesn't work on ROCm or Windows yet
77
// TODO: Add compiler warning? Add PyTorch config flag?
88
#else
9+
#include <cuda_fp16.h>
10+
911
#include <cuda_runtime.h>
1012
#include <cutlass/cutlass.h>
1113
#include <cutlass/tensor_ref.h>

‎aten/src/ATen/native/cuda/RowwiseScaledMM.cu

Copy file name to clipboardExpand all lines: aten/src/ATen/native/cuda/RowwiseScaledMM.cu
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,13 @@ void f8f8bf16_rowwise_impl(
141141
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
142142

143143
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
144-
PONG ? 2 : 1,
144+
0,
145145
TileShape,
146146
ElementComputeEpilogue,
147147
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
148148

149149
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
150-
PONG ? 2 : 1,
150+
0,
151151
TileShape,
152152
ElementBias,
153153
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

‎aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h

Copy file name to clipboardExpand all lines: aten/src/ATen/native/sparse/cuda/ComputeSparseTile.h
+2-12Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,6 @@
99
// sparsification, as a bitmask.
1010
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
1111

12-
namespace platform {
13-
template <>
14-
struct numeric_limits<cutlass::bfloat16_t> {
15-
CUTLASS_HOST_DEVICE
16-
static cutlass::bfloat16_t infinity() {
17-
return cutlass::bfloat16_t::bitcast(0x7f80);
18-
}
19-
};
20-
} // namespace platform
21-
2212
namespace at::native{
2313

2414
template <typename Element, typename Pointwise>
@@ -68,7 +58,7 @@ template <typename Op = IdentityOp>
6858
struct LargestValuesGreedy {
6959
template <typename T>
7060
static CUTLASS_DEVICE T outOfBoundsFillValue() {
71-
return -platform::numeric_limits<T>::infinity();
61+
return -std::numeric_limits<T>::infinity();
7262
}
7363

7464
template <typename Tile4x4Accessor>
@@ -128,7 +118,7 @@ template <typename Op = IdentityOp>
128118
struct Causal1122 {
129119
template <typename T>
130120
static CUTLASS_DEVICE T outOfBoundsFillValue() {
131-
return -platform::numeric_limits<T>::infinity();
121+
return -std::numeric_limits<T>::infinity();
132122
}
133123

134124
template <typename Tile4x4Accessor>

‎aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu

Copy file name to clipboardExpand all lines: aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu
+22-1Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ template <
4444
typename ThreadblockShape,
4545
typename WarpShape,
4646
typename InstructionShape,
47+
typename Operator,
4748
typename LayoutInputA,
4849
typename LayoutInputB,
4950
bool use_bias,
@@ -62,7 +63,6 @@ Tensor two_four_sgemm(
6263
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
6364
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
6465
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
65-
using Operator = cutlass::arch::OpMultiplyAdd;
6666
constexpr int NumEVTEpilogueStages = 1;
6767

6868
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
@@ -317,6 +317,7 @@ template <
317317
typename ThreadblockShape,
318318
typename WarpShape,
319319
typename InstructionShape,
320+
typename Operator,
320321
bool EnableRowMajorRowMajorLayouts,
321322
bool EnableRowMajorColumnMajorLayouts,
322323
bool EnableColumnMajorRowMajorLayouts,
@@ -345,6 +346,7 @@ Tensor two_four_sgemm_dispatch_layouts(
345346
ThreadblockShape,
346347
WarpShape,
347348
InstructionShape,
349+
Operator,
348350
cutlass::layout::RowMajor,
349351
cutlass::layout::RowMajor,
350352
use_bias,
@@ -367,6 +369,7 @@ Tensor two_four_sgemm_dispatch_layouts(
367369
ThreadblockShape,
368370
WarpShape,
369371
InstructionShape,
372+
Operator,
370373
cutlass::layout::RowMajor,
371374
cutlass::layout::ColumnMajor,
372375
use_bias,
@@ -389,6 +392,7 @@ Tensor two_four_sgemm_dispatch_layouts(
389392
ThreadblockShape,
390393
WarpShape,
391394
InstructionShape,
395+
Operator,
392396
cutlass::layout::ColumnMajor,
393397
cutlass::layout::RowMajor,
394398
use_bias,
@@ -411,6 +415,7 @@ Tensor two_four_sgemm_dispatch_layouts(
411415
ThreadblockShape,
412416
WarpShape,
413417
InstructionShape,
418+
Operator,
414419
cutlass::layout::ColumnMajor,
415420
cutlass::layout::ColumnMajor,
416421
use_bias,
@@ -440,6 +445,7 @@ template <
440445
typename ThreadblockShape,
441446
typename WarpShape,
442447
typename InstructionShape,
448+
typename Operator,
443449
bool EnableRowMajorRowMajorLayouts,
444450
bool EnableRowMajorColumnMajorLayouts,
445451
bool EnableColumnMajorRowMajorLayouts,
@@ -457,6 +463,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
457463
ThreadblockShape,
458464
WarpShape,
459465
InstructionShape,
466+
Operator,
460467
EnableRowMajorRowMajorLayouts,
461468
EnableRowMajorColumnMajorLayouts,
462469
EnableColumnMajorRowMajorLayouts,
@@ -476,6 +483,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias(
476483
ThreadblockShape,
477484
WarpShape,
478485
InstructionShape,
486+
Operator,
479487
EnableRowMajorRowMajorLayouts,
480488
EnableRowMajorColumnMajorLayouts,
481489
EnableColumnMajorRowMajorLayouts,
@@ -498,6 +506,7 @@ template <
498506
typename ThreadblockShape,
499507
typename WarpShape,
500508
typename InstructionShape,
509+
typename Operator,
501510
bool EnableRowMajorRowMajorLayouts,
502511
bool EnableRowMajorColumnMajorLayouts,
503512
bool EnableColumnMajorRowMajorLayouts,
@@ -519,6 +528,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
519528
ThreadblockShape,
520529
WarpShape,
521530
InstructionShape,
531+
Operator,
522532
EnableRowMajorRowMajorLayouts,
523533
EnableRowMajorColumnMajorLayouts,
524534
EnableColumnMajorRowMajorLayouts,
@@ -540,6 +550,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
540550
ThreadblockShape,
541551
WarpShape,
542552
InstructionShape,
553+
Operator,
543554
EnableRowMajorRowMajorLayouts,
544555
EnableRowMajorColumnMajorLayouts,
545556
EnableColumnMajorRowMajorLayouts,
@@ -561,6 +572,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation(
561572
ThreadblockShape,
562573
WarpShape,
563574
InstructionShape,
575+
Operator,
564576
EnableRowMajorRowMajorLayouts,
565577
EnableRowMajorColumnMajorLayouts,
566578
EnableColumnMajorRowMajorLayouts,
@@ -717,6 +729,7 @@ Tensor _sparse_semi_structured_linear(
717729
cutlass::gemm::GemmShape<128, 128, 128>;
718730
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
719731
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
732+
using Operator = cutlass::arch::OpMultiplyAddSaturate;
720733
const auto EnableRowMajorRowMajorLayouts = false;
721734
const auto EnableRowMajorColumnMajorLayouts = true;
722735
const auto EnableColumnMajorRowMajorLayouts = false;
@@ -734,6 +747,7 @@ Tensor _sparse_semi_structured_linear(
734747
ThreadblockShape,
735748
WarpShape,
736749
InstructionShape,
750+
Operator,
737751
EnableRowMajorRowMajorLayouts,
738752
EnableRowMajorColumnMajorLayouts,
739753
EnableColumnMajorRowMajorLayouts,
@@ -756,6 +770,7 @@ Tensor _sparse_semi_structured_linear(
756770
ThreadblockShape,
757771
WarpShape,
758772
InstructionShape,
773+
Operator,
759774
EnableRowMajorRowMajorLayouts,
760775
EnableRowMajorColumnMajorLayouts,
761776
EnableColumnMajorRowMajorLayouts,
@@ -781,6 +796,7 @@ Tensor _sparse_semi_structured_linear(
781796
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
782797
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
783798
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
799+
using Operator = cutlass::arch::OpMultiplyAdd;
784800
const auto EnableRowMajorRowMajorLayouts = true;
785801
const auto EnableRowMajorColumnMajorLayouts = true;
786802
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -796,6 +812,7 @@ Tensor _sparse_semi_structured_linear(
796812
ThreadblockShape,
797813
WarpShape,
798814
InstructionShape,
815+
Operator,
799816
EnableRowMajorRowMajorLayouts,
800817
EnableRowMajorColumnMajorLayouts,
801818
EnableColumnMajorRowMajorLayouts,
@@ -820,6 +837,7 @@ Tensor _sparse_semi_structured_linear(
820837
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
821838
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
822839
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
840+
using Operator = cutlass::arch::OpMultiplyAdd;
823841
const auto EnableRowMajorRowMajorLayouts = true;
824842
const auto EnableRowMajorColumnMajorLayouts = true;
825843
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -835,6 +853,7 @@ Tensor _sparse_semi_structured_linear(
835853
ThreadblockShape,
836854
WarpShape,
837855
InstructionShape,
856+
Operator,
838857
EnableRowMajorRowMajorLayouts,
839858
EnableRowMajorColumnMajorLayouts,
840859
EnableColumnMajorRowMajorLayouts,
@@ -859,6 +878,7 @@ Tensor _sparse_semi_structured_linear(
859878
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
860879
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
861880
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
881+
using Operator = cutlass::arch::OpMultiplyAdd;
862882
const auto EnableRowMajorRowMajorLayouts = true;
863883
const auto EnableRowMajorColumnMajorLayouts = true;
864884
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -874,6 +894,7 @@ Tensor _sparse_semi_structured_linear(
874894
ThreadblockShape,
875895
WarpShape,
876896
InstructionShape,
897+
Operator,
877898
EnableRowMajorRowMajorLayouts,
878899
EnableRowMajorColumnMajorLayouts,
879900
EnableColumnMajorRowMajorLayouts,

‎aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu

Copy file name to clipboardExpand all lines: aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu
+18-1Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ template <
4141
typename ThreadblockShape,
4242
typename WarpShape,
4343
typename InstructionShape,
44+
typename Operator,
4445
typename LayoutInputA,
4546
typename LayoutInputB,
4647
bool use_tensor_c>
@@ -57,7 +58,6 @@ void spgemm_cutlass(
5758
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
5859
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
5960
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
60-
using Operator = cutlass::arch::OpMultiplyAdd;
6161
constexpr int NumEVTEpilogueStages = 1;
6262

6363
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
@@ -305,6 +305,7 @@ template <
305305
typename ThreadblockShape,
306306
typename WarpShape,
307307
typename InstructionShape,
308+
typename Operator,
308309
bool EnableRowMajorRowMajorLayouts,
309310
bool EnableRowMajorColumnMajorLayouts,
310311
bool EnableColumnMajorRowMajorLayouts,
@@ -333,6 +334,7 @@ void spgemm_cutlass_dispatch_layouts(
333334
ThreadblockShape,
334335
WarpShape,
335336
InstructionShape,
337+
Operator,
336338
cutlass::layout::RowMajor,
337339
cutlass::layout::RowMajor,
338340
use_tensor_c>(
@@ -358,6 +360,7 @@ void spgemm_cutlass_dispatch_layouts(
358360
ThreadblockShape,
359361
WarpShape,
360362
InstructionShape,
363+
Operator,
361364
cutlass::layout::RowMajor,
362365
cutlass::layout::ColumnMajor,
363366
use_tensor_c>(
@@ -383,6 +386,7 @@ void spgemm_cutlass_dispatch_layouts(
383386
ThreadblockShape,
384387
WarpShape,
385388
InstructionShape,
389+
Operator,
386390
cutlass::layout::ColumnMajor,
387391
cutlass::layout::RowMajor,
388392
use_tensor_c>(
@@ -408,6 +412,7 @@ void spgemm_cutlass_dispatch_layouts(
408412
ThreadblockShape,
409413
WarpShape,
410414
InstructionShape,
415+
Operator,
411416
cutlass::layout::ColumnMajor,
412417
cutlass::layout::ColumnMajor,
413418
use_tensor_c>(
@@ -439,6 +444,7 @@ template <
439444
typename ThreadblockShape,
440445
typename WarpShape,
441446
typename InstructionShape,
447+
typename Operator,
442448
bool EnableRowMajorRowMajorLayouts,
443449
bool EnableRowMajorColumnMajorLayouts,
444450
bool EnableColumnMajorRowMajorLayouts,
@@ -456,6 +462,7 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
456462
ThreadblockShape,
457463
WarpShape,
458464
InstructionShape,
465+
Operator,
459466
EnableRowMajorRowMajorLayouts,
460467
EnableRowMajorColumnMajorLayouts,
461468
EnableColumnMajorRowMajorLayouts,
@@ -477,6 +484,7 @@ void spgemm_cutlass_dispatch_layouts_tensor_c(
477484
ThreadblockShape,
478485
WarpShape,
479486
InstructionShape,
487+
Operator,
480488
EnableRowMajorRowMajorLayouts,
481489
EnableRowMajorColumnMajorLayouts,
482490
EnableColumnMajorRowMajorLayouts,
@@ -629,6 +637,7 @@ Tensor sparse_semi_structured_mad_op(
629637
cutlass::gemm::GemmShape<128, 128, 128>;
630638
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
631639
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
640+
using Operator = cutlass::arch::OpMultiplyAddSaturate;
632641
const auto EnableRowMajorRowMajorLayouts = false;
633642
const auto EnableRowMajorColumnMajorLayouts = true;
634643
const auto EnableColumnMajorRowMajorLayouts = false;
@@ -643,6 +652,7 @@ Tensor sparse_semi_structured_mad_op(
643652
ThreadblockShape,
644653
WarpShape,
645654
InstructionShape,
655+
Operator,
646656
EnableRowMajorRowMajorLayouts,
647657
EnableRowMajorColumnMajorLayouts,
648658
EnableColumnMajorRowMajorLayouts,
@@ -664,6 +674,7 @@ Tensor sparse_semi_structured_mad_op(
664674
ThreadblockShape,
665675
WarpShape,
666676
InstructionShape,
677+
Operator,
667678
EnableRowMajorRowMajorLayouts,
668679
EnableRowMajorColumnMajorLayouts,
669680
EnableColumnMajorRowMajorLayouts,
@@ -687,6 +698,7 @@ Tensor sparse_semi_structured_mad_op(
687698
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
688699
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
689700
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
701+
using Operator = cutlass::arch::OpMultiplyAdd;
690702
const auto EnableRowMajorRowMajorLayouts = true;
691703
const auto EnableRowMajorColumnMajorLayouts = true;
692704
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -699,6 +711,7 @@ Tensor sparse_semi_structured_mad_op(
699711
ThreadblockShape,
700712
WarpShape,
701713
InstructionShape,
714+
Operator,
702715
EnableRowMajorRowMajorLayouts,
703716
EnableRowMajorColumnMajorLayouts,
704717
EnableColumnMajorRowMajorLayouts,
@@ -721,6 +734,7 @@ Tensor sparse_semi_structured_mad_op(
721734
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
722735
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
723736
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
737+
using Operator = cutlass::arch::OpMultiplyAdd;
724738
const auto EnableRowMajorRowMajorLayouts = true;
725739
const auto EnableRowMajorColumnMajorLayouts = true;
726740
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -733,6 +747,7 @@ Tensor sparse_semi_structured_mad_op(
733747
ThreadblockShape,
734748
WarpShape,
735749
InstructionShape,
750+
Operator,
736751
EnableRowMajorRowMajorLayouts,
737752
EnableRowMajorColumnMajorLayouts,
738753
EnableColumnMajorRowMajorLayouts,
@@ -755,6 +770,7 @@ Tensor sparse_semi_structured_mad_op(
755770
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
756771
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
757772
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
773+
using Operator = cutlass::arch::OpMultiplyAdd;
758774
const auto EnableRowMajorRowMajorLayouts = true;
759775
const auto EnableRowMajorColumnMajorLayouts = true;
760776
const auto EnableColumnMajorRowMajorLayouts = true;
@@ -767,6 +783,7 @@ Tensor sparse_semi_structured_mad_op(
767783
ThreadblockShape,
768784
WarpShape,
769785
InstructionShape,
786+
Operator,
770787
EnableRowMajorRowMajorLayouts,
771788
EnableRowMajorColumnMajorLayouts,
772789
EnableColumnMajorRowMajorLayouts,

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.