Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

PaulZhang12
Copy link
Contributor

@PaulZhang12 PaulZhang12 commented Oct 9, 2025

Stack from ghstack (oldest at bottom):

Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
BF16

Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================

FP16

Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 9, 2025
PaulZhang12 added a commit that referenced this pull request Oct 9, 2025
ghstack-source-id: 42dabad
Pull Request resolved: #165055
Copy link

pytorch-bot bot commented Oct 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165055

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 910c221 with merge base 90c0825 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy
Copy link
Collaborator

eqy commented Oct 9, 2025

Looks like there's a lot of other uses of gpu_reduce_kernel, should these also all be updated to 128-bit vectorization?

iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t {
return a + b;
}));
// }
Copy link
Collaborator

@Aidyn-A Aidyn-A Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this is a typo 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes typo, need to update, thanks!

@PaulZhang12
Copy link
Contributor Author

PaulZhang12 commented Oct 9, 2025

Looks like there's a lot of other uses of gpu_reduce_kernel, should these also all be updated to 128-bit vectorization?

I have found regressions generally for other reductions, like argmax, min, etc. I will add some performance numbers in a bit

@PaulZhang12 PaulZhang12 changed the title Vectorize 8 elements on 16 bit data types [ATen] Vectorize 8 elements on 16 bit data types Oct 9, 2025
PaulZhang12 added a commit that referenced this pull request Oct 9, 2025
ghstack-source-id: 801892e
Pull Request resolved: #165055
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t>
void mean_kernel_impl(TensorIterator& iter) {
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T>
const bool is_16_bits =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be constexpr bool

using factor_t = typename c10::scalar_value_type<acc_t>::type;
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel();
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor});
if (is_16_bits) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this becomes constexpr if too

