Description
Bug summary
🐛 Bug
I'm using matplotlib to create plots whilst training neural networks using PyTorch and pytorch-lightning.
If I create specific, simple matplotlib
plots during training_step
then I get a segfault after training for several iterations.
I originally raised this issue on the pytorch-lightning issue queue but it was concluded that this error is likely coming from matplotlib
rather than from pytorch-lighting
.
Specifically, if I call plot_data()
from within training_step
, then I get a segfault:
def plot_data():
"""Plot random data."""
fig, axes = plt.subplots(ncols=4) # See note 1 below.
N_PERIODS = 16
x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min") # See note 2 below.
y = np.ones(N_PERIODS)
for ax in axes:
ax.plot(x, y)
plt.close(fig)
Interestingly, I've found two ways to stop the segfaults. You can do one or the other (or both) of these actions to stop the segfaults:
- Change
ncols
to one of{2, 3, 5, 6, 7, 9, 10}
. segfaults only seem to appear ifncols
is one of{4, 8, 16, 32}
- Convert
x
from apd.DatetimeIndex
to a numpy array of matplotlib dates by doingx = matplotlib.dates.date2num(x)
after line 1.
If I run gdb --args python script.py
I get this:
Epoch 0: 1%|█▏| 7/1024 [00:01<02:52, 5.90it/s, loss=0.281, v_num=34]
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x000055555568bb4b in _PyObject_GC_UNTRACK_impl (filename=0x55555587fef0
"/home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Modules/gcmodule.c", lineno=2236,
op=0x7fffd05f5e00) at /home/conda/feedstock_root/build_artifacts/python-
split_1643749964416/work/Include/internal/pycore_object.h:76
76 /home/conda/feedstock_root/build_artifacts/python-split_1643749964416/work/Include/internal/pycore_object.h:
No such file or directory.
The minimal code example below seems to always crash whilst training on epoch 7.
In my "real" code, the crash was intermittent.
If ncols
is large (e.g. 32) then the crash changes randomly between:
SIGSEGV
segfaultAttributeError: 'weakref' object has no attribute 'grad_fn'
AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'
The two AttributeErrors
originate from the same bit of code: line 213 of torch/optim/optimizer.py
: if p.grad.grad_fn is not None:
. My guess is that the SIGSEGV
is the root problem, and the AttributeErrors
are symptoms of the underlying SIGSEGV
error. I have (once) seen the code compain, at the same time, about a SIGSEGV
and an AttributeError
.
Please see this comment for some debugging that @justusschock did this morning.
Code for reproduction
These scripts are in a tiny GitHub repo for this bug report: https://github.com/JackKelly/pytorch-lighting-segfault
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
START_DATE = pd.Timestamp("2020-01-01")
N_EXAMPLES_PER_BATCH = 32
N_FEATURES = 1
class MyDataset(Dataset):
def __init__(self, n_batches_per_epoch: int):
self.n_batches_per_epoch = n_batches_per_epoch
def __len__(self):
return self.n_batches_per_epoch
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
x = torch.rand(N_EXAMPLES_PER_BATCH, N_FEATURES, 1)
y = torch.rand(N_EXAMPLES_PER_BATCH, 1)
return x, y
class LitNeuralNetwork(pl.LightningModule):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.nn = nn.Linear(in_features=N_FEATURES, out_features=1)
def forward(self, x):
x = self.flatten(x)
return self.nn(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
plot_data()
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def plot_data():
"""Plot random data."""
# ncols needs to be 4 or higher to trigger a segfault.
fig, axes = plt.subplots(ncols=4)
N_PERIODS = 16
x = pd.date_range(START_DATE, periods=N_PERIODS, freq="30 min")
# The segfaults go away if I do:
# x = mdates.date2num(x)
y = np.ones(N_PERIODS)
for ax in axes:
ax.plot(x, y)
plt.close(fig)
dataloader = DataLoader(
MyDataset(n_batches_per_epoch=1024),
batch_size=None,
num_workers=2,
)
model = LitNeuralNetwork()
trainer = pl.Trainer()
trainer.fit(model=model, train_dataloader=dataloader)
Here's my environment.yml
file:
name: segfault
channels:
- conda-forge
- pytorch
dependencies:
- matplotlib
- pandas
- pytorch
- cpuonly
- pytorch-lightning
Actual outcome
If ncols
is large (e.g. 32) then the crash changes randomly between:
SIGSEGV
segfaultAttributeError: 'weakref' object has no attribute 'grad_fn'
AttributeError: 'builtin_function_or_method' object has no attribute 'grad_fn'
Expected outcome
No segfault 🙂
Additional information
I have tried and failed to reproduce this problem in two simpler scripts:
- A script which just uses matplotlib and pandas (no pytorch. no pytorch-lightning)
- A script which just uses matplotlib, pandas and pytorch (no pytorch-lightning).
Operating system
Ubuntu 21.10
Matplotlib Version
3.5.1
Matplotlib Backend
QtAgg
Python version
3.9.10
Jupyter version
No response
Installation
conda