-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[CUDA][CUTLASS][submodule] Fixes for CUTLASS upgrade #131493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
LGTM - as far as changes in |
@@ -141,13 +141,13 @@ void f8f8bf16_rowwise_impl( | ||
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>; | ||
|
||
using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast< | ||
PONG ? 2 : 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this have perf impact?
I hade this PR: #131687, hopefully will be able to autoclose after this lands :) |
@eqy Unfortunately, didn't fix the build issues. |
I'll see if I can reproduce the build issues on 11.8/12.1, which I haven't been trying locally yet... |
Update: this is because |
oof that windows failure doesn't look so fun |
@eqy Sigh, seems like the CUTLASS issue is open: NVIDIA/cutlass#1571 |
@eqy The problematic kernels seem to be copy pasted from xformers and xformers has already updated to CUTLASS 3.5.0, are they are any diffs from there that could be useful to fix the error? |
@pytorchmergebot merge |
This PR updates submodules third_party/cutlass If those updates are intentional, please add "submodule" keyword to PR title/description. |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Unblocks/unbreaks against newer CUTLASS (3.5+) CC @nWEIdia @xwang233 @ptrblck @thakkarV Pull Request resolved: pytorch#131493 Approved by: https://github.com/Skylion007
This reverts commit 4aa66f6. Reverted #131493 on behalf of https://github.com/izaitsevfb due to breaks internal builds with identifier "std::numeric_limits< ::cutlass::half_t> ::infinity" is undefined in device code ([comment](#131493 (comment)))
@eqy Anything we need for CUTLASS 3.6? |
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
137ceb9
to
eb63912
Compare
Looks like some annoying missing |
@@ -174,7 +174,7 @@ void f8f8bf16_rowwise_impl( | ||
|
||
// Implement rowwise scaling epilogue. | ||
constexpr int ColBroadcastStages = 0; | ||
constexpr int RowBroadcastStages = PingPong::value ? 2 : 1; | ||
constexpr int RowBroadcastStages = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: This change is required to compile, but I have "misaligned address" errors after that when I use CUTLASS 3.6 and use the PingPong kernels. I found a workaround by inverting the order of the scales, eg:
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<Cast,
cutlass::epilogue::fusion::Sm90EVT<Add, Bias,
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale,
Accum>>>>;
(requires to change how the arguments are supplied to the epilogue as well)
cc @drisspg
@@ -9,16 +9,6 @@ | ||
// sparsification, as a bitmask. | ||
// NOTE: Algorithms might select LESS than 8 values in total in some cases. | ||
|
||
namespace platform { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eqy If the improved 9.5.1> fixes for CUTLASS are blocking the changes, can we at least merge the early PR that was passing tests and still does upgrade CUTLASS a little bit? Until we can have the warnings fixed in CUTLASS of course.
@@ -264,13 +270,32 @@ void f8f8bf16_rowwise_impl( | ||
stride_b}, | ||
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr()) | ||
: nullptr}, | ||
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())}, | ||
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}}, | ||
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are flipped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that's because the order of the epilogues got flipped (see just above) to workaround a CUTLASS bug in this new version
Redundant given recent CUTLASS merges |
Unblocks/unbreaks against newer CUTLASS (3.5+)
CC @nWEIdia @xwang233 @ptrblck @thakkarV
cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip @ptrblck @msaroufim @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @desertfire @chauhang