-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Description
馃悰 Describe the bug
When using torch.compile(backend="inductor") on MPS devices, F.adaptive_max_pool1d and F.adaptive_max_pool2d produce incorrect results with significant numerical divergence (diff > 1.0) compared to Eager mode.
This behavior is observed in both 1D and 2D cases.
Reproduce script
import torch
import torch.nn.functional as F
def fn(x):
return F.adaptive_max_pool1d(x, output_size=3)
x = torch.randn(4, 10, 8, device="mps")
# def fn(x):
# return F.adaptive_max_pool2d(x, output_size=(3, 3))
# x = torch.randn(4, 10, 8, 8,device='mps')
eager_out = fn(x)
opt_fn = torch.compile(fn, backend="inductor")
try:
compiled_out = opt_fn(x)
diff = (eager_out - compiled_out).abs().max().item()
print(f"Max Difference: {diff}")
except Exception as e:
print(f"Crashed during execution: {e}")output
Max Difference: 2.4042608737945557What's more
I found
pytorch/test/inductor/test_torchinductor.py
Lines 5030 to 5065 in d3944da
| @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device | |
| def test_adaptive_avg_pool2d2(self): | |
| # Big kernel size, use fallback | |
| def fn(x): | |
| return aten._adaptive_avg_pool2d(x, (4, 4)) | |
| torch._inductor.metrics.generated_kernel_count = 0 | |
| self.common( | |
| fn, | |
| (torch.randn(2, 4, 21, 21),), | |
| check_lowp=False, | |
| ) | |
| assertGeneratedKernelCountEqual(self, 0) | |
| @xfail_if_mps | |
| @skip_if_gpu_halide # slow | |
| def test_adaptive_max_pool2d1(self): | |
| def fn(x): | |
| return aten.adaptive_max_pool2d(x, (6, 6)) | |
| self.common( | |
| fn, | |
| (torch.randn(2, 4, 16, 16),), | |
| check_lowp=False, | |
| ) | |
| self.common( | |
| fn, | |
| (torch.randn(2, 4, 3, 3),), | |
| ) | |
| # no-op case | |
| self.common( | |
| fn, | |
| (torch.randn(2, 4, 6, 6),), | |
| ) |
Versions
PyTorch version: 2.10.0.dev20251202
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 26.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.0 (clang-1700.4.4.1)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.12 (main, Oct 28 2025, 11:52:25) [Clang 20.1.4 ] (64-bit runtime)
Python platform: macOS-26.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M4
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
cc @kulinseth @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu