from __future__ import annotations
import torch
import torch.nn as nn
from torch import Tensor
import warnings
from typing import Any
from numpy import ndarray
from tqdm import tqdm
from deepinv.utils.compat import zip_strict
[docs]
class SDEOutput(dict):
r"""
A container for storing the output of an SDE solver, that behaves like a `dict` but allows access with the attribute syntax.
Attributes:
:attr torch.Tensor sample: the final samples of the sampling process, of shape ``(B, C, H, W)``.
:attr torch.Tensor trajectory: the trajectory of the sampling process, of shape ``(num_steps, B, C, H, W)`` if ``full_trajectory`` is ``True``, otherwise of shape ``(B, C, H, W)``.
:attr torch.Tensor timesteps: the time steps at which the samples were taken, of shape ``(num_steps,)``.
:attr int nfe: the number of function evaluations performed during the integration.
"""
def __init__(self, sample: Tensor, trajectory: Tensor, timesteps: Tensor, nfe: int):
sol = {
"sample": sample,
"trajectory": trajectory,
"timesteps": timesteps,
"nfe": nfe,
}
super().__init__(sol)
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
[docs]
class BaseSDESolver(nn.Module):
r"""
Base class for solving Stochastic Differential Equations (SDEs) from :class:`deepinv.sampling.BaseSDE` of the form:
.. math::
d x_{t} = f(x_t, t) dt + g(t) d w_{t}
where :math:`f` is the drift term, :math:`g` is the diffusion coefficient, and :math:`w_t` is a standard Brownian process.
Currently only supported for fixed time steps for numerical integration.
:param torch.Tensor, numpy.ndarray, list timesteps: time steps at which the SDE will be discretized.e.
:param torch.Generator rng: a random number generator for reproducibility, optional.
:param bool verbose: whether to display a progress bar during the sampling process, optional. Default to False.
"""
def __init__(
self,
timesteps: Tensor | ndarray,
rng: torch.Generator | None = None,
):
super().__init__()
if isinstance(timesteps, ndarray):
self.timesteps = torch.from_numpy(timesteps.copy())
elif isinstance(timesteps, Tensor):
self.timesteps = timesteps
self.rng = rng
if rng is not None:
self.initial_random_state = rng.get_state()
self.timesteps = self.timesteps.to(rng.device)
[docs]
def step(self, sde, t0: float, t1: float, x0: Tensor, *args, **kwargs) -> Tensor:
r"""
Perform a single step with step size from time `t0` to time `t1`, with current state `x0`.
:param deepinv.sampling.BaseSDE sde: the SDE to solve.
:param float or torch.Tensor t0: Time at the start of the step, of size (,).
:param float or torch.Tensor t1: Time at the end of the step, of size (,).
:param torch.Tensor x0: Current state of the system, of size (batch_size, d).
:return: Updated state of the system after the step.
:rtype: torch.Tensor
"""
raise NotImplementedError
[docs]
@torch.no_grad()
def sample(
self,
sde,
x_init: Tensor,
seed: int = None,
*args,
timesteps: Tensor | ndarray = None,
get_trajectory: bool = False,
verbose: bool = False,
**kwargs,
) -> SDEOutput:
r"""
Solve the Stochastic Differential Equation (SDE) with given time steps.
This function iteratively applies the SDE solver step for each time interval
defined by the provided timesteps.
:param deepinv.sampling.BaseSDE sde: the SDE to solve.
:param torch.Tensor x_init: The initial state of the system.
:param int seed: The seed for the random number generator, if `rng` is provided.
:param torch.Tensor, numpy.ndarray, list timesteps: A sequence of time points at which to solve the SDE. If None, default timesteps will be used.
:param bool get_trajectory: whether to return the full trajectory of the SDE or only the last sample, optional. Default to False.
:param bool verbose: whether to display a progress bar during the sampling process, optional. Default to False.
:param \*args: Variable length argument list to be passed to the step function.
:param \*\*kwargs: Arbitrary keyword arguments to be passed to the step function.
:return: SDEOutput
"""
self.rng_manual_seed(seed)
x = x_init
nfe = 0
trajectory = [x_init.clone()] if get_trajectory else []
if timesteps is None:
timesteps = self.timesteps.to(sde.device, sde.dtype)
else:
if isinstance(timesteps, ndarray):
timesteps = torch.from_numpy(timesteps.copy())
timesteps = timesteps.to(sde.device, sde.dtype)
for t_cur, t_next in tqdm(
zip_strict(timesteps[:-1], timesteps[1:]),
total=len(timesteps) - 1,
disable=not verbose,
):
x, cur_nfe = self.step(sde, t_cur, t_next, x, *args, **kwargs)
nfe += cur_nfe
if get_trajectory:
trajectory.append(x.clone())
if get_trajectory:
trajectory = torch.stack(trajectory, dim=0)
else:
trajectory = x
output = SDEOutput(
sample=x, trajectory=trajectory, timesteps=timesteps, nfe=nfe
)
return output
[docs]
def rng_manual_seed(self, seed: int = None):
r"""
Sets the seed for the random number generator.
:param int seed: the seed to set for the random number generator. If not provided, the current state of the random number generator is used.
Note: it will be ignored if the random number generator is not initialized.
"""
if seed is not None:
if self.rng is not None:
self.rng = self.rng.manual_seed(seed)
else:
warnings.warn(
"Cannot set seed for random number generator because it is not initialized. The `seed` parameter is ignored."
)
[docs]
def reset_rng(self):
r"""
Reset the random number generator to its initial state.
"""
self.rng.set_state(self.initial_random_state)
[docs]
def randn_like(self, input: torch.Tensor, seed: int = None):
r"""
Equivalent to :func:`torch.randn_like` but supports a pseudorandom number generator argument.
:param torch.Tensor input: The input tensor whose size will be used.
:param int seed: The seed for the random number generator, if `rng` is provided.
:return: A tensor of the same size as input filled with random numbers from a normal distribution.
:rtype: torch.Tensor
This method uses the `rng` attribute of the class, which is a pseudo-random number generator
for reproducibility. If a seed is provided, it will be used to set the state of `rng` before
generating the random numbers.
.. note::
The `rng` attribute must be initialized for this method to work properly.
"""
self.rng_manual_seed(seed)
return torch.empty_like(input).normal_(generator=self.rng)
[docs]
class EulerSolver(BaseSDESolver):
r"""
Euler-Maruyama solver for SDEs.
This solver uses the Euler-Maruyama method to numerically integrate SDEs. It is a first-order method that
approximates the solution using the following update rule:
.. math::
x_{t+dt} = x_t + f(x_t,t)dt + g(t) W_{dt}
where :math:`W_t` is a Gaussian random variable with mean 0 and variance dt.
:param torch.Tensor timesteps: The time steps at which to evaluate the solution.
:param torch.Generator rng: A random number generator for reproducibility.
"""
def __init__(self, timesteps, rng: torch.Generator = None):
super().__init__(timesteps, rng=rng)
def step(self, sde, t0, t1, x0: Tensor, *args, **kwargs):
dt = abs(t1 - t0)
dW = self.randn_like(x0) * dt**0.5
drift, diffusion = sde.discretize(x0, t0, *args, **kwargs)
return x0 + drift * dt + diffusion * dW, 1
[docs]
class HeunSolver(BaseSDESolver):
r"""
Heun solver for SDEs.
This solver uses the second-order Heun method to numerically integrate SDEs, defined as:
.. math::
\tilde{x}_{t+dt} &= x_t + f(x_t,t)dt + g(t) W_{dt} \\
x_{t+dt} &= x_t + \frac{1}{2}[f(x_t,t) + f(\tilde{x}_{t+dt},t+dt)]dt + \frac{1}{2}[g(t) + g(t+dt)] W_{dt}
where :math:`W_t` is a Gaussian random variable with mean 0 and variance dt.
:param torch.Tensor timesteps: The time steps at which to evaluate the solution.
:param torch.Generator rng: A random number generator for reproducibility.
"""
def __init__(
self,
timesteps,
rng: torch.Generator = None,
):
super().__init__(timesteps, rng=rng)
def step(self, sde, t0, t1, x0: Tensor, *args, **kwargs):
dt = abs(t1 - t0)
dW = self.randn_like(x0) * dt**0.5
drift_0, diffusion_0 = sde.discretize(x0, t0, *args, **kwargs)
x_euler = x0 + drift_0 * dt + diffusion_0 * dW
drift_1, diffusion_1 = sde.discretize(x_euler, t1, *args, **kwargs)
return (
x0
+ 0.5 * (drift_0 + drift_1) * dt
+ 0.5 * (diffusion_0 + diffusion_1) * dW,
2,
)