-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ATen] Fix CUDA reduction warp shuffle order #164790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164790
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 1 PendingAs of commit bece7bb with merge base 602ace5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/native/cuda/Reduce.cuh
Outdated
__syncthreads(); | ||
|
||
for (int offset = 1; offset < dim_x; offset <<= 1) { | ||
for (int offset = dim_x / 2; offset > 0; offset >>= 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old: dim_x = 3
-> offset = 1, 2
(2 << 1 > dim_x)
new: dim_x = 3
-> offset = 3 / 2 = 1
(1 >> 1 == 0)
Naively the values offset
will have seem to change between old and new
…t=16" [ghstack-poisoned]
…t=16" [ghstack-poisoned]
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier [ghstack-poisoned]
aten/src/ATen/native/cuda/Reduce.cuh
Outdated
arg_t other = ops.warp_shfl_down(value[i], offset); | ||
value[i] = ops.combine(value[i], other); | ||
// Only combine if the source thread (threadIdx.x + offset) is within bounds | ||
if (threadIdx.x + offset < dim_x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would likely break unrolling, you'd need to check generated ptx
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier [ghstack-poisoned]
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` [ghstack-poisoned]
@eqy made relevant changes to ensure consistency + added some benchmarking/ptx comparisons if you can review again! |
aten/src/ATen/native/cuda/Reduce.cuh
Outdated
// Warp-level reduction for remaining threads | ||
// For non-power-of-2 sizes, we start from the next power-of-2 divided by 2 | ||
// and use a boundary check to avoid out-of-bounds access | ||
for (size_t offset = warpSize / 2; offset > 0; offset >>= 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now looks like potentially has an extra iteration that is wasted (ok sure, if blockIdx.x
won't expected to be < 32 then it makes no difference)
dim_x = 15 before: 1 2 4 8
now: 16 8 4 2 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, a previous commit had it so that we were calculating the next power of two each time, but this is cleaner in most situations, happy to change in a followup if you prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not do this instead (and remove the if below)?
offset = min(warpSize / 2, dim_x-1);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will make us go out of bounds and have bad values?
@pytorchbot merge |
for (int offset = 1; offset < dim_x; offset <<= 1) { | ||
#else | ||
int offset = warpSize / 2; | ||
while (offset >= dim_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim_x is guaranteed to be
- power of 2
- <= warpSize (due to the branch above)
so
inf offset = dim_x >> 1
should do the job?
You can add a check during launch time that dim_x is power of 2 but this loop would already break if it weren't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, that's right! As block_width = std::min(dim0_pow2, int(at::cuda::warp_size())
. Thanks!
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce: **BF16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80% (256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99% (512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18% (1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64% (2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82% (4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05% (8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81% (8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95% (8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05% (8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50% (8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86% (8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45% (4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70% (2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65% (1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81% (512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98% (1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18% (1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28% (1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59% (1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64% (2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00% (1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47% (8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41% (1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52% (8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57% ==================================================================================================================================================================================== ``` **FP16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022804 0.008298 0.015888 0.007848 +43.53% +5.73% (256, 256) sum 0.023215 0.008328 0.015677 0.007850 +48.08% +6.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013777 0.009988 0.012884 0.008512 +6.93% +17.34% (512, 512) sum 0.013775 0.009622 0.012870 0.009028 +7.03% +6.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014740 0.012322 0.013708 0.010239 +7.53% +20.34% (1024, 1024) sum 0.014762 0.012756 0.013722 0.010307 +7.58% +23.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018700 0.018364 0.018135 0.015078 +3.12% +21.79% (2048, 2048) sum 0.018276 0.018415 0.018471 0.015127 -1.06% +21.74% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034518 0.037000 0.033838 0.030617 +2.01% +20.85% (4096, 4096) sum 0.034569 0.037448 0.033842 0.031100 +2.15% +20.41% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087675 0.095176 0.085328 0.084105 +2.75% +13.16% (8192, 8192) sum 0.088102 0.095211 0.085707 0.084090 +2.79% +13.23% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.147800 0.149263 0.146388 0.138390 +0.96% +7.86% (8192, 16384) sum 0.148147 0.148957 0.146439 0.138801 +1.17% +7.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266316 0.260294 0.265829 0.253411 +0.18% +2.72% (8192, 32768) sum 0.266562 0.260717 0.265744 0.253308 +0.31% +2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502035 0.486077 0.498139 0.481374 +0.78% +0.98% (8192, 65536) sum 0.501571 0.485733 0.498353 0.481350 +0.65% +0.91% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971343 0.943016 0.956600 0.938622 +1.54% +0.47% (8192, 131072) sum 0.971463 0.942991 0.957352 0.938334 +1.47% +0.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.952722 1.877165 1.906406 1.861455 +2.43% +0.84% (8192, 262144) sum 1.952634 1.876388 1.904677 1.861282 +2.52% +0.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970697 0.941298 0.956964 0.936160 +1.44% +0.55% (4096, 262144) sum 0.969981 0.941078 0.957016 0.936260 +1.35% +0.51% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501577 0.487208 0.498422 0.483493 +0.63% +0.77% (2048, 262144) sum 0.502029 0.487124 0.497854 0.483643 +0.84% +0.72% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266416 0.257383 0.265928 0.255140 +0.18% +0.88% (1024, 262144) sum 0.266434 0.257081 0.265817 0.255143 +0.23% +0.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087858 0.091296 0.085816 0.087745 +2.38% +4.05% (512, 131072) sum 0.088144 0.091314 0.085664 0.087864 +2.90% +3.93% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014977 0.012393 0.014141 0.010614 +5.91% +16.76% (1000, 1000) sum 0.014589 0.012804 0.014118 0.010320 +3.34% +24.07% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014208 0.008383 0.013273 0.008440 +7.04% -0.68% (1024, 129) sum 0.013804 0.008863 0.013265 0.009003 +4.06% -1.56% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014378 0.009109 0.013037 0.009038 +10.29% +0.79% (1024, 257) sum 0.014387 0.009113 0.013396 0.008698 +7.40% +4.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014207 0.011037 0.013182 0.010391 +7.78% +6.22% (1024, 587) sum 0.014588 0.011453 0.013539 0.010049 +7.75% +13.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.016024 0.013614 0.015448 0.011845 +3.73% +14.93% (2048, 977) sum 0.015990 0.014033 0.015406 0.012278 +3.79% +14.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.014037 0.007804 0.013143 0.008242 +6.80% -5.31% (1024, 128) sum 0.014041 0.007847 0.012759 0.007850 +10.05% -0.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014361 0.009644 0.014075 0.009061 +2.03% +6.43% (8192, 128) sum 0.014366 0.010032 0.013702 0.009181 +4.85% +9.27% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014226 0.008696 0.012894 0.008835 +10.33% -1.57% (1024, 130) sum 0.013830 0.008740 0.013288 0.008989 +4.08% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.015036 0.010019 0.013917 0.009538 +8.04% +5.04% (8192, 130) sum 0.014652 0.010403 0.013900 0.009565 +5.41% +8.76% ==================================================================================================================================================================================== ``` Pull Request resolved: #165055 Approved by: https://github.com/ngimel ghstack dependencies: #165494, #164790
Performance benchmarking, perf neutral: ``` ================================================================================================================================================================================================================================================ Tensor Shape Operation Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full diff % Non-Contig diff % Contig diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.015684 0.017056 0.008287 0.016015 0.016929 0.008170 -2.07% +0.75% +1.43% (256, 256) sum 0.015774 0.016638 0.007926 0.015811 0.016935 0.008330 -0.23% -1.75% -4.85% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013385 0.025742 0.008629 0.013046 0.026005 0.008924 +2.60% -1.01% -3.31% (512, 512) sum 0.013390 0.026059 0.009116 0.013054 0.025696 0.008952 +2.57% +1.41% +1.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014213 0.015467 0.010334 0.013862 0.015082 0.010318 +2.53% +2.55% +0.16% (1024, 1024) sum 0.014179 0.015446 0.010774 0.014132 0.015073 0.010350 +0.33% +2.47% +4.10% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018234 0.019487 0.014812 0.018482 0.019397 0.014802 -1.34% +0.46% +0.07% (2048, 2048) sum 0.018202 0.019529 0.015195 0.018122 0.019485 0.015129 +0.44% +0.23% +0.44% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.033582 0.039378 0.030751 0.033810 0.039673 0.031019 -0.67% -0.74% -0.86% (4096, 4096) sum 0.033604 0.039777 0.030809 0.033530 0.039386 0.031113 +0.22% +0.99% -0.98% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.085824 0.091133 0.084200 0.085431 0.091364 0.084303 +0.46% -0.25% -0.12% (8192, 8192) sum 0.085763 0.091442 0.084180 0.085508 0.091419 0.084595 +0.30% +0.03% -0.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.146480 0.147666 0.138807 0.146515 0.147987 0.138930 -0.02% -0.22% -0.09% (8192, 16384) sum 0.146446 0.147593 0.138559 0.146151 0.147982 0.139120 +0.20% -0.26% -0.40% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266047 0.265386 0.253837 0.265648 0.265885 0.253652 +0.15% -0.19% +0.07% (8192, 32768) sum 0.266093 0.265421 0.253890 0.265458 0.265591 0.253567 +0.24% -0.06% +0.13% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.498632 0.508976 0.481865 0.498237 0.508777 0.481476 +0.08% +0.04% +0.08% (8192, 65536) sum 0.498917 0.508202 0.481883 0.498104 0.508016 0.481972 +0.16% +0.04% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.957633 0.968519 0.938172 0.956766 0.968267 0.938196 +0.09% +0.03% -0.00% (8192, 131072) sum 0.956972 0.968140 0.937741 0.957365 0.968404 0.938056 -0.04% -0.03% -0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.906661 1.928377 1.861846 1.907327 1.928811 1.862083 -0.03% -0.02% -0.01% (8192, 262144) sum 1.905976 1.928362 1.862399 1.907098 1.928844 1.861782 -0.06% -0.02% +0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.956852 0.970101 0.936524 0.957263 0.969809 0.936965 -0.04% +0.03% -0.05% (4096, 262144) sum 0.957117 0.969933 0.936247 0.956675 0.969451 0.936395 +0.05% +0.05% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.498813 0.511299 0.483415 0.498567 0.511482 0.483376 +0.05% -0.04% +0.01% (2048, 262144) sum 0.498813 0.510834 0.483641 0.498875 0.511036 0.483338 -0.01% -0.04% +0.06% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266157 0.276751 0.255192 0.265966 0.276808 0.255544 +0.07% -0.02% -0.14% (1024, 262144) sum 0.266133 0.276709 0.255528 0.265658 0.276685 0.255287 +0.18% +0.01% +0.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.085941 0.081184 0.087931 0.085591 0.080832 0.088008 +0.41% +0.44% -0.09% (512, 131072) sum 0.085962 0.081107 0.088045 0.085882 0.081160 0.088024 +0.09% -0.07% +0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014203 0.045859 0.010310 0.013885 0.046132 0.010621 +2.29% -0.59% -2.93% (1000, 1000) sum 0.014180 0.046165 0.010756 0.013893 0.046109 0.010338 +2.07% +0.12% +4.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.012953 0.016751 0.008536 0.012977 0.016714 0.008916 -0.18% +0.22% -4.26% (1024, 129) sum 0.013356 0.016806 0.008722 0.013003 0.017071 0.008611 +2.71% -1.55% +1.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.013075 0.016787 0.009102 0.013116 0.016769 0.008679 -0.31% +0.11% +4.87% (1024, 257) sum 0.013092 0.016842 0.008786 0.013126 0.017128 0.008771 -0.26% -1.67% +0.17% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.013662 0.017412 0.010055 0.013659 0.017019 0.010033 +0.02% +2.31% +0.22% (1024, 587) sum 0.013636 0.017473 0.010163 0.013642 0.017363 0.010101 -0.04% +0.63% +0.61% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015276 0.027873 0.012531 0.015241 0.027783 0.012467 +0.23% +0.32% +0.51% (2048, 977) sum 0.015345 0.027949 0.012192 0.015255 0.027839 0.012485 +0.59% +0.40% -2.35% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.012806 0.014020 0.008291 0.013137 0.014309 0.007908 -2.52% -2.02% +4.84% (1024, 128) sum 0.012769 0.014308 0.007924 0.012788 0.014236 0.008038 -0.15% +0.51% -1.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014145 0.023049 0.009143 0.014104 0.023298 0.009501 +0.29% -1.07% -3.77% (8192, 128) sum 0.014132 0.023082 0.009638 0.014107 0.023331 0.009244 +0.18% -1.07% +4.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.013420 0.025834 0.008949 0.013368 0.025724 0.008918 +0.39% +0.43% +0.35% (1024, 130) sum 0.013300 0.025940 0.009113 0.013266 0.025419 0.008922 +0.26% +2.05% +2.14% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.013993 0.017883 0.009661 0.014275 0.018220 0.009596 -1.98% -1.85% +0.68% (8192, 130) sum 0.014026 0.018297 0.010066 0.014326 0.018257 0.009659 -2.09% +0.22% +4.21% ================================================================================================================================================================================================================================================ ``` Pull Request resolved: #165178 Approved by: https://github.com/ngimel ghstack dependencies: #165494, #164790, #165055
@pytorchbot revert -m "was reverted due to failing internal tests after merge D84992607" -c ghfirst |
@pytorchbot successfully started a revert job. Check the current status here. |
@PaulZhang12 your PR has been successfully reverted. |
This reverts commit 36371b8. Reverted #164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](#164790 (comment)))
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` [ghstack-poisoned]
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` [ghstack-poisoned]
@PaulZhang12 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This reverts commit 36371b8. Reverted pytorch#164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](pytorch#164790 (comment)))
Typical warp shuffle reduction has the following pattern: <img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" /> which is exhibited in Triton generated by torch.compile: <img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" /> Switch the warp shuffle order to make bitwise equivalence between the 2 easier. PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/ Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order: ``` Tensor Shape Operation New all dims (ms) New dim=0 (ms) New dim=1 (ms) Old all dims (ms) Old dim=0 (ms) Old dim=1 (ms) ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.015817 0.016259 0.013642 0.015990 0.016258 0.013631 (1024, 1024) sum 0.015917 0.015906 0.013359 0.015707 0.016266 0.013226 (1024, 1024) min 0.016021 0.024625 0.015631 0.015761 0.024485 0.015317 (1024, 1024) max 0.016349 0.024971 0.015972 0.015771 0.025001 0.015314 (1024, 1024) argmin 0.018070 0.024448 0.015578 0.018135 0.025370 0.015322 (1024, 1024) argmax 0.018427 0.024859 0.015932 0.018164 0.024452 0.015639 (1024, 1024) var 0.020078 0.026413 0.020295 0.020199 0.026381 0.020214 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.023826 0.023726 0.022273 0.023236 0.023776 0.022248 (2048, 2048) sum 0.023840 0.023355 0.021974 0.023294 0.023354 0.021884 (2048, 2048) min 0.024519 0.041263 0.024620 0.023292 0.041491 0.024358 (2048, 2048) max 0.024509 0.041670 0.024277 0.023334 0.041231 0.024395 (2048, 2048) argmin 0.026125 0.041282 0.024567 0.026772 0.041773 0.024296 (2048, 2048) argmax 0.026117 0.041487 0.024572 0.026412 0.041477 0.024273 (2048, 2048) var 0.026603 0.048581 0.031308 0.027587 0.048603 0.030860 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.053927 0.057070 0.054073 0.053028 0.057544 0.053935 (4096, 4096) sum 0.053604 0.057410 0.054451 0.053076 0.057033 0.054266 (4096, 4096) min 0.054293 0.109122 0.058363 0.053821 0.108689 0.058382 (4096, 4096) max 0.054258 0.108035 0.058703 0.053492 0.110552 0.058376 (4096, 4096) argmin 0.056805 0.111167 0.058301 0.056836 0.112325 0.058292 (4096, 4096) argmax 0.056488 0.110958 0.058636 0.056844 0.111000 0.057928 (4096, 4096) var 0.058936 0.141755 0.068693 0.059735 0.141284 0.068500 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.145552 0.148082 0.138647 0.145364 0.147818 0.138207 (8192, 8192) sum 0.145985 0.147900 0.138714 0.145755 0.148031 0.138616 (8192, 8192) min 0.146566 0.205359 0.192739 0.145611 0.205237 0.182335 (8192, 8192) max 0.146526 0.204844 0.193050 0.146073 0.205457 0.182697 (8192, 8192) argmin 0.150190 0.206605 0.192543 0.150654 0.206847 0.182007 (8192, 8192) argmax 0.150481 0.206368 0.192535 0.150845 0.206430 0.182022 (8192, 8192) var 0.150884 0.184546 0.203900 0.151594 0.184172 0.197983 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1, 1024, 128) mean 0.014293 0.008119 0.014533 0.013861 0.008022 0.014449 (1, 1024, 128) sum 0.014039 0.007877 0.014111 0.014219 0.008227 0.014045 (1, 1024, 128) min 0.014159 0.011354 0.023493 0.014271 0.010862 0.023644 (1, 1024, 128) max 0.014154 0.011027 0.023368 0.014259 0.011234 0.023692 (1, 1024, 128) argmin 0.016403 0.005677 0.023328 0.016273 0.005683 0.024073 (1, 1024, 128) argmax 0.016734 0.005675 0.023437 0.016580 0.005318 0.023331 (1, 1024, 128) var 0.018338 0.009549 0.025538 0.018528 0.009391 0.024777 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (5, 1024, 128) mean 0.014873 0.010131 0.015546 0.015123 0.010131 0.015481 (5, 1024, 128) sum 0.015334 0.009673 0.015824 0.014736 0.009671 0.015438 (5, 1024, 128) min 0.015047 0.013252 0.024573 0.014803 0.013163 0.024551 (5, 1024, 128) max 0.015050 0.013339 0.024197 0.014810 0.013525 0.024230 (5, 1024, 128) argmin 0.017341 0.012737 0.024306 0.017471 0.012379 0.024991 (5, 1024, 128) argmax 0.017345 0.012411 0.024421 0.017422 0.012471 0.024237 (5, 1024, 128) var 0.019973 0.011453 0.026188 0.020050 0.011438 0.026282 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10, 1024, 128) mean 0.016976 0.011575 0.016831 0.016722 0.011927 0.017173 (10, 1024, 128) sum 0.017039 0.011841 0.017159 0.016385 0.011860 0.016753 (10, 1024, 128) min 0.017036 0.015331 0.026770 0.016944 0.015205 0.027166 (10, 1024, 128) max 0.017369 0.015348 0.027077 0.016531 0.015716 0.026819 (10, 1024, 128) argmin 0.019203 0.014447 0.026813 0.018994 0.014497 0.027313 (10, 1024, 128) argmax 0.019563 0.014795 0.027140 0.019460 0.014912 0.026733 (10, 1024, 128) var 0.020529 0.014316 0.030405 0.020719 0.013960 0.029964 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100, 1024, 128) mean 0.045046 0.039168 0.046082 0.044839 0.039217 0.045782 (100, 1024, 128) sum 0.045094 0.039150 0.045777 0.044496 0.039542 0.046083 (100, 1024, 128) min 0.045768 0.054466 0.076244 0.044915 0.053943 0.076599 (100, 1024, 128) max 0.045748 0.054459 0.076188 0.044931 0.053949 0.076856 (100, 1024, 128) argmin 0.048275 0.054046 0.076647 0.048694 0.054105 0.077004 (100, 1024, 128) argmax 0.048267 0.054395 0.077401 0.048691 0.054131 0.076751 (100, 1024, 128) var 0.049710 0.043254 0.083077 0.050971 0.043251 0.082378 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000, 100) mean 0.202312 0.196723 0.197765 0.201774 0.196641 0.197459 (1000, 1000, 100) sum 0.202651 0.196682 0.197736 0.202175 0.196313 0.197523 (1000, 1000, 100) min 0.203022 0.264762 0.269200 0.202729 0.264129 0.268694 (1000, 1000, 100) max 0.202864 0.264396 0.269388 0.202486 0.263896 0.268720 (1000, 1000, 100) argmin 0.226727 0.263781 0.268651 0.226597 0.264676 0.268983 (1000, 1000, 100) argmax 0.226412 0.264469 0.269090 0.226570 0.264595 0.269178 (1000, 1000, 100) var 0.243223 0.204079 0.216096 0.241942 0.204079 0.215925 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (10000, 100) mean 0.016193 0.020277 0.014316 0.016152 0.020324 0.013712 (10000, 100) sum 0.016289 0.020237 0.014034 0.016168 0.020265 0.013708 (10000, 100) min 0.016046 0.030872 0.019609 0.016208 0.030867 0.018627 (10000, 100) max 0.016369 0.030835 0.019257 0.016218 0.030861 0.018209 (10000, 100) argmin 0.017957 0.031171 0.019517 0.018050 0.031556 0.018077 (10000, 100) argmax 0.017961 0.031658 0.019521 0.018060 0.031564 0.018087 (10000, 100) var 0.020393 0.035652 0.019339 0.020144 0.035987 0.019171 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (100000, 10) mean 0.015718 0.016576 0.016555 0.015999 0.016246 0.014869 (100000, 10) sum 0.015833 0.016247 0.016572 0.016007 0.016627 0.014872 (100000, 10) min 0.015888 0.020510 0.023920 0.015671 0.020821 0.021417 (100000, 10) max 0.015889 0.020479 0.023918 0.016077 0.020386 0.021421 (100000, 10) argmin 0.018233 0.020863 0.023647 0.017574 0.020864 0.021103 (100000, 10) argmax 0.017896 0.020527 0.023296 0.017569 0.020447 0.021098 (100000, 10) var 0.020005 0.024198 0.024372 0.020075 0.024167 0.022415 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 1023) mean 1.874816 1.963506 1.903909 1.873279 1.963859 1.903230 (1023, 1023, 1023) sum 1.875030 1.965716 1.902458 1.873566 1.960730 1.901642 (1023, 1023, 1023) min 1.878563 2.473455 2.179092 1.875174 2.482086 2.183027 (1023, 1023, 1023) max 1.879128 2.474803 2.178895 1.874831 2.482253 2.183884 (1023, 1023, 1023) argmin 1.921800 2.476629 2.174831 1.923987 2.472641 2.170453 (1023, 1023, 1023) argmax 1.922605 2.476688 2.177927 1.923366 2.472808 2.172979 (1023, 1023, 1023) var 1.972606 3.088695 2.758797 1.978679 3.095658 2.762243 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 255) mean 0.489984 0.500954 0.492957 0.489891 0.500654 0.491971 (1023, 1023, 255) sum 0.490228 0.500764 0.492289 0.489624 0.501089 0.492824 (1023, 1023, 255) min 0.491457 0.563560 0.553334 0.490355 0.564709 0.554754 (1023, 1023, 255) max 0.491396 0.563628 0.553345 0.490017 0.565004 0.554947 (1023, 1023, 255) argmin 0.503666 0.561512 0.551831 0.503845 0.560972 0.551017 (1023, 1023, 255) argmax 0.503602 0.561185 0.551407 0.504328 0.561267 0.551448 (1023, 1023, 255) var 0.510844 0.709452 0.701630 0.512693 0.710365 0.701965 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1023, 1023, 377) mean 0.707439 0.727646 0.712019 0.706769 0.727101 0.711632 (1023, 1023, 377) sum 0.707780 0.727453 0.711554 0.706807 0.726656 0.711729 (1023, 1023, 377) min 0.709423 0.819809 0.794379 0.707847 0.822086 0.796664 (1023, 1023, 377) max 0.709297 0.819780 0.794308 0.707566 0.821913 0.796690 (1023, 1023, 377) argmin 0.725028 0.817088 0.791695 0.726039 0.816445 0.790828 (1023, 1023, 377) argmax 0.725301 0.817011 0.791420 0.726040 0.816917 0.791143 (1023, 1023, 377) var 0.740859 1.034165 1.006712 0.743413 1.035506 1.007638 ``` Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826) [ghstack-poisoned]
@PaulZhang12 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Typical warp shuffle reduction has the following pattern:

which is exhibited in Triton generated by torch.compile:

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/
Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
Differential Revision: D85022826