Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e2d141d

Browse filesBrowse files
pytorchbotngimel
andauthored
set thread_work_size to 4 for unrolled kernel (#154541)
set thread_work_size to 4 for unrolled kernel (#152396) Previous PRs enabling 8-vectorization inadvertently regressed unrolled kernel perf. Pull Request resolved: #152396 Approved by: https://github.com/BoyuanFeng, https://github.com/msaroufim, https://github.com/malfet, https://github.com/Aidyn-A, https://github.com/atalman (cherry picked from commit adebb8b) Co-authored-by: Natalia Gimelshein <ngimel@meta.com>
1 parent 1214198 commit e2d141d
Copy full SHA for e2d141d

File tree

Expand file treeCollapse file tree

1 file changed

+11
-2
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+11
-2
lines changed

‎aten/src/ATen/native/cuda/CUDALoops.cuh

Copy file name to clipboardExpand all lines: aten/src/ATen/native/cuda/CUDALoops.cuh
+11-2Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,14 @@ constexpr auto elems_per_thread(){
8383
}
8484
#endif
8585

86+
87+
//thread work size of 8 regresses the perf of elementwise kernel on cuda
88+
//this doesn't change ROCm behavior as thread_work_size is already 4 on ROCm
89+
constexpr int elementwise_thread_work_size() {return 4;}
90+
constexpr int elementwise_block_work_size() {
91+
return elementwise_thread_work_size() * num_threads();
92+
}
93+
8694
template <int io_sizes>
8795
constexpr auto io_block_work_size() {
8896
return num_threads() * elems_per_thread<io_sizes>();
@@ -336,9 +344,10 @@ static inline void launch_unrolled_kernel(
336344
loader_t l,
337345
storer_t s) {
338346
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
339-
int64_t grid = (N + block_work_size() - 1) / block_work_size();
347+
348+
int64_t grid = (N + elementwise_block_work_size() - 1) / elementwise_block_work_size();
340349
auto stream = at::cuda::getCurrentCUDAStream();
341-
unrolled_elementwise_kernel<func_t, array_t, thread_work_size()>
350+
unrolled_elementwise_kernel<func_t, array_t, elementwise_thread_work_size()>
342351
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
343352
C10_CUDA_KERNEL_LAUNCH_CHECK();
344353
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.