Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[Bug]: SIGSEGV segfault when plotting multiple subplots #22478

Copy link
Copy link
Open
@JackKelly

Description

@JackKelly
Issue body actions

Bug summary

🐛 Bug

I'm using matplotlib to create plots whilst training neural networks using PyTorch and pytorch-lightning.

If I create specific, simple matplotlib plots during training_step then I get a segfault after training for several iterations.

I originally raised this issue on the pytorch-lightning issue queue but it was concluded that this error is likely coming from matplotlib rather than from pytorch-lighting.

Specifically, if I call plot_data() from within training_step, then I get a segfault:

def plot_data():
    """Plot random data."""
    fig, axes = plt.subplots(ncols=4)  # See note 1 below.
    N_PERIODS = 16
    x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")  # See note 2 below.
    y = np.ones(N_PERIODS)
    for ax in axes:
        ax.plot(x, y)
    plt.close(fig)

Interestingly, I've found two ways to stop the segfaults. You can do one or the other (or both) of these actions to stop the segfaults:

  1. Change ncols to one of {2, 3, 5, 6, 7, 9, 10}. segfaults only seem to appear if ncols is one of {4, 8, 16, 32}
  2. Convert x from a pd.DatetimeIndex to a numpy array of matplotlib dates by doing x = matplotlib.dates.date2num(x) after line 1.

If I run gdb --args python script.py I get this:

Epoch 0:   1%|█▏| 7/1024 [00:01<02:52,  5.90it/s, loss=0.281, v_num=34]
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x000055555568bb4b in _PyObject_GC_UNTRACK_impl (filename=0x55555587fef0 
"/home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Modules/gcmodule.c", lineno=2236, 
op=0x7fffd05f5e00) at /home/conda/feedstock_root/build_artifacts/python-
split_1643749964416/work/Include/internal/pycore_object.h:76
76  /home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Include/internal/pycore_object.h:
 No such file or directory.

The minimal code example below seems to always crash whilst training on epoch 7.

In my "real" code, the crash was intermittent.

If ncols is large (e.g. 32) then the crash changes randomly between:

  • SIGSEGV segfault
  • AttributeError: 'weakref' object has no attribute 'grad_fn'
  • AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'

The two AttributeErrors originate from the same bit of code: line 213 of torch/optim/optimizer.py: if p.grad.grad_fn is not None:. My guess is that the SIGSEGV is the root problem, and the AttributeErrors are symptoms of the underlying SIGSEGV error. I have (once) seen the code compain, at the same time, about a SIGSEGV and an AttributeError.

Please see this comment for some debugging that @justusschock did this morning.

Code for reproduction

These scripts are in a tiny GitHub repo for this bug report: https://github.com/JackKelly/pytorch-lighting-segfault

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl


START_DATE = pd.Timestamp("2020-01-01")
N_EXAMPLES_PER_BATCH = 32
N_FEATURES = 1


class MyDataset(Dataset):
    def __init__(self, n_batches_per_epoch: int):
        self.n_batches_per_epoch = n_batches_per_epoch

    def __len__(self):
        return self.n_batches_per_epoch

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
        x = torch.rand(N_EXAMPLES_PER_BATCH, N_FEATURES, 1)
        y = torch.rand(N_EXAMPLES_PER_BATCH, 1)
        return x, y


class LitNeuralNetwork(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.nn = nn.Linear(in_features=N_FEATURES, out_features=1)

    def forward(self, x):
        x = self.flatten(x)
        return self.nn(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        plot_data()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


def plot_data():
    """Plot random data."""
    # ncols needs to be 4 or higher to trigger a segfault.
    fig, axes = plt.subplots(ncols=4)
    N_PERIODS = 16
    x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")
    # The segfaults go away if I do:
    # x = mdates.date2num(x)
    y = np.ones(N_PERIODS)
    for ax in axes:
        ax.plot(x, y)
    plt.close(fig)


dataloader = DataLoader(
    MyDataset(n_batches_per_epoch=1024),
    batch_size=None,
    num_workers=2,
)

model = LitNeuralNetwork()
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloader=dataloader)

Here's my environment.yml file:

name: segfault
channels:
  - conda-forge
  - pytorch
dependencies:
  - matplotlib
  - pandas
  - pytorch
  - cpuonly 
  - pytorch-lightning

Actual outcome

If ncols is large (e.g. 32) then the crash changes randomly between:

  • SIGSEGV segfault
  • AttributeError: 'weakref' object has no attribute 'grad_fn'
  • AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'

Expected outcome

No segfault 🙂

Additional information

I have tried and failed to reproduce this problem in two simpler scripts:

  • A script which just uses matplotlib and pandas (no pytorch. no pytorch-lightning)
  • A script which just uses matplotlib, pandas and pytorch (no pytorch-lightning).

Operating system

Ubuntu 21.10

Matplotlib Version

3.5.1

Matplotlib Backend

QtAgg

Python version

3.9.10

Jupyter version

No response

Installation

conda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Morty Proxy This is a proxified and sanitized view of the page, visit original site.