Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Oct 6, 2025

Stack from ghstack (oldest at bottom):

Typical warp shuffle reduction has the following pattern:
image

which is exhibited in Triton generated by torch.compile:
image

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:

Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             

Differential Revision: D85022826

Copy link

pytorch-bot bot commented Oct 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164790

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 1 Pending

As of commit bece7bb with merge base 602ace5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

PaulZhang12 added a commit that referenced this pull request Oct 6, 2025
__syncthreads();

for (int offset = 1; offset < dim_x; offset <<= 1) {
for (int offset = dim_x / 2; offset > 0; offset >>= 1) {
Copy link
Collaborator

@eqy eqy Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old: dim_x = 3 -> offset = 1, 2 (2 << 1 > dim_x)
new: dim_x = 3 -> offset = 3 / 2 = 1 (1 >> 1 == 0)

Naively the values offset will have seem to change between old and new

PaulZhang12 added a commit that referenced this pull request Oct 7, 2025
@PaulZhang12 PaulZhang12 changed the title Fix eager reduction warp shuffle order to start from offset=16 [ATen] Fix CUDA reduction warp shuffle order with traditional offsets instead of beginning at offset=1 Oct 8, 2025
@PaulZhang12 PaulZhang12 changed the title [ATen] Fix CUDA reduction warp shuffle order with traditional offsets instead of beginning at offset=1 [ATen] Fix CUDA reduction warp shuffle order Oct 8, 2025
@PaulZhang12 PaulZhang12 requested review from eqy and ngimel October 8, 2025 15:58
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier



[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 8, 2025
arg_t other = ops.warp_shfl_down(value[i], offset);
value[i] = ops.combine(value[i], other);
// Only combine if the source thread (threadIdx.x + offset) is within bounds
if (threadIdx.x + offset < dim_x) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would likely break unrolling, you'd need to check generated ptx

Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier



[ghstack-poisoned]
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             
```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 8, 2025
@PaulZhang12 PaulZhang12 requested a review from ngimel October 8, 2025 18:59
@PaulZhang12
Copy link
Contributor Author

@eqy made relevant changes to ensure consistency + added some benchmarking/ptx comparisons if you can review again!

// Warp-level reduction for remaining threads
// For non-power-of-2 sizes, we start from the next power-of-2 divided by 2
// and use a boundary check to avoid out-of-bounds access
for (size_t offset = warpSize / 2; offset > 0; offset >>= 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now looks like potentially has an extra iteration that is wasted (ok sure, if blockIdx.x won't expected to be < 32 then it makes no difference)
dim_x = 15 before: 1 2 4 8
now: 16 8 4 2 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, a previous commit had it so that we were calculating the next power of two each time, but this is cleaner in most situations, happy to change in a followup if you prefer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not do this instead (and remove the if below)?
offset = min(warpSize / 2, dim_x-1);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make us go out of bounds and have bad values?

@PaulZhang12
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
int offset = warpSize / 2;
while (offset >= dim_x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dim_x is guaranteed to be

  1. power of 2
  2. <= warpSize (due to the branch above)
    so
inf offset = dim_x >> 1

should do the job?
You can add a check during launch time that dim_x is power of 2 but this loop would already break if it weren't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, that's right! As block_width = std::min(dim0_pow2, int(at::cuda::warp_size()). Thanks!

Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             
```

[ghstack-poisoned]
@PaulZhang12
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 17, 2025
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: #165055
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790
pytorchmergebot pushed a commit that referenced this pull request Oct 19, 2025
Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full diff %     Non-Contig diff %    Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015684             0.017056               0.008287             0.016015             0.016929               0.008170                      -2.07%               +0.75%          +1.43%
(256, 256)           sum          0.015774             0.016638               0.007926             0.015811             0.016935               0.008330                      -0.23%               -1.75%          -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013385             0.025742               0.008629             0.013046             0.026005               0.008924                      +2.60%               -1.01%          -3.31%
(512, 512)           sum          0.013390             0.026059               0.009116             0.013054             0.025696               0.008952                      +2.57%               +1.41%          +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014213             0.015467               0.010334             0.013862             0.015082               0.010318                      +2.53%               +2.55%          +0.16%
(1024, 1024)         sum          0.014179             0.015446               0.010774             0.014132             0.015073               0.010350                      +0.33%               +2.47%          +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018234             0.019487               0.014812             0.018482             0.019397               0.014802                      -1.34%               +0.46%          +0.07%
(2048, 2048)         sum          0.018202             0.019529               0.015195             0.018122             0.019485               0.015129                      +0.44%               +0.23%          +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.033582             0.039378               0.030751             0.033810             0.039673               0.031019                      -0.67%               -0.74%          -0.86%
(4096, 4096)         sum          0.033604             0.039777               0.030809             0.033530             0.039386               0.031113                      +0.22%               +0.99%          -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.085824             0.091133               0.084200             0.085431             0.091364               0.084303                      +0.46%               -0.25%          -0.12%
(8192, 8192)         sum          0.085763             0.091442               0.084180             0.085508             0.091419               0.084595                      +0.30%               +0.03%          -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.146480             0.147666               0.138807             0.146515             0.147987               0.138930                      -0.02%               -0.22%          -0.09%
(8192, 16384)        sum          0.146446             0.147593               0.138559             0.146151             0.147982               0.139120                      +0.20%               -0.26%          -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266047             0.265386               0.253837             0.265648             0.265885               0.253652                      +0.15%               -0.19%          +0.07%
(8192, 32768)        sum          0.266093             0.265421               0.253890             0.265458             0.265591               0.253567                      +0.24%               -0.06%          +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.498632             0.508976               0.481865             0.498237             0.508777               0.481476                      +0.08%               +0.04%          +0.08%
(8192, 65536)        sum          0.498917             0.508202               0.481883             0.498104             0.508016               0.481972                      +0.16%               +0.04%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.957633             0.968519               0.938172             0.956766             0.968267               0.938196                      +0.09%               +0.03%          -0.00%
(8192, 131072)       sum          0.956972             0.968140               0.937741             0.957365             0.968404               0.938056                      -0.04%               -0.03%          -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.906661             1.928377               1.861846             1.907327             1.928811               1.862083                      -0.03%               -0.02%          -0.01%
(8192, 262144)       sum          1.905976             1.928362               1.862399             1.907098             1.928844               1.861782                      -0.06%               -0.02%          +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.956852             0.970101               0.936524             0.957263             0.969809               0.936965                      -0.04%               +0.03%          -0.05%
(4096, 262144)       sum          0.957117             0.969933               0.936247             0.956675             0.969451               0.936395                      +0.05%               +0.05%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.498813             0.511299               0.483415             0.498567             0.511482               0.483376                      +0.05%               -0.04%          +0.01%
(2048, 262144)       sum          0.498813             0.510834               0.483641             0.498875             0.511036               0.483338                      -0.01%               -0.04%          +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266157             0.276751               0.255192             0.265966             0.276808               0.255544                      +0.07%               -0.02%          -0.14%
(1024, 262144)       sum          0.266133             0.276709               0.255528             0.265658             0.276685               0.255287                      +0.18%               +0.01%          +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.085941             0.081184               0.087931             0.085591             0.080832               0.088008                      +0.41%               +0.44%          -0.09%
(512, 131072)        sum          0.085962             0.081107               0.088045             0.085882             0.081160               0.088024                      +0.09%               -0.07%          +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014203             0.045859               0.010310             0.013885             0.046132               0.010621                      +2.29%               -0.59%          -2.93%
(1000, 1000)         sum          0.014180             0.046165               0.010756             0.013893             0.046109               0.010338                      +2.07%               +0.12%          +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.012953             0.016751               0.008536             0.012977             0.016714               0.008916                      -0.18%               +0.22%          -4.26%
(1024, 129)          sum          0.013356             0.016806               0.008722             0.013003             0.017071               0.008611                      +2.71%               -1.55%          +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013075             0.016787               0.009102             0.013116             0.016769               0.008679                      -0.31%               +0.11%          +4.87%
(1024, 257)          sum          0.013092             0.016842               0.008786             0.013126             0.017128               0.008771                      -0.26%               -1.67%          +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013662             0.017412               0.010055             0.013659             0.017019               0.010033                      +0.02%               +2.31%          +0.22%
(1024, 587)          sum          0.013636             0.017473               0.010163             0.013642             0.017363               0.010101                      -0.04%               +0.63%          +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015276             0.027873               0.012531             0.015241             0.027783               0.012467                      +0.23%               +0.32%          +0.51%
(2048, 977)          sum          0.015345             0.027949               0.012192             0.015255             0.027839               0.012485                      +0.59%               +0.40%          -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.012806             0.014020               0.008291             0.013137             0.014309               0.007908                      -2.52%               -2.02%          +4.84%
(1024, 128)          sum          0.012769             0.014308               0.007924             0.012788             0.014236               0.008038                      -0.15%               +0.51%          -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014145             0.023049               0.009143             0.014104             0.023298               0.009501                      +0.29%               -1.07%          -3.77%
(8192, 128)          sum          0.014132             0.023082               0.009638             0.014107             0.023331               0.009244                      +0.18%               -1.07%          +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013420             0.025834               0.008949             0.013368             0.025724               0.008918                      +0.39%               +0.43%          +0.35%
(1024, 130)          sum          0.013300             0.025940               0.009113             0.013266             0.025419               0.008922                      +0.26%               +2.05%          +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.013993             0.017883               0.009661             0.014275             0.018220               0.009596                      -1.98%               -1.85%          +0.68%
(8192, 130)          sum          0.014026             0.018297               0.010066             0.014326             0.018257               0.009659                      -2.09%               +0.22%          +4.21%
================================================================================================================================================================================================================================================
```

Pull Request resolved: #165178
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790, #165055
@clee2000
Copy link
Contributor

@pytorchbot revert -m "was reverted due to failing internal tests after merge D84992607" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@PaulZhang12 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 20, 2025
This reverts commit 36371b8.

Reverted #164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](#164790 (comment)))
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             
```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 20, 2025
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             
```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 20, 2025
@PaulZhang12
Copy link
Contributor Author

@PaulZhang12 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pytorchmergebot added a commit to renato-arantes/pytorch that referenced this pull request Oct 20, 2025
This reverts commit 36371b8.

Reverted pytorch#164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](pytorch#164790 (comment)))
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631             
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226             
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317             
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314             
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322             
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639             
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248             
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884             
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358             
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395             
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296             
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273             
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935             
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266             
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382             
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376             
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292             
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928             
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207             
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616             
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335             
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697             
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007             
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022             
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449             
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045             
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644             
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692             
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073             
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331             
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481             
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438             
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551             
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230             
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991             
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237             
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173             
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753             
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166             
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819             
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313             
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733             
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782             
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083             
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599             
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856             
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004             
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751             
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459             
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523             
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694             
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720             
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983             
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178             
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712             
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708             
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627             
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209             
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077             
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087             
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869             
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872             
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417             
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421             
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103             
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098             
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230             
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642             
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027             
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884             
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453             
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979             
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971             
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824             
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754             
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947             
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017             
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448             
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965             
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632             
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729             
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664             
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690             
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828             
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143             
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638             
```

Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826)

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 20, 2025
@PaulZhang12
Copy link
Contributor Author

@PaulZhang12 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@PaulZhang12
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/h100 ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.