Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

MPS memory issue, MPS backend out of memory, but works if I empty the MPS cache#105839

Copy link
Copy link
@Vargol

Description

@Vargol
Issue body actions

馃悰 Describe the bug

There appears to be something wrong the the MPS cache, I appears that either its not releasing memory when it ideally should be, or the freeable memory in the cache is not being taken into account when the check for space occurs.
The issue occurs on the currently nightly, see versions, and 2.0.1

This issue affects performance at best and terminates an application at worse.

Here's an example...

from diffusers import KandinskyV22PriorPipeline, KandinskyV22Pipeline
from torch import mps
import torch
import fp16fixes
import gc

fp16fixes.fp16_fixes()

pipe_prior = KandinskyV22PriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-prior", torch_dtype=torch.float16)
pipe_prior.to("mps")
prompt = "A car exploding into colorful dust"
out = pipe_prior(prompt)
image_emb = out.image_embeds
zero_image_emb = out.negative_image_embeds

pipe_prior = None
gc.collect()
mps.empty_cache()

pipe = KandinskyV22Pipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16)
pipe.to("mps")
pipe.enable_attention_slicing()

image = pipe(
    image_embeds=image_emb,
    negative_image_embeds=zero_image_emb,
    height=1024,
    width=1024,
    num_inference_steps=30,
).images

image[0].save("cat.png")

This works on a 8GB M1 Mac Mini without issue the two models run at

100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 25/25 [00:07<00:00,  3.15it/s]
100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 30/30 [04:24<00:00,  8.82s/it]

Remove the mps.empty_cache() and it fails during the second model run

  0%|                                                                                                                                    | 0/30 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/8GB_M1_Diffusers_Scripts/sag/k2img.py", line 25, in <module>
    image = pipe(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py", line 272, in __call__
    noise_pred = self.unet(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 905, in forward
    sample, res_samples = downsample_block(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/unet_2d_blocks.py", line 1662, in forward
    hidden_states = attn(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 321, in forward
    return self.processor(
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 1590, in __call__
    attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
  File "/Volumes/Sabrent Media/Documents/Source/Python/Diffusers/lib/python3.10/site-packages/diffusers/models/attention_processor.py", line 374, in get_attention_scores
    attention_probs = attention_scores.softmax(dim=-1)
RuntimeError: MPS backend out of memory (MPS allocated: 3.90 GB, other allocations: 4.94 GB, max allowed: 9.07 GB). Tried to allocate 387.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

If I reduce the height and width values to 512 it'll run to completion but the second model runs at 40 seconds per iter with a lot of swap file access. With the cache emptied manually it runs at around 2 seconds per iter.

the fp16fixes file is required to work around some issues with using fp16 on mps which fails with a broadcast error on 2.0.1 and fails with a bad image on the nightly I'm currently using. If I remove it the issue still occurs on the nightly.

% cat fp16fixes.py 
import torch

def fp16_fixes():
  if torch.backends.mps.is_available():
      torch.empty = torch.zeros

  _torch_layer_norm = torch.nn.functional.layer_norm
  def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
      if input.device.type == "mps" and input.dtype == torch.float16:
          input = input.float()
          if weight is not None:
              weight = weight.float()
          if bias is not None:
              bias = bias.float()
          return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
      else:
          return _torch_layer_norm(input, normalized_shape, weight, bias, eps)

  torch.nn.functional.layer_norm = new_layer_norm


  def new_torch_tensor_permute(input, *dims):
      result = torch.permute(input, tuple(dims))
      if input.device == "mps" and input.dtype == torch.float16:
          result = result.contiguous()
      return result

  torch.Tensor.permute = new_torch_tensor_permute

Versions

Collecting environment information...
PyTorch version: 2.1.0.dev20230724
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.4.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.3 (clang-1403.0.22.14.1)
CMake version: version 3.24.4
Libc version: N/A

Python version: 3.10.11 (main, Apr 8 2023, 02:11:11) [Clang 14.0.0 (clang-1400.0.29.202)] (64-bit runtime)
Python platform: macOS-13.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] numpy==1.25.1
[pip3] torch==2.1.0.dev20230724
[pip3] torchvision==0.15.2
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

thuduyen07, CTimmerman, janosh, smartsastram, liam-kadence and 5 moremalfet, giornogiovannya, smartsastram, satvik-1945 and minhquoc0712

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: memory usagePyTorch is using more memory than it should, or it is leaking memoryPyTorch is using more memory than it should, or it is leaking memorymodule: mpsRelated to Apple Metal Performance Shaders frameworkRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.