template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t>
struct sum_functor {
void operator()(TensorIterator& iter) {
const bool is_16_bits =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

}
#endif
gpu_reduce_kernel<scalar_t, out_t>(
if (is_16_bits) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have an issue where MSCV complains about constexpr use the macro used for constexpr in other CUDA kernels.

Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
Tensor Shape         Operation Base Full reduce (ms)  Base Contiguous dim (ms)  V Full reduce (ms)  V Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022750             0.008646             0.016023             0.008252                          +41.98%               +4.77%
(256, 256)           sum          0.023141             0.008664             0.015670             0.008269                          +47.68%               +4.78%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014090             0.009603             0.013006             0.008669                           +8.33%              +10.77%
(512, 512)           sum          0.014084             0.009622             0.013015             0.008672                           +8.21%              +10.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014368             0.012647             0.013895             0.010347                           +3.40%              +22.23%
(1024, 1024)         sum          0.014700             0.012651             0.014261             0.010368                           +3.08%              +22.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018587             0.017994             0.019311             0.015388                           -3.75%              +16.94%
(2048, 2048)         sum          0.018590             0.017995             0.018913             0.015064                           -1.71%              +19.46%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034463             0.037302             0.034645             0.032360                           -0.53%              +15.27%
(4096, 4096)         sum          0.034163             0.037045             0.034679             0.031992                           -1.49%              +15.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087827             0.095071             0.087534             0.087104                           +0.33%               +9.15%
(8192, 8192)         sum          0.088087             0.094829             0.087511             0.086730                           +0.66%               +9.34%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148125             0.149172             0.150266             0.138626                           -1.42%               +7.61%
(8192, 16384)        sum          0.147821             0.149253             0.150140             0.138618                           -1.54%               +7.67%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266410             0.261032             0.274723             0.271464                           -3.03%               -3.84%
(8192, 32768)        sum          0.266214             0.260893             0.275037             0.271687                           -3.21%               -3.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.501475             0.485934             0.520243             0.535353                           -3.61%               -9.23%
(8192, 65536)        sum          0.501845             0.486466             0.520737             0.534895                           -3.63%               -9.05%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.970737             0.942428             1.012732             1.013035                           -4.15%               -6.97%
(8192, 131072)       sum          0.970730             0.942806             1.012648             1.012854                           -4.14%               -6.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.954044             1.877875             2.019751             1.944285                           -3.25%               -3.42%
(8192, 262144)       sum          1.954403             1.877363             2.020125             1.945457                           -3.25%               -3.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.971163             0.941658             1.013148             0.977148                           -4.14%               -3.63%
(4096, 262144)       sum          0.971205             0.941575             1.013127             0.977082                           -4.14%               -3.63%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501913             0.487320             0.520714             0.500667                           -3.61%               -2.67%
(2048, 262144)       sum          0.501453             0.486805             0.520148             0.500693                           -3.59%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266118             0.257603             0.274991             0.261859                           -3.23%               -1.63%
(1024, 262144)       sum          0.266247             0.257172             0.275075             0.261925                           -3.21%               -1.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091222             0.087552             0.090097                           +0.29%               +1.25%
(512, 131072)        sum          0.087862             0.091237             0.087519             0.090119                           +0.39%               +1.24%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014582             0.012707             0.014285             0.010758                           +2.08%              +18.12%
(1000, 1000)         sum          0.014568             0.012711             0.014289             0.010754                           +1.95%              +18.20%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014099             0.008741             0.013312             0.008671                           +5.91%               +0.81%
(1024, 129)          sum          0.013774             0.008410             0.012965             0.008663                           +6.24%               -2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014312             0.009436             0.013504             0.008753                           +5.98%               +7.80%
(1024, 257)          sum          0.013998             0.009111             0.013503             0.008770                           +3.67%               +3.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014163             0.011363             0.013690             0.010475                           +3.46%               +8.48%
(1024, 587)          sum          0.014172             0.011377             0.013356             0.010476                           +6.11%               +8.60%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015947             0.013946             0.015327             0.012085                           +4.05%              +15.40%
(2048, 977)          sum          0.015631             0.013945             0.015322             0.012464                           +2.02%              +11.88%
====================================================================================================================================================================================

```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 8f47441
Pull Request resolved: #165055
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension:
```
====================================================================================================================================================================================
PyTorch Reduction Operations Benchmark - Combined Comparison
====================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015700             0.008431             0.016060             0.008293                           -2.24%               +1.66%
(256, 256)           sum          0.023307             0.008427             0.015687             0.007887                          +48.58%               +6.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013473             0.008701             0.013377             0.009102                           +0.72%               -4.41%
(512, 512)           sum          0.014294             0.009716             0.013321             0.008644                           +7.30%              +12.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014308             0.010442             0.013879             0.010776                           +3.09%               -3.10%
(1024, 1024)         sum          0.014516             0.012448             0.013785             0.010314                           +5.30%              +20.69%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.019392             0.015462             0.018919             0.015022                           +2.50%               +2.93%
(2048, 2048)         sum          0.018793             0.018475             0.018607             0.015159                           +1.00%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034354             0.032064             0.034267             0.032384                           +0.25%               -0.99%
(4096, 4096)         sum          0.034703             0.037153             0.034006             0.030732                           +2.05%              +20.89%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087554             0.086913             0.087550             0.087240                           +0.00%               -0.37%
(8192, 8192)         sum          0.088265             0.095772             0.085408             0.084674                           +3.35%              +13.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.150714             0.138377             0.150194             0.138775                           +0.35%               -0.29%
(8192, 16384)        sum          0.147996             0.149958             0.146178             0.138765                           +1.24%               +8.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.274935             0.271579             0.274787             0.271480                           +0.05%               +0.04%
(8192, 32768)        sum          0.266700             0.261308             0.266002             0.253478                           +0.26%               +3.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.520137             0.535495             0.520826             0.534928                           -0.13%               +0.11%
(8192, 65536)        sum          0.501590             0.486255             0.498309             0.481378                           +0.66%               +1.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         1.013195             1.013046             1.012499             1.012806                           +0.07%               +0.02%
(8192, 131072)       sum          0.970992             0.943506             0.957651             0.938352                           +1.39%               +0.55%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         2.020536             1.944477             2.021117             1.944456                           -0.03%               +0.00%
(8192, 262144)       sum          1.954167             1.877884             1.904428             1.862436                           +2.61%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         1.013495             0.977271             1.012324             0.977721                           +0.12%               -0.05%
(4096, 262144)       sum          0.970941             0.941685             0.957323             0.936643                           +1.42%               +0.54%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.520851             0.500767             0.520352             0.500979                           +0.10%               -0.04%
(2048, 262144)       sum          0.501935             0.487330             0.497851             0.483599                           +0.82%               +0.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.275272             0.262046             0.275309             0.261893                           -0.01%               +0.06%
(1024, 262144)       sum          0.266447             0.257469             0.265654             0.255563                           +0.30%               +0.75%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087285             0.090168             0.087124             0.089990                           +0.18%               +0.20%
(512, 131072)        sum          0.087961             0.091448             0.085851             0.088071                           +2.46%               +3.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014012             0.010440             0.013918             0.010388                           +0.68%               +0.50%
(1000, 1000)         sum          0.014668             0.012879             0.014121             0.010303                           +3.87%              +25.00%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.013072             0.008655             0.013373             0.008693                           -2.25%               -0.44%
(1024, 129)          sum          0.014308             0.008951             0.013336             0.008592                           +7.29%               +4.18%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013171             0.008889             0.013516             0.009153                           -2.55%               -2.88%
(1024, 257)          sum          0.014455             0.009573             0.013460             0.008709                           +7.39%               +9.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013385             0.010576             0.013303             0.010525                           +0.62%               +0.48%
(1024, 587)          sum          0.014635             0.011530             0.013569             0.010458                           +7.86%              +10.25%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015431             0.012528             0.015751             0.012079                           -2.03%               +3.72%
(2048, 977)          sum          0.015732             0.014101             0.015064             0.011896                           +4.43%              +18.54%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013273             0.008437             0.013219             0.007888                           +0.41%               +6.96%
(1024, 128)          sum          0.013738             0.007830             0.012793             0.008284                           +7.39%               -5.48%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.013930             0.010086             0.013843             0.009677                           +0.63%               +4.23%
(8192, 128)          sum          0.014525             0.009780             0.013758             0.009247                           +5.57%               +5.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013427             0.009024             0.013372             0.009078                           +0.41%               -0.59%
(1024, 130)          sum          0.013951             0.008799             0.013360             0.008560                           +4.42%               +2.79%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014121             0.010474             0.014035             0.009967                           +0.61%               +5.09%
(8192, 130)          sum          0.015139             0.010545             0.013958             0.010045                           +8.46%               +4.98%
====================================================================================================================================================================================

```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 1bdfd2f
Pull Request resolved: #165055
Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool!

Comment on lines 42 to 44
constexpr bool is_16_bits =
( (std::is_same<at::Half, scalar_t>::value) ||
(std::is_same<at::BFloat16, scalar_t>::value) );
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
constexpr bool is_16_bits =
( (std::is_same<at::Half, scalar_t>::value) ||
(std::is_same<at::BFloat16, scalar_t>::value) );
constexpr bool is_16_bits = sizeof(scalar_t) == 16;

Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for BF16, ~6% improvement on average across shapes:
```
====================================================================================================================================================================================
PyTorch Reduction Operations Benchmark - Combined Comparison
====================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: 0fba0d5
Pull Request resolved: #165055
@PaulZhang12 PaulZhang12 changed the title [ATen] Vectorize 8 elements on 16 bit data types [ATen] Vectorize 8 elements on 16 bit data types for sum/mean Oct 16, 2025
@PaulZhang12 PaulZhang12 added the ciflow/rocm Trigger "default" config CI on ROCm label Oct 16, 2025
@PaulZhang12
Copy link
Contributor Author

If you have an issue where MSCV complains about constexpr use the macro used for constexpr in other CUDA kernels.

@Skylion007 MSCV complains about the extended lambda definition in the if constexpr. What is the recommended workaround here?

…ean"


Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %   
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```

[ghstack-poisoned]
PaulZhang12 added a commit that referenced this pull request Oct 16, 2025
ghstack-source-id: b3de3a1
Pull Request resolved: #165055
@PaulZhang12
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 19, 2025
Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full diff %     Non-Contig diff %    Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015684             0.017056               0.008287             0.016015             0.016929               0.008170                      -2.07%               +0.75%          +1.43%
(256, 256)           sum          0.015774             0.016638               0.007926             0.015811             0.016935               0.008330                      -0.23%               -1.75%          -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013385             0.025742               0.008629             0.013046             0.026005               0.008924                      +2.60%               -1.01%          -3.31%
(512, 512)           sum          0.013390             0.026059               0.009116             0.013054             0.025696               0.008952                      +2.57%               +1.41%          +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014213             0.015467               0.010334             0.013862             0.015082               0.010318                      +2.53%               +2.55%          +0.16%
(1024, 1024)         sum          0.014179             0.015446               0.010774             0.014132             0.015073               0.010350                      +0.33%               +2.47%          +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018234             0.019487               0.014812             0.018482             0.019397               0.014802                      -1.34%               +0.46%          +0.07%
(2048, 2048)         sum          0.018202             0.019529               0.015195             0.018122             0.019485               0.015129                      +0.44%               +0.23%          +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.033582             0.039378               0.030751             0.033810             0.039673               0.031019                      -0.67%               -0.74%          -0.86%
(4096, 4096)         sum          0.033604             0.039777               0.030809             0.033530             0.039386               0.031113                      +0.22%               +0.99%          -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.085824             0.091133               0.084200             0.085431             0.091364               0.084303                      +0.46%               -0.25%          -0.12%
(8192, 8192)         sum          0.085763             0.091442               0.084180             0.085508             0.091419               0.084595                      +0.30%               +0.03%          -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.146480             0.147666               0.138807             0.146515             0.147987               0.138930                      -0.02%               -0.22%          -0.09%
(8192, 16384)        sum          0.146446             0.147593               0.138559             0.146151             0.147982               0.139120                      +0.20%               -0.26%          -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266047             0.265386               0.253837             0.265648             0.265885               0.253652                      +0.15%               -0.19%          +0.07%
(8192, 32768)        sum          0.266093             0.265421               0.253890             0.265458             0.265591               0.253567                      +0.24%               -0.06%          +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.498632             0.508976               0.481865             0.498237             0.508777               0.481476                      +0.08%               +0.04%          +0.08%
(8192, 65536)        sum          0.498917             0.508202               0.481883             0.498104             0.508016               0.481972                      +0.16%               +0.04%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.957633             0.968519               0.938172             0.956766             0.968267               0.938196                      +0.09%               +0.03%          -0.00%
(8192, 131072)       sum          0.956972             0.968140               0.937741             0.957365             0.968404               0.938056                      -0.04%               -0.03%          -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.906661             1.928377               1.861846             1.907327             1.928811               1.862083                      -0.03%               -0.02%          -0.01%
(8192, 262144)       sum          1.905976             1.928362               1.862399             1.907098             1.928844               1.861782                      -0.06%               -0.02%          +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.956852             0.970101               0.936524             0.957263             0.969809               0.936965                      -0.04%               +0.03%          -0.05%
(4096, 262144)       sum          0.957117             0.969933               0.936247             0.956675             0.969451               0.936395                      +0.05%               +0.05%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.498813             0.511299               0.483415             0.498567             0.511482               0.483376                      +0.05%               -0.04%          +0.01%
(2048, 262144)       sum          0.498813             0.510834               0.483641             0.498875             0.511036               0.483338                      -0.01%               -0.04%          +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266157             0.276751               0.255192             0.265966             0.276808               0.255544                      +0.07%               -0.02%          -0.14%
(1024, 262144)       sum          0.266133             0.276709               0.255528             0.265658             0.276685               0.255287                      +0.18%               +0.01%          +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.085941             0.081184               0.087931             0.085591             0.080832               0.088008                      +0.41%               +0.44%          -0.09%
(512, 131072)        sum          0.085962             0.081107               0.088045             0.085882             0.081160               0.088024                      +0.09%               -0.07%          +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014203             0.045859               0.010310             0.013885             0.046132               0.010621                      +2.29%               -0.59%          -2.93%
(1000, 1000)         sum          0.014180             0.046165               0.010756             0.013893             0.046109               0.010338                      +2.07%               +0.12%          +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.012953             0.016751               0.008536             0.012977             0.016714               0.008916                      -0.18%               +0.22%          -4.26%
(1024, 129)          sum          0.013356             0.016806               0.008722             0.013003             0.017071               0.008611                      +2.71%               -1.55%          +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013075             0.016787               0.009102             0.013116             0.016769               0.008679                      -0.31%               +0.11%          +4.87%
(1024, 257)          sum          0.013092             0.016842               0.008786             0.013126             0.017128               0.008771                      -0.26%               -1.67%          +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013662             0.017412               0.010055             0.013659             0.017019               0.010033                      +0.02%               +2.31%          +0.22%
(1024, 587)          sum          0.013636             0.017473               0.010163             0.013642             0.017363               0.010101                      -0.04%               +0.63%          +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015276             0.027873               0.012531             0.015241             0.027783               0.012467                      +0.23%               +0.32%          +0.51%
(2048, 977)          sum          0.015345             0.027949               0.012192             0.015255             0.027839               0.012485                      +0.59%               +0.40%          -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.012806             0.014020               0.008291             0.013137             0.014309               0.007908                      -2.52%               -2.02%          +4.84%
(1024, 128)          sum          0.012769             0.014308               0.007924             0.012788             0.014236               0.008038                      -0.15%               +0.51%          -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014145             0.023049               0.009143             0.014104             0.023298               0.009501                      +0.29%               -1.07%          -3.77%
(8192, 128)          sum          0.014132             0.023082               0.009638             0.014107             0.023331               0.009244                      +0.18%               -1.07%          +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013420             0.025834               0.008949             0.013368             0.025724               0.008918                      +0.39%               +0.43%          +0.35%
(1024, 130)          sum          0.013300             0.025940               0.009113             0.013266             0.025419               0.008922                      +0.26%               +2.05%          +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.013993             0.017883               0.009661             0.014275             0.018220               0.009596                      -1.98%               -1.85%          +0.68%
(8192, 130)          sum          0.014026             0.018297               0.010066             0.014326             0.018257               0.009659                      -2.09%               +0.22%          +4.21%
================================================================================================================================================================================================================================================
```

Pull Request resolved: #165178
Approved by: https://github.com/ngimel
ghstack dependencies: #165494, #164790, #165055
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…h#165055)

Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
**BF16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022686             0.008263             0.015498             0.008117                          +46.38%               +1.80%
(256, 256)           sum          0.022769             0.008269             0.015628             0.008185                          +45.69%               +1.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.014116             0.009545             0.012892             0.008839                           +9.49%               +7.99%
(512, 512)           sum          0.014110             0.009892             0.012891             0.008878                           +9.46%              +11.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014727             0.012642             0.014061             0.010519                           +4.74%              +20.18%
(1024, 1024)         sum          0.014376             0.012636             0.014069             0.010595                           +2.18%              +19.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018663             0.018294             0.018171             0.014678                           +2.71%              +24.64%
(2048, 2048)         sum          0.018638             0.017931             0.018142             0.014713                           +2.73%              +21.87%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034216             0.036953             0.033520             0.030585                           +2.08%              +20.82%
(4096, 4096)         sum          0.034196             0.036942             0.033518             0.030676                           +2.02%              +20.43%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087763             0.095201             0.085439             0.084960                           +2.72%              +12.05%
(8192, 8192)         sum          0.088079             0.095592             0.085353             0.084632                           +3.19%              +12.95%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.148174             0.149705             0.146274             0.138865                           +1.30%               +7.81%
(8192, 16384)        sum          0.147820             0.149371             0.146419             0.138752                           +0.96%               +7.65%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266144             0.260807             0.265953             0.253330                           +0.07%               +2.95%
(8192, 32768)        sum          0.266572             0.261163             0.265729             0.253294                           +0.32%               +3.11%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502034             0.486312             0.498417             0.481246                           +0.73%               +1.05%
(8192, 65536)        sum          0.501597             0.486351             0.497735             0.481579                           +0.78%               +0.99%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971178             0.942988             0.957164             0.938316                           +1.46%               +0.50%
(8192, 131072)       sum          0.971189             0.943232             0.956814             0.937816                           +1.50%               +0.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.953728             1.877648             1.904937             1.861692                           +2.56%               +0.86%
(8192, 262144)       sum          1.953969             1.877538             1.905990             1.862547                           +2.52%               +0.80%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970408             0.940965             0.957871             0.936732                           +1.31%               +0.45%
(4096, 262144)       sum          0.970919             0.941652             0.957765             0.936676                           +1.37%               +0.53%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501477             0.486976             0.497964             0.483570                           +0.71%               +0.70%
(2048, 262144)       sum          0.501955             0.487213             0.498210             0.483218                           +0.75%               +0.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266536             0.257111             0.265642             0.255439                           +0.34%               +0.65%
(1024, 262144)       sum          0.266613             0.257096             0.265427             0.255472                           +0.45%               +0.64%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087805             0.091200             0.085818             0.087851                           +2.32%               +3.81%
(512, 131072)        sum          0.087788             0.091249             0.085373             0.087944                           +2.83%               +3.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014503             0.012328             0.013663             0.010190                           +6.15%              +20.98%
(1000, 1000)         sum          0.014545             0.012378             0.013662             0.010579                           +6.46%              +17.01%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014163             0.008371             0.012893             0.008828                           +9.85%               -5.18%
(1024, 129)          sum          0.014132             0.008751             0.013234             0.008868                           +6.79%               -1.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014296             0.009101             0.013334             0.008563                           +7.21%               +6.28%
(1024, 257)          sum          0.014302             0.009058             0.013020             0.008672                           +9.85%               +4.45%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014127             0.010997             0.013443             0.009944                           +5.09%              +10.59%
(1024, 587)          sum          0.014471             0.011373             0.013123             0.010354                          +10.27%               +9.84%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015607             0.013566             0.015089             0.012152                           +3.43%              +11.64%
(2048, 977)          sum          0.015953             0.013580             0.015039             0.011861                           +6.08%              +14.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.013982             0.008058             0.012747             0.008139                           +9.69%               -1.00%
(1024, 128)          sum          0.013967             0.008071             0.012726             0.007859                           +9.75%               +2.70%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014378             0.009627             0.013712             0.009395                           +4.86%               +2.47%
(8192, 128)          sum          0.014389             0.009965             0.013718             0.009521                           +4.89%               +4.66%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014156             0.008267             0.012895             0.008833                           +9.78%               -6.41%
(1024, 130)          sum          0.013797             0.008277             0.012903             0.008512                           +6.93%               -2.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.014977             0.010026             0.013911             0.009876                           +7.66%               +1.52%
(8192, 130)          sum          0.014994             0.010043             0.014235             0.009604                           +5.33%               +4.57%
====================================================================================================================================================================================
```

**FP16**
```
Tensor Shape         Operation    Full reduce (ms)     Contiguous dim (ms)  Full reduce (ms)     Contiguous dim (ms)  Full reduce diff %   Contiguous diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.022804             0.008298             0.015888             0.007848                          +43.53%               +5.73%
(256, 256)           sum          0.023215             0.008328             0.015677             0.007850                          +48.08%               +6.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013777             0.009988             0.012884             0.008512                           +6.93%              +17.34%
(512, 512)           sum          0.013775             0.009622             0.012870             0.009028                           +7.03%               +6.58%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014740             0.012322             0.013708             0.010239                           +7.53%              +20.34%
(1024, 1024)         sum          0.014762             0.012756             0.013722             0.010307                           +7.58%              +23.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018700             0.018364             0.018135             0.015078                           +3.12%              +21.79%
(2048, 2048)         sum          0.018276             0.018415             0.018471             0.015127                           -1.06%              +21.74%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.034518             0.037000             0.033838             0.030617                           +2.01%              +20.85%
(4096, 4096)         sum          0.034569             0.037448             0.033842             0.031100                           +2.15%              +20.41%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.087675             0.095176             0.085328             0.084105                           +2.75%              +13.16%
(8192, 8192)         sum          0.088102             0.095211             0.085707             0.084090                           +2.79%              +13.23%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.147800             0.149263             0.146388             0.138390                           +0.96%               +7.86%
(8192, 16384)        sum          0.148147             0.148957             0.146439             0.138801                           +1.17%               +7.32%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266316             0.260294             0.265829             0.253411                           +0.18%               +2.72%
(8192, 32768)        sum          0.266562             0.260717             0.265744             0.253308                           +0.31%               +2.92%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.502035             0.486077             0.498139             0.481374                           +0.78%               +0.98%
(8192, 65536)        sum          0.501571             0.485733             0.498353             0.481350                           +0.65%               +0.91%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.971343             0.943016             0.956600             0.938622                           +1.54%               +0.47%
(8192, 131072)       sum          0.971463             0.942991             0.957352             0.938334                           +1.47%               +0.50%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.952722             1.877165             1.906406             1.861455                           +2.43%               +0.84%
(8192, 262144)       sum          1.952634             1.876388             1.904677             1.861282                           +2.52%               +0.81%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.970697             0.941298             0.956964             0.936160                           +1.44%               +0.55%
(4096, 262144)       sum          0.969981             0.941078             0.957016             0.936260                           +1.35%               +0.51%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.501577             0.487208             0.498422             0.483493                           +0.63%               +0.77%
(2048, 262144)       sum          0.502029             0.487124             0.497854             0.483643                           +0.84%               +0.72%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266416             0.257383             0.265928             0.255140                           +0.18%               +0.88%
(1024, 262144)       sum          0.266434             0.257081             0.265817             0.255143                           +0.23%               +0.76%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.087858             0.091296             0.085816             0.087745                           +2.38%               +4.05%
(512, 131072)        sum          0.088144             0.091314             0.085664             0.087864                           +2.90%               +3.93%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014977             0.012393             0.014141             0.010614                           +5.91%              +16.76%
(1000, 1000)         sum          0.014589             0.012804             0.014118             0.010320                           +3.34%              +24.07%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.014208             0.008383             0.013273             0.008440                           +7.04%               -0.68%
(1024, 129)          sum          0.013804             0.008863             0.013265             0.009003                           +4.06%               -1.56%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.014378             0.009109             0.013037             0.009038                          +10.29%               +0.79%
(1024, 257)          sum          0.014387             0.009113             0.013396             0.008698                           +7.40%               +4.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.014207             0.011037             0.013182             0.010391                           +7.78%               +6.22%
(1024, 587)          sum          0.014588             0.011453             0.013539             0.010049                           +7.75%              +13.97%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.016024             0.013614             0.015448             0.011845                           +3.73%              +14.93%
(2048, 977)          sum          0.015990             0.014033             0.015406             0.012278                           +3.79%              +14.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.014037             0.007804             0.013143             0.008242                           +6.80%               -5.31%
(1024, 128)          sum          0.014041             0.007847             0.012759             0.007850                          +10.05%               -0.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014361             0.009644             0.014075             0.009061                           +2.03%               +6.43%
(8192, 128)          sum          0.014366             0.010032             0.013702             0.009181                           +4.85%               +9.27%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.014226             0.008696             0.012894             0.008835                          +10.33%               -1.57%
(1024, 130)          sum          0.013830             0.008740             0.013288             0.008989                           +4.08%               -2.77%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.015036             0.010019             0.013917             0.009538                           +8.04%               +5.04%
(8192, 130)          sum          0.014652             0.010403             0.013900             0.009565                           +5.41%               +8.76%
====================================================================================================================================================================================
```
Pull Request resolved: pytorch#165055
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#165494, pytorch#164790
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…65178)

Performance benchmarking, perf neutral:
```
================================================================================================================================================================================================================================================
Tensor Shape         Operation    Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full reduce (ms)     Non-Contig dim (ms)    Contig dim (ms)      Full diff %     Non-Contig diff %    Contig diff %
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(256, 256)           mean         0.015684             0.017056               0.008287             0.016015             0.016929               0.008170                      -2.07%               +0.75%          +1.43%
(256, 256)           sum          0.015774             0.016638               0.007926             0.015811             0.016935               0.008330                      -0.23%               -1.75%          -4.85%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 512)           mean         0.013385             0.025742               0.008629             0.013046             0.026005               0.008924                      +2.60%               -1.01%          -3.31%
(512, 512)           sum          0.013390             0.026059               0.009116             0.013054             0.025696               0.008952                      +2.57%               +1.41%          +1.83%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)         mean         0.014213             0.015467               0.010334             0.013862             0.015082               0.010318                      +2.53%               +2.55%          +0.16%
(1024, 1024)         sum          0.014179             0.015446               0.010774             0.014132             0.015073               0.010350                      +0.33%               +2.47%          +4.10%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)         mean         0.018234             0.019487               0.014812             0.018482             0.019397               0.014802                      -1.34%               +0.46%          +0.07%
(2048, 2048)         sum          0.018202             0.019529               0.015195             0.018122             0.019485               0.015129                      +0.44%               +0.23%          +0.44%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)         mean         0.033582             0.039378               0.030751             0.033810             0.039673               0.031019                      -0.67%               -0.74%          -0.86%
(4096, 4096)         sum          0.033604             0.039777               0.030809             0.033530             0.039386               0.031113                      +0.22%               +0.99%          -0.98%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)         mean         0.085824             0.091133               0.084200             0.085431             0.091364               0.084303                      +0.46%               -0.25%          -0.12%
(8192, 8192)         sum          0.085763             0.091442               0.084180             0.085508             0.091419               0.084595                      +0.30%               +0.03%          -0.49%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 16384)        mean         0.146480             0.147666               0.138807             0.146515             0.147987               0.138930                      -0.02%               -0.22%          -0.09%
(8192, 16384)        sum          0.146446             0.147593               0.138559             0.146151             0.147982               0.139120                      +0.20%               -0.26%          -0.40%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 32768)        mean         0.266047             0.265386               0.253837             0.265648             0.265885               0.253652                      +0.15%               -0.19%          +0.07%
(8192, 32768)        sum          0.266093             0.265421               0.253890             0.265458             0.265591               0.253567                      +0.24%               -0.06%          +0.13%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 65536)        mean         0.498632             0.508976               0.481865             0.498237             0.508777               0.481476                      +0.08%               +0.04%          +0.08%
(8192, 65536)        sum          0.498917             0.508202               0.481883             0.498104             0.508016               0.481972                      +0.16%               +0.04%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 131072)       mean         0.957633             0.968519               0.938172             0.956766             0.968267               0.938196                      +0.09%               +0.03%          -0.00%
(8192, 131072)       sum          0.956972             0.968140               0.937741             0.957365             0.968404               0.938056                      -0.04%               -0.03%          -0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 262144)       mean         1.906661             1.928377               1.861846             1.907327             1.928811               1.862083                      -0.03%               -0.02%          -0.01%
(8192, 262144)       sum          1.905976             1.928362               1.862399             1.907098             1.928844               1.861782                      -0.06%               -0.02%          +0.03%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 262144)       mean         0.956852             0.970101               0.936524             0.957263             0.969809               0.936965                      -0.04%               +0.03%          -0.05%
(4096, 262144)       sum          0.957117             0.969933               0.936247             0.956675             0.969451               0.936395                      +0.05%               +0.05%          -0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 262144)       mean         0.498813             0.511299               0.483415             0.498567             0.511482               0.483376                      +0.05%               -0.04%          +0.01%
(2048, 262144)       sum          0.498813             0.510834               0.483641             0.498875             0.511036               0.483338                      -0.01%               -0.04%          +0.06%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 262144)       mean         0.266157             0.276751               0.255192             0.265966             0.276808               0.255544                      +0.07%               -0.02%          -0.14%
(1024, 262144)       sum          0.266133             0.276709               0.255528             0.265658             0.276685               0.255287                      +0.18%               +0.01%          +0.09%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(512, 131072)        mean         0.085941             0.081184               0.087931             0.085591             0.080832               0.088008                      +0.41%               +0.44%          -0.09%
(512, 131072)        sum          0.085962             0.081107               0.088045             0.085882             0.081160               0.088024                      +0.09%               -0.07%          +0.02%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000)         mean         0.014203             0.045859               0.010310             0.013885             0.046132               0.010621                      +2.29%               -0.59%          -2.93%
(1000, 1000)         sum          0.014180             0.046165               0.010756             0.013893             0.046109               0.010338                      +2.07%               +0.12%          +4.04%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 129)          mean         0.012953             0.016751               0.008536             0.012977             0.016714               0.008916                      -0.18%               +0.22%          -4.26%
(1024, 129)          sum          0.013356             0.016806               0.008722             0.013003             0.017071               0.008611                      +2.71%               -1.55%          +1.29%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 257)          mean         0.013075             0.016787               0.009102             0.013116             0.016769               0.008679                      -0.31%               +0.11%          +4.87%
(1024, 257)          sum          0.013092             0.016842               0.008786             0.013126             0.017128               0.008771                      -0.26%               -1.67%          +0.17%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 587)          mean         0.013662             0.017412               0.010055             0.013659             0.017019               0.010033                      +0.02%               +2.31%          +0.22%
(1024, 587)          sum          0.013636             0.017473               0.010163             0.013642             0.017363               0.010101                      -0.04%               +0.63%          +0.61%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 977)          mean         0.015276             0.027873               0.012531             0.015241             0.027783               0.012467                      +0.23%               +0.32%          +0.51%
(2048, 977)          sum          0.015345             0.027949               0.012192             0.015255             0.027839               0.012485                      +0.59%               +0.40%          -2.35%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 128)          mean         0.012806             0.014020               0.008291             0.013137             0.014309               0.007908                      -2.52%               -2.02%          +4.84%
(1024, 128)          sum          0.012769             0.014308               0.007924             0.012788             0.014236               0.008038                      -0.15%               +0.51%          -1.42%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 128)          mean         0.014145             0.023049               0.009143             0.014104             0.023298               0.009501                      +0.29%               -1.07%          -3.77%
(8192, 128)          sum          0.014132             0.023082               0.009638             0.014107             0.023331               0.009244                      +0.18%               -1.07%          +4.26%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 130)          mean         0.013420             0.025834               0.008949             0.013368             0.025724               0.008918                      +0.39%               +0.43%          +0.35%
(1024, 130)          sum          0.013300             0.025940               0.009113             0.013266             0.025419               0.008922                      +0.26%               +2.05%          +2.14%
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 130)          mean         0.013993             0.017883               0.009661             0.014275             0.018220               0.009596                      -1.98%               -1.85%          +0.68%
(8192, 130)          sum          0.014026             0.018297               0.010066             0.014326             0.018257               0.009659                      -2.09%               +0.22%          +4.21%
================================================================================================================================================================================================================================================
```

Pull Request resolved: pytorch#165178
Approved by: https://github.com/ngimel
ghstack dependencies: pytorch#165494, pytorch#164790, pytorch#165055
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100 ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.