-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[ATen] Vectorize 8 elements on 16 bit data types for sum/mean #165055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165055
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 910c221 with merge base 90c0825 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Looks like there's a lot of other uses of |
iter, func_wrapper<out_t>([] GPU_LAMBDA(acc_t a, acc_t b) -> acc_t { | ||
return a + b; | ||
})); | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this is a typo 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes typo, need to update, thanks!
I have found regressions generally for other reductions, like argmax, min, etc. I will add some performance numbers in a bit |
[ghstack-poisoned]
template <typename scalar_t, typename acc_t=scalar_t, typename out_t=scalar_t> | ||
void mean_kernel_impl(TensorIterator& iter) { | ||
// returns acc_t for all non-complex dtypes and returns T for c10::complex<T> | ||
const bool is_16_bits = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be constexpr bool
using factor_t = typename c10::scalar_value_type<acc_t>::type; | ||
factor_t factor = static_cast<factor_t>(iter.num_output_elements()) / iter.numel(); | ||
gpu_reduce_kernel<scalar_t, out_t>(iter, MeanOps<scalar_t, acc_t, factor_t, out_t> {factor}); | ||
if (is_16_bits) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this becomes constexpr if too
template <typename scalar_t, typename acc_t = scalar_t, typename out_t = scalar_t> | ||
struct sum_functor { | ||
void operator()(TensorIterator& iter) { | ||
const bool is_16_bits = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
} | ||
#endif | ||
gpu_reduce_kernel<scalar_t, out_t>( | ||
if (is_16_bits) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have an issue where MSCV complains about constexpr use the macro used for constexpr in other CUDA kernels.
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` Tensor Shape Operation Base Full reduce (ms) Base Contiguous dim (ms) V Full reduce (ms) V Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022750 0.008646 0.016023 0.008252 +41.98% +4.77% (256, 256) sum 0.023141 0.008664 0.015670 0.008269 +47.68% +4.78% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014090 0.009603 0.013006 0.008669 +8.33% +10.77% (512, 512) sum 0.014084 0.009622 0.013015 0.008672 +8.21% +10.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014368 0.012647 0.013895 0.010347 +3.40% +22.23% (1024, 1024) sum 0.014700 0.012651 0.014261 0.010368 +3.08% +22.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018587 0.017994 0.019311 0.015388 -3.75% +16.94% (2048, 2048) sum 0.018590 0.017995 0.018913 0.015064 -1.71% +19.46% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034463 0.037302 0.034645 0.032360 -0.53% +15.27% (4096, 4096) sum 0.034163 0.037045 0.034679 0.031992 -1.49% +15.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087827 0.095071 0.087534 0.087104 +0.33% +9.15% (8192, 8192) sum 0.088087 0.094829 0.087511 0.086730 +0.66% +9.34% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148125 0.149172 0.150266 0.138626 -1.42% +7.61% (8192, 16384) sum 0.147821 0.149253 0.150140 0.138618 -1.54% +7.67% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266410 0.261032 0.274723 0.271464 -3.03% -3.84% (8192, 32768) sum 0.266214 0.260893 0.275037 0.271687 -3.21% -3.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.501475 0.485934 0.520243 0.535353 -3.61% -9.23% (8192, 65536) sum 0.501845 0.486466 0.520737 0.534895 -3.63% -9.05% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.970737 0.942428 1.012732 1.013035 -4.15% -6.97% (8192, 131072) sum 0.970730 0.942806 1.012648 1.012854 -4.14% -6.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.954044 1.877875 2.019751 1.944285 -3.25% -3.42% (8192, 262144) sum 1.954403 1.877363 2.020125 1.945457 -3.25% -3.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.971163 0.941658 1.013148 0.977148 -4.14% -3.63% (4096, 262144) sum 0.971205 0.941575 1.013127 0.977082 -4.14% -3.63% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501913 0.487320 0.520714 0.500667 -3.61% -2.67% (2048, 262144) sum 0.501453 0.486805 0.520148 0.500693 -3.59% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266118 0.257603 0.274991 0.261859 -3.23% -1.63% (1024, 262144) sum 0.266247 0.257172 0.275075 0.261925 -3.21% -1.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091222 0.087552 0.090097 +0.29% +1.25% (512, 131072) sum 0.087862 0.091237 0.087519 0.090119 +0.39% +1.24% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014582 0.012707 0.014285 0.010758 +2.08% +18.12% (1000, 1000) sum 0.014568 0.012711 0.014289 0.010754 +1.95% +18.20% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014099 0.008741 0.013312 0.008671 +5.91% +0.81% (1024, 129) sum 0.013774 0.008410 0.012965 0.008663 +6.24% -2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014312 0.009436 0.013504 0.008753 +5.98% +7.80% (1024, 257) sum 0.013998 0.009111 0.013503 0.008770 +3.67% +3.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014163 0.011363 0.013690 0.010475 +3.46% +8.48% (1024, 587) sum 0.014172 0.011377 0.013356 0.010476 +6.11% +8.60% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015947 0.013946 0.015327 0.012085 +4.05% +15.40% (2048, 977) sum 0.015631 0.013945 0.015322 0.012464 +2.02% +11.88% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension: ``` ==================================================================================================================================================================================== PyTorch Reduction Operations Benchmark - Combined Comparison ==================================================================================================================================================================================== Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.015700 0.008431 0.016060 0.008293 -2.24% +1.66% (256, 256) sum 0.023307 0.008427 0.015687 0.007887 +48.58% +6.85% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013473 0.008701 0.013377 0.009102 +0.72% -4.41% (512, 512) sum 0.014294 0.009716 0.013321 0.008644 +7.30% +12.40% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014308 0.010442 0.013879 0.010776 +3.09% -3.10% (1024, 1024) sum 0.014516 0.012448 0.013785 0.010314 +5.30% +20.69% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.019392 0.015462 0.018919 0.015022 +2.50% +2.93% (2048, 2048) sum 0.018793 0.018475 0.018607 0.015159 +1.00% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034354 0.032064 0.034267 0.032384 +0.25% -0.99% (4096, 4096) sum 0.034703 0.037153 0.034006 0.030732 +2.05% +20.89% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087554 0.086913 0.087550 0.087240 +0.00% -0.37% (8192, 8192) sum 0.088265 0.095772 0.085408 0.084674 +3.35% +13.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.150714 0.138377 0.150194 0.138775 +0.35% -0.29% (8192, 16384) sum 0.147996 0.149958 0.146178 0.138765 +1.24% +8.07% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.274935 0.271579 0.274787 0.271480 +0.05% +0.04% (8192, 32768) sum 0.266700 0.261308 0.266002 0.253478 +0.26% +3.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.520137 0.535495 0.520826 0.534928 -0.13% +0.11% (8192, 65536) sum 0.501590 0.486255 0.498309 0.481378 +0.66% +1.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 1.013195 1.013046 1.012499 1.012806 +0.07% +0.02% (8192, 131072) sum 0.970992 0.943506 0.957651 0.938352 +1.39% +0.55% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 2.020536 1.944477 2.021117 1.944456 -0.03% +0.00% (8192, 262144) sum 1.954167 1.877884 1.904428 1.862436 +2.61% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 1.013495 0.977271 1.012324 0.977721 +0.12% -0.05% (4096, 262144) sum 0.970941 0.941685 0.957323 0.936643 +1.42% +0.54% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.520851 0.500767 0.520352 0.500979 +0.10% -0.04% (2048, 262144) sum 0.501935 0.487330 0.497851 0.483599 +0.82% +0.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.275272 0.262046 0.275309 0.261893 -0.01% +0.06% (1024, 262144) sum 0.266447 0.257469 0.265654 0.255563 +0.30% +0.75% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087285 0.090168 0.087124 0.089990 +0.18% +0.20% (512, 131072) sum 0.087961 0.091448 0.085851 0.088071 +2.46% +3.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014012 0.010440 0.013918 0.010388 +0.68% +0.50% (1000, 1000) sum 0.014668 0.012879 0.014121 0.010303 +3.87% +25.00% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.013072 0.008655 0.013373 0.008693 -2.25% -0.44% (1024, 129) sum 0.014308 0.008951 0.013336 0.008592 +7.29% +4.18% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.013171 0.008889 0.013516 0.009153 -2.55% -2.88% (1024, 257) sum 0.014455 0.009573 0.013460 0.008709 +7.39% +9.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.013385 0.010576 0.013303 0.010525 +0.62% +0.48% (1024, 587) sum 0.014635 0.011530 0.013569 0.010458 +7.86% +10.25% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015431 0.012528 0.015751 0.012079 -2.03% +3.72% (2048, 977) sum 0.015732 0.014101 0.015064 0.011896 +4.43% +18.54% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013273 0.008437 0.013219 0.007888 +0.41% +6.96% (1024, 128) sum 0.013738 0.007830 0.012793 0.008284 +7.39% -5.48% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.013930 0.010086 0.013843 0.009677 +0.63% +4.23% (8192, 128) sum 0.014525 0.009780 0.013758 0.009247 +5.57% +5.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.013427 0.009024 0.013372 0.009078 +0.41% -0.59% (1024, 130) sum 0.013951 0.008799 0.013360 0.008560 +4.42% +2.79% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014121 0.010474 0.014035 0.009967 +0.61% +5.09% (8192, 130) sum 0.015139 0.010545 0.013958 0.010045 +8.46% +4.98% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty cool!
constexpr bool is_16_bits = | ||
( (std::is_same<at::Half, scalar_t>::value) || | ||
(std::is_same<at::BFloat16, scalar_t>::value) ); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constexpr bool is_16_bits = | |
( (std::is_same<at::Half, scalar_t>::value) || | |
(std::is_same<at::BFloat16, scalar_t>::value) ); | |
constexpr bool is_16_bits = sizeof(scalar_t) == 16; | |
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for BF16, ~6% improvement on average across shapes: ``` ==================================================================================================================================================================================== PyTorch Reduction Operations Benchmark - Combined Comparison ==================================================================================================================================================================================== Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80% (256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99% (512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18% (1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64% (2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82% (4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05% (8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81% (8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95% (8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05% (8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50% (8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86% (8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45% (4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70% (2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65% (1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81% (512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98% (1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18% (1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28% (1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59% (1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64% (2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00% (1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47% (8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41% (1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52% (8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
@Skylion007 MSCV complains about the extended lambda definition in the if constexpr. What is the recommended workaround here? |
…ean" Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce: **BF16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80% (256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99% (512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18% (1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64% (2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82% (4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05% (8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81% (8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95% (8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05% (8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50% (8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86% (8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45% (4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70% (2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65% (1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81% (512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98% (1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18% (1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28% (1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59% (1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64% (2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00% (1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47% (8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41% (1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52% (8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57% ==================================================================================================================================================================================== ``` **FP16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022804 0.008298 0.015888 0.007848 +43.53% +5.73% (256, 256) sum 0.023215 0.008328 0.015677 0.007850 +48.08% +6.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013777 0.009988 0.012884 0.008512 +6.93% +17.34% (512, 512) sum 0.013775 0.009622 0.012870 0.009028 +7.03% +6.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014740 0.012322 0.013708 0.010239 +7.53% +20.34% (1024, 1024) sum 0.014762 0.012756 0.013722 0.010307 +7.58% +23.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018700 0.018364 0.018135 0.015078 +3.12% +21.79% (2048, 2048) sum 0.018276 0.018415 0.018471 0.015127 -1.06% +21.74% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034518 0.037000 0.033838 0.030617 +2.01% +20.85% (4096, 4096) sum 0.034569 0.037448 0.033842 0.031100 +2.15% +20.41% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087675 0.095176 0.085328 0.084105 +2.75% +13.16% (8192, 8192) sum 0.088102 0.095211 0.085707 0.084090 +2.79% +13.23% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.147800 0.149263 0.146388 0.138390 +0.96% +7.86% (8192, 16384) sum 0.148147 0.148957 0.146439 0.138801 +1.17% +7.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266316 0.260294 0.265829 0.253411 +0.18% +2.72% (8192, 32768) sum 0.266562 0.260717 0.265744 0.253308 +0.31% +2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502035 0.486077 0.498139 0.481374 +0.78% +0.98% (8192, 65536) sum 0.501571 0.485733 0.498353 0.481350 +0.65% +0.91% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971343 0.943016 0.956600 0.938622 +1.54% +0.47% (8192, 131072) sum 0.971463 0.942991 0.957352 0.938334 +1.47% +0.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.952722 1.877165 1.906406 1.861455 +2.43% +0.84% (8192, 262144) sum 1.952634 1.876388 1.904677 1.861282 +2.52% +0.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970697 0.941298 0.956964 0.936160 +1.44% +0.55% (4096, 262144) sum 0.969981 0.941078 0.957016 0.936260 +1.35% +0.51% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501577 0.487208 0.498422 0.483493 +0.63% +0.77% (2048, 262144) sum 0.502029 0.487124 0.497854 0.483643 +0.84% +0.72% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266416 0.257383 0.265928 0.255140 +0.18% +0.88% (1024, 262144) sum 0.266434 0.257081 0.265817 0.255143 +0.23% +0.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087858 0.091296 0.085816 0.087745 +2.38% +4.05% (512, 131072) sum 0.088144 0.091314 0.085664 0.087864 +2.90% +3.93% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014977 0.012393 0.014141 0.010614 +5.91% +16.76% (1000, 1000) sum 0.014589 0.012804 0.014118 0.010320 +3.34% +24.07% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014208 0.008383 0.013273 0.008440 +7.04% -0.68% (1024, 129) sum 0.013804 0.008863 0.013265 0.009003 +4.06% -1.56% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014378 0.009109 0.013037 0.009038 +10.29% +0.79% (1024, 257) sum 0.014387 0.009113 0.013396 0.008698 +7.40% +4.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014207 0.011037 0.013182 0.010391 +7.78% +6.22% (1024, 587) sum 0.014588 0.011453 0.013539 0.010049 +7.75% +13.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.016024 0.013614 0.015448 0.011845 +3.73% +14.93% (2048, 977) sum 0.015990 0.014033 0.015406 0.012278 +3.79% +14.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.014037 0.007804 0.013143 0.008242 +6.80% -5.31% (1024, 128) sum 0.014041 0.007847 0.012759 0.007850 +10.05% -0.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014361 0.009644 0.014075 0.009061 +2.03% +6.43% (8192, 128) sum 0.014366 0.010032 0.013702 0.009181 +4.85% +9.27% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014226 0.008696 0.012894 0.008835 +10.33% -1.57% (1024, 130) sum 0.013830 0.008740 0.013288 0.008989 +4.08% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.015036 0.010019 0.013917 0.009538 +8.04% +5.04% (8192, 130) sum 0.014652 0.010403 0.013900 0.009565 +5.41% +8.76% ==================================================================================================================================================================================== ``` [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Performance benchmarking, perf neutral: ``` ================================================================================================================================================================================================================================================ Tensor Shape Operation Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full diff % Non-Contig diff % Contig diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.015684 0.017056 0.008287 0.016015 0.016929 0.008170 -2.07% +0.75% +1.43% (256, 256) sum 0.015774 0.016638 0.007926 0.015811 0.016935 0.008330 -0.23% -1.75% -4.85% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013385 0.025742 0.008629 0.013046 0.026005 0.008924 +2.60% -1.01% -3.31% (512, 512) sum 0.013390 0.026059 0.009116 0.013054 0.025696 0.008952 +2.57% +1.41% +1.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014213 0.015467 0.010334 0.013862 0.015082 0.010318 +2.53% +2.55% +0.16% (1024, 1024) sum 0.014179 0.015446 0.010774 0.014132 0.015073 0.010350 +0.33% +2.47% +4.10% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018234 0.019487 0.014812 0.018482 0.019397 0.014802 -1.34% +0.46% +0.07% (2048, 2048) sum 0.018202 0.019529 0.015195 0.018122 0.019485 0.015129 +0.44% +0.23% +0.44% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.033582 0.039378 0.030751 0.033810 0.039673 0.031019 -0.67% -0.74% -0.86% (4096, 4096) sum 0.033604 0.039777 0.030809 0.033530 0.039386 0.031113 +0.22% +0.99% -0.98% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.085824 0.091133 0.084200 0.085431 0.091364 0.084303 +0.46% -0.25% -0.12% (8192, 8192) sum 0.085763 0.091442 0.084180 0.085508 0.091419 0.084595 +0.30% +0.03% -0.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.146480 0.147666 0.138807 0.146515 0.147987 0.138930 -0.02% -0.22% -0.09% (8192, 16384) sum 0.146446 0.147593 0.138559 0.146151 0.147982 0.139120 +0.20% -0.26% -0.40% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266047 0.265386 0.253837 0.265648 0.265885 0.253652 +0.15% -0.19% +0.07% (8192, 32768) sum 0.266093 0.265421 0.253890 0.265458 0.265591 0.253567 +0.24% -0.06% +0.13% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.498632 0.508976 0.481865 0.498237 0.508777 0.481476 +0.08% +0.04% +0.08% (8192, 65536) sum 0.498917 0.508202 0.481883 0.498104 0.508016 0.481972 +0.16% +0.04% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.957633 0.968519 0.938172 0.956766 0.968267 0.938196 +0.09% +0.03% -0.00% (8192, 131072) sum 0.956972 0.968140 0.937741 0.957365 0.968404 0.938056 -0.04% -0.03% -0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.906661 1.928377 1.861846 1.907327 1.928811 1.862083 -0.03% -0.02% -0.01% (8192, 262144) sum 1.905976 1.928362 1.862399 1.907098 1.928844 1.861782 -0.06% -0.02% +0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.956852 0.970101 0.936524 0.957263 0.969809 0.936965 -0.04% +0.03% -0.05% (4096, 262144) sum 0.957117 0.969933 0.936247 0.956675 0.969451 0.936395 +0.05% +0.05% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.498813 0.511299 0.483415 0.498567 0.511482 0.483376 +0.05% -0.04% +0.01% (2048, 262144) sum 0.498813 0.510834 0.483641 0.498875 0.511036 0.483338 -0.01% -0.04% +0.06% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266157 0.276751 0.255192 0.265966 0.276808 0.255544 +0.07% -0.02% -0.14% (1024, 262144) sum 0.266133 0.276709 0.255528 0.265658 0.276685 0.255287 +0.18% +0.01% +0.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.085941 0.081184 0.087931 0.085591 0.080832 0.088008 +0.41% +0.44% -0.09% (512, 131072) sum 0.085962 0.081107 0.088045 0.085882 0.081160 0.088024 +0.09% -0.07% +0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014203 0.045859 0.010310 0.013885 0.046132 0.010621 +2.29% -0.59% -2.93% (1000, 1000) sum 0.014180 0.046165 0.010756 0.013893 0.046109 0.010338 +2.07% +0.12% +4.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.012953 0.016751 0.008536 0.012977 0.016714 0.008916 -0.18% +0.22% -4.26% (1024, 129) sum 0.013356 0.016806 0.008722 0.013003 0.017071 0.008611 +2.71% -1.55% +1.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.013075 0.016787 0.009102 0.013116 0.016769 0.008679 -0.31% +0.11% +4.87% (1024, 257) sum 0.013092 0.016842 0.008786 0.013126 0.017128 0.008771 -0.26% -1.67% +0.17% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.013662 0.017412 0.010055 0.013659 0.017019 0.010033 +0.02% +2.31% +0.22% (1024, 587) sum 0.013636 0.017473 0.010163 0.013642 0.017363 0.010101 -0.04% +0.63% +0.61% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015276 0.027873 0.012531 0.015241 0.027783 0.012467 +0.23% +0.32% +0.51% (2048, 977) sum 0.015345 0.027949 0.012192 0.015255 0.027839 0.012485 +0.59% +0.40% -2.35% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.012806 0.014020 0.008291 0.013137 0.014309 0.007908 -2.52% -2.02% +4.84% (1024, 128) sum 0.012769 0.014308 0.007924 0.012788 0.014236 0.008038 -0.15% +0.51% -1.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014145 0.023049 0.009143 0.014104 0.023298 0.009501 +0.29% -1.07% -3.77% (8192, 128) sum 0.014132 0.023082 0.009638 0.014107 0.023331 0.009244 +0.18% -1.07% +4.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.013420 0.025834 0.008949 0.013368 0.025724 0.008918 +0.39% +0.43% +0.35% (1024, 130) sum 0.013300 0.025940 0.009113 0.013266 0.025419 0.008922 +0.26% +2.05% +2.14% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.013993 0.017883 0.009661 0.014275 0.018220 0.009596 -1.98% -1.85% +0.68% (8192, 130) sum 0.014026 0.018297 0.010066 0.014326 0.018257 0.009659 -2.09% +0.22% +4.21% ================================================================================================================================================================================================================================================ ``` Pull Request resolved: #165178 Approved by: https://github.com/ngimel ghstack dependencies: #165494, #164790, #165055
…h#165055) Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce: **BF16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022686 0.008263 0.015498 0.008117 +46.38% +1.80% (256, 256) sum 0.022769 0.008269 0.015628 0.008185 +45.69% +1.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.014116 0.009545 0.012892 0.008839 +9.49% +7.99% (512, 512) sum 0.014110 0.009892 0.012891 0.008878 +9.46% +11.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014727 0.012642 0.014061 0.010519 +4.74% +20.18% (1024, 1024) sum 0.014376 0.012636 0.014069 0.010595 +2.18% +19.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018663 0.018294 0.018171 0.014678 +2.71% +24.64% (2048, 2048) sum 0.018638 0.017931 0.018142 0.014713 +2.73% +21.87% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034216 0.036953 0.033520 0.030585 +2.08% +20.82% (4096, 4096) sum 0.034196 0.036942 0.033518 0.030676 +2.02% +20.43% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087763 0.095201 0.085439 0.084960 +2.72% +12.05% (8192, 8192) sum 0.088079 0.095592 0.085353 0.084632 +3.19% +12.95% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.148174 0.149705 0.146274 0.138865 +1.30% +7.81% (8192, 16384) sum 0.147820 0.149371 0.146419 0.138752 +0.96% +7.65% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266144 0.260807 0.265953 0.253330 +0.07% +2.95% (8192, 32768) sum 0.266572 0.261163 0.265729 0.253294 +0.32% +3.11% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502034 0.486312 0.498417 0.481246 +0.73% +1.05% (8192, 65536) sum 0.501597 0.486351 0.497735 0.481579 +0.78% +0.99% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971178 0.942988 0.957164 0.938316 +1.46% +0.50% (8192, 131072) sum 0.971189 0.943232 0.956814 0.937816 +1.50% +0.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.953728 1.877648 1.904937 1.861692 +2.56% +0.86% (8192, 262144) sum 1.953969 1.877538 1.905990 1.862547 +2.52% +0.80% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970408 0.940965 0.957871 0.936732 +1.31% +0.45% (4096, 262144) sum 0.970919 0.941652 0.957765 0.936676 +1.37% +0.53% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501477 0.486976 0.497964 0.483570 +0.71% +0.70% (2048, 262144) sum 0.501955 0.487213 0.498210 0.483218 +0.75% +0.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266536 0.257111 0.265642 0.255439 +0.34% +0.65% (1024, 262144) sum 0.266613 0.257096 0.265427 0.255472 +0.45% +0.64% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087805 0.091200 0.085818 0.087851 +2.32% +3.81% (512, 131072) sum 0.087788 0.091249 0.085373 0.087944 +2.83% +3.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014503 0.012328 0.013663 0.010190 +6.15% +20.98% (1000, 1000) sum 0.014545 0.012378 0.013662 0.010579 +6.46% +17.01% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014163 0.008371 0.012893 0.008828 +9.85% -5.18% (1024, 129) sum 0.014132 0.008751 0.013234 0.008868 +6.79% -1.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014296 0.009101 0.013334 0.008563 +7.21% +6.28% (1024, 257) sum 0.014302 0.009058 0.013020 0.008672 +9.85% +4.45% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014127 0.010997 0.013443 0.009944 +5.09% +10.59% (1024, 587) sum 0.014471 0.011373 0.013123 0.010354 +10.27% +9.84% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015607 0.013566 0.015089 0.012152 +3.43% +11.64% (2048, 977) sum 0.015953 0.013580 0.015039 0.011861 +6.08% +14.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.013982 0.008058 0.012747 0.008139 +9.69% -1.00% (1024, 128) sum 0.013967 0.008071 0.012726 0.007859 +9.75% +2.70% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014378 0.009627 0.013712 0.009395 +4.86% +2.47% (8192, 128) sum 0.014389 0.009965 0.013718 0.009521 +4.89% +4.66% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014156 0.008267 0.012895 0.008833 +9.78% -6.41% (1024, 130) sum 0.013797 0.008277 0.012903 0.008512 +6.93% -2.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.014977 0.010026 0.013911 0.009876 +7.66% +1.52% (8192, 130) sum 0.014994 0.010043 0.014235 0.009604 +5.33% +4.57% ==================================================================================================================================================================================== ``` **FP16** ``` Tensor Shape Operation Full reduce (ms) Contiguous dim (ms) Full reduce (ms) Contiguous dim (ms) Full reduce diff % Contiguous diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.022804 0.008298 0.015888 0.007848 +43.53% +5.73% (256, 256) sum 0.023215 0.008328 0.015677 0.007850 +48.08% +6.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013777 0.009988 0.012884 0.008512 +6.93% +17.34% (512, 512) sum 0.013775 0.009622 0.012870 0.009028 +7.03% +6.58% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014740 0.012322 0.013708 0.010239 +7.53% +20.34% (1024, 1024) sum 0.014762 0.012756 0.013722 0.010307 +7.58% +23.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018700 0.018364 0.018135 0.015078 +3.12% +21.79% (2048, 2048) sum 0.018276 0.018415 0.018471 0.015127 -1.06% +21.74% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.034518 0.037000 0.033838 0.030617 +2.01% +20.85% (4096, 4096) sum 0.034569 0.037448 0.033842 0.031100 +2.15% +20.41% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.087675 0.095176 0.085328 0.084105 +2.75% +13.16% (8192, 8192) sum 0.088102 0.095211 0.085707 0.084090 +2.79% +13.23% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.147800 0.149263 0.146388 0.138390 +0.96% +7.86% (8192, 16384) sum 0.148147 0.148957 0.146439 0.138801 +1.17% +7.32% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266316 0.260294 0.265829 0.253411 +0.18% +2.72% (8192, 32768) sum 0.266562 0.260717 0.265744 0.253308 +0.31% +2.92% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.502035 0.486077 0.498139 0.481374 +0.78% +0.98% (8192, 65536) sum 0.501571 0.485733 0.498353 0.481350 +0.65% +0.91% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.971343 0.943016 0.956600 0.938622 +1.54% +0.47% (8192, 131072) sum 0.971463 0.942991 0.957352 0.938334 +1.47% +0.50% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.952722 1.877165 1.906406 1.861455 +2.43% +0.84% (8192, 262144) sum 1.952634 1.876388 1.904677 1.861282 +2.52% +0.81% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.970697 0.941298 0.956964 0.936160 +1.44% +0.55% (4096, 262144) sum 0.969981 0.941078 0.957016 0.936260 +1.35% +0.51% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.501577 0.487208 0.498422 0.483493 +0.63% +0.77% (2048, 262144) sum 0.502029 0.487124 0.497854 0.483643 +0.84% +0.72% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266416 0.257383 0.265928 0.255140 +0.18% +0.88% (1024, 262144) sum 0.266434 0.257081 0.265817 0.255143 +0.23% +0.76% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.087858 0.091296 0.085816 0.087745 +2.38% +4.05% (512, 131072) sum 0.088144 0.091314 0.085664 0.087864 +2.90% +3.93% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014977 0.012393 0.014141 0.010614 +5.91% +16.76% (1000, 1000) sum 0.014589 0.012804 0.014118 0.010320 +3.34% +24.07% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.014208 0.008383 0.013273 0.008440 +7.04% -0.68% (1024, 129) sum 0.013804 0.008863 0.013265 0.009003 +4.06% -1.56% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.014378 0.009109 0.013037 0.009038 +10.29% +0.79% (1024, 257) sum 0.014387 0.009113 0.013396 0.008698 +7.40% +4.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.014207 0.011037 0.013182 0.010391 +7.78% +6.22% (1024, 587) sum 0.014588 0.011453 0.013539 0.010049 +7.75% +13.97% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.016024 0.013614 0.015448 0.011845 +3.73% +14.93% (2048, 977) sum 0.015990 0.014033 0.015406 0.012278 +3.79% +14.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.014037 0.007804 0.013143 0.008242 +6.80% -5.31% (1024, 128) sum 0.014041 0.007847 0.012759 0.007850 +10.05% -0.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014361 0.009644 0.014075 0.009061 +2.03% +6.43% (8192, 128) sum 0.014366 0.010032 0.013702 0.009181 +4.85% +9.27% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.014226 0.008696 0.012894 0.008835 +10.33% -1.57% (1024, 130) sum 0.013830 0.008740 0.013288 0.008989 +4.08% -2.77% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.015036 0.010019 0.013917 0.009538 +8.04% +5.04% (8192, 130) sum 0.014652 0.010403 0.013900 0.009565 +5.41% +8.76% ==================================================================================================================================================================================== ``` Pull Request resolved: pytorch#165055 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#165494, pytorch#164790
…65178) Performance benchmarking, perf neutral: ``` ================================================================================================================================================================================================================================================ Tensor Shape Operation Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full reduce (ms) Non-Contig dim (ms) Contig dim (ms) Full diff % Non-Contig diff % Contig diff % ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (256, 256) mean 0.015684 0.017056 0.008287 0.016015 0.016929 0.008170 -2.07% +0.75% +1.43% (256, 256) sum 0.015774 0.016638 0.007926 0.015811 0.016935 0.008330 -0.23% -1.75% -4.85% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 512) mean 0.013385 0.025742 0.008629 0.013046 0.026005 0.008924 +2.60% -1.01% -3.31% (512, 512) sum 0.013390 0.026059 0.009116 0.013054 0.025696 0.008952 +2.57% +1.41% +1.83% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 1024) mean 0.014213 0.015467 0.010334 0.013862 0.015082 0.010318 +2.53% +2.55% +0.16% (1024, 1024) sum 0.014179 0.015446 0.010774 0.014132 0.015073 0.010350 +0.33% +2.47% +4.10% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 2048) mean 0.018234 0.019487 0.014812 0.018482 0.019397 0.014802 -1.34% +0.46% +0.07% (2048, 2048) sum 0.018202 0.019529 0.015195 0.018122 0.019485 0.015129 +0.44% +0.23% +0.44% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 4096) mean 0.033582 0.039378 0.030751 0.033810 0.039673 0.031019 -0.67% -0.74% -0.86% (4096, 4096) sum 0.033604 0.039777 0.030809 0.033530 0.039386 0.031113 +0.22% +0.99% -0.98% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 8192) mean 0.085824 0.091133 0.084200 0.085431 0.091364 0.084303 +0.46% -0.25% -0.12% (8192, 8192) sum 0.085763 0.091442 0.084180 0.085508 0.091419 0.084595 +0.30% +0.03% -0.49% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 16384) mean 0.146480 0.147666 0.138807 0.146515 0.147987 0.138930 -0.02% -0.22% -0.09% (8192, 16384) sum 0.146446 0.147593 0.138559 0.146151 0.147982 0.139120 +0.20% -0.26% -0.40% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 32768) mean 0.266047 0.265386 0.253837 0.265648 0.265885 0.253652 +0.15% -0.19% +0.07% (8192, 32768) sum 0.266093 0.265421 0.253890 0.265458 0.265591 0.253567 +0.24% -0.06% +0.13% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 65536) mean 0.498632 0.508976 0.481865 0.498237 0.508777 0.481476 +0.08% +0.04% +0.08% (8192, 65536) sum 0.498917 0.508202 0.481883 0.498104 0.508016 0.481972 +0.16% +0.04% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 131072) mean 0.957633 0.968519 0.938172 0.956766 0.968267 0.938196 +0.09% +0.03% -0.00% (8192, 131072) sum 0.956972 0.968140 0.937741 0.957365 0.968404 0.938056 -0.04% -0.03% -0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 262144) mean 1.906661 1.928377 1.861846 1.907327 1.928811 1.862083 -0.03% -0.02% -0.01% (8192, 262144) sum 1.905976 1.928362 1.862399 1.907098 1.928844 1.861782 -0.06% -0.02% +0.03% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (4096, 262144) mean 0.956852 0.970101 0.936524 0.957263 0.969809 0.936965 -0.04% +0.03% -0.05% (4096, 262144) sum 0.957117 0.969933 0.936247 0.956675 0.969451 0.936395 +0.05% +0.05% -0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 262144) mean 0.498813 0.511299 0.483415 0.498567 0.511482 0.483376 +0.05% -0.04% +0.01% (2048, 262144) sum 0.498813 0.510834 0.483641 0.498875 0.511036 0.483338 -0.01% -0.04% +0.06% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 262144) mean 0.266157 0.276751 0.255192 0.265966 0.276808 0.255544 +0.07% -0.02% -0.14% (1024, 262144) sum 0.266133 0.276709 0.255528 0.265658 0.276685 0.255287 +0.18% +0.01% +0.09% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (512, 131072) mean 0.085941 0.081184 0.087931 0.085591 0.080832 0.088008 +0.41% +0.44% -0.09% (512, 131072) sum 0.085962 0.081107 0.088045 0.085882 0.081160 0.088024 +0.09% -0.07% +0.02% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1000, 1000) mean 0.014203 0.045859 0.010310 0.013885 0.046132 0.010621 +2.29% -0.59% -2.93% (1000, 1000) sum 0.014180 0.046165 0.010756 0.013893 0.046109 0.010338 +2.07% +0.12% +4.04% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 129) mean 0.012953 0.016751 0.008536 0.012977 0.016714 0.008916 -0.18% +0.22% -4.26% (1024, 129) sum 0.013356 0.016806 0.008722 0.013003 0.017071 0.008611 +2.71% -1.55% +1.29% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 257) mean 0.013075 0.016787 0.009102 0.013116 0.016769 0.008679 -0.31% +0.11% +4.87% (1024, 257) sum 0.013092 0.016842 0.008786 0.013126 0.017128 0.008771 -0.26% -1.67% +0.17% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 587) mean 0.013662 0.017412 0.010055 0.013659 0.017019 0.010033 +0.02% +2.31% +0.22% (1024, 587) sum 0.013636 0.017473 0.010163 0.013642 0.017363 0.010101 -0.04% +0.63% +0.61% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (2048, 977) mean 0.015276 0.027873 0.012531 0.015241 0.027783 0.012467 +0.23% +0.32% +0.51% (2048, 977) sum 0.015345 0.027949 0.012192 0.015255 0.027839 0.012485 +0.59% +0.40% -2.35% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 128) mean 0.012806 0.014020 0.008291 0.013137 0.014309 0.007908 -2.52% -2.02% +4.84% (1024, 128) sum 0.012769 0.014308 0.007924 0.012788 0.014236 0.008038 -0.15% +0.51% -1.42% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 128) mean 0.014145 0.023049 0.009143 0.014104 0.023298 0.009501 +0.29% -1.07% -3.77% (8192, 128) sum 0.014132 0.023082 0.009638 0.014107 0.023331 0.009244 +0.18% -1.07% +4.26% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (1024, 130) mean 0.013420 0.025834 0.008949 0.013368 0.025724 0.008918 +0.39% +0.43% +0.35% (1024, 130) sum 0.013300 0.025940 0.009113 0.013266 0.025419 0.008922 +0.26% +2.05% +2.14% ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ (8192, 130) mean 0.013993 0.017883 0.009661 0.014275 0.018220 0.009596 -1.98% -1.85% +0.68% (8192, 130) sum 0.014026 0.018297 0.010066 0.014326 0.018257 0.009659 -2.09% +0.22% +4.21% ================================================================================================================================================================================================================================================ ``` Pull Request resolved: pytorch#165178 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#165494, pytorch#164790, pytorch#165055
Stack from ghstack (oldest at bottom):
Benchmarks for a full reduction + reduction on the contiguous dimension. Vectorized loads do not occur on the non contiguous dimension. Benchmarking done for FP16/BF16, ~6% improvement on average across shapes, up to ~24% for single reduction on contiguous dimension and 46% for full reduce:
BF16
FP16