Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 43f78bf

Browse filesBrowse files
hvaarapytorchmergebot
authored andcommitted
[MPS] Gather sliced inputs to batch norm (#133610)
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary. It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs. ### Performance impact #### With fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 282 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 448 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 705 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 1.11 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.16 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 11.7 msec per loop ``` #### Without fix ``` python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 100 loops, best of 5: 284 usec per loop python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 100 loops, best of 5: 265 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 715 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 675 usec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)" 1000 loops, best of 5: 7.19 msec per loop python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])" 1000 loops, best of 5: 7.13 msec per loop ``` Please feel free to push back or request changes. Fixes #133520 Pull Request resolved: #133610 Approved by: https://github.com/malfet
1 parent 278bc98 commit 43f78bf
Copy full SHA for 43f78bf

File tree

Expand file treeCollapse file tree

2 files changed

+14
-7
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+14
-7
lines changed

‎aten/src/ATen/native/mps/operations/Normalization.mm

Copy file name to clipboardExpand all lines: aten/src/ATen/native/mps/operations/Normalization.mm
+1-7Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,6 @@ static void get_shapes(MPSShape* input_shape_readonly,
153153
else
154154
channelsDim = num_input_dims - 1;
155155

156-
bool executeGatherOp = true;
157-
if (self.is_contiguous(memory_format)) {
158-
memory_format = MemoryFormat::Contiguous;
159-
executeGatherOp = false;
160-
}
161-
162156
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
163157
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_mps_dtype, input_shape);
164158
MPSGraphTensor* weightTensor = nil;
@@ -302,7 +296,7 @@ Check if running mean exists (maybe do this check before making graph)
302296
newCachedGraph->runningVarInplaceUpdate_ = runningVarInplaceUpdate;
303297
});
304298

305-
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape, executeGatherOp);
299+
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, self, input_shape);
306300
auto weightPlaceholder = Placeholder();
307301
if (has_weight)
308302
weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape);

‎test/test_mps.py

Copy file name to clipboardExpand all lines: test/test_mps.py
+13Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,6 +2541,19 @@ def test_batch_norm_backward(self):
25412541
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
25422542
outputs.sum().backward()
25432543

2544+
# Regression test for https://github.com/pytorch/pytorch/issues/133520
2545+
def test_batch_norm_slices(self):
2546+
bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2547+
bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2548+
2549+
x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2550+
x_mps = x_cpu.to('mps')
2551+
2552+
res_cpu = bn_cpu(x_cpu[5:])
2553+
res_mps = bn_mps(x_mps[5:])
2554+
2555+
self.assertEqual(res_cpu, res_mps)
2556+
25442557
def test_layer_norm_backward(self):
25452558
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
25462559
x = torch.nn.LayerNorm(4).to("mps")

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.