File tree Expand file tree Collapse file tree 2 files changed +14
-7
lines changed
Filter options
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree 2 files changed +14
-7
lines changed
Original file line number Diff line number Diff line change @@ -153,12 +153,6 @@ static void get_shapes(MPSShape* input_shape_readonly,
153
153
else
154
154
channelsDim = num_input_dims - 1 ;
155
155
156
- bool executeGatherOp = true ;
157
- if (self.is_contiguous (memory_format)) {
158
- memory_format = MemoryFormat::Contiguous;
159
- executeGatherOp = false ;
160
- }
161
-
162
156
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
163
157
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder (mpsGraph, input_mps_dtype, input_shape);
164
158
MPSGraphTensor* weightTensor = nil ;
@@ -302,7 +296,7 @@ Check if running mean exists (maybe do this check before making graph)
302
296
newCachedGraph->runningVarInplaceUpdate_ = runningVarInplaceUpdate;
303
297
});
304
298
305
- auto inputPlaceholder = Placeholder (cachedGraph->inputTensor_ , self, input_shape, executeGatherOp );
299
+ auto inputPlaceholder = Placeholder (cachedGraph->inputTensor_ , self, input_shape);
306
300
auto weightPlaceholder = Placeholder ();
307
301
if (has_weight)
308
302
weightPlaceholder = Placeholder (cachedGraph->weightTensor_ , weight_opt.value (), new_mean_shape);
Original file line number Diff line number Diff line change @@ -2541,6 +2541,19 @@ def test_batch_norm_backward(self):
2541
2541
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
2542
2542
outputs.sum().backward()
2543
2543
2544
+ # Regression test for https://github.com/pytorch/pytorch/issues/133520
2545
+ def test_batch_norm_slices(self):
2546
+ bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
2547
+ bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
2548
+
2549
+ x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
2550
+ x_mps = x_cpu.to('mps')
2551
+
2552
+ res_cpu = bn_cpu(x_cpu[5:])
2553
+ res_mps = bn_mps(x_mps[5:])
2554
+
2555
+ self.assertEqual(res_cpu, res_mps)
2556
+
2544
2557
def test_layer_norm_backward(self):
2545
2558
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
2546
2559
x = torch.nn.LayerNorm(4).to("mps")
You can’t perform that action at this time.
0 commit comments