# Required so ArrayLike does not look terrible in docs
from __future__ import annotations
import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable
import numpy as np
import scipy.stats
from numpy.random import Generator
from numpy.typing import ArrayLike
from scipy.special import erf, erfinv
from scipy.stats import loguniform, norm, rv_continuous, truncnorm, uniform
import simpple.utils as ut
def _handle_scipy_yaml_dict(yaml_dict: dict) -> dict:
"""
Handle as many corner cases as possible to find the scipy distribution
in a yaml dict and convert from string to Python object
:param yaml_dict: Yaml dictionary, typically loaded directly from a file
:return: A copy of the yaml_dict with the scipy distribution name replaced by its class
"""
yaml_dict = copy.deepcopy(yaml_dict)
args = yaml_dict.get("args", [])
kwargs = yaml_dict.get("kwargs", {})
if len(args) > 0:
if isinstance(args, list):
k = 0
elif isinstance(args, tuple):
k = 0
# Cannot assign to tuples so convert to list
yaml_dict["args"] = list(args)
elif isinstance(args, dict):
k = "dist"
else:
raise TypeError(
"args for ScipyDistribution should be a list, tuple or a dict"
)
yaml_dict["args"][k] = getattr(scipy.stats, args[k])
elif "dist" in kwargs:
yaml_dict["kwargs"]["dist"] = getattr(scipy.stats, kwargs["dist"])
else:
raise ValueError(
"ScipyDistribution should have a distribution specified in 'args' or 'kwargs'"
)
return yaml_dict
[docs]
class Distribution(ABC):
"""Abstract base class for distributions.
Defines the interface distributions must implement.
All distributions should have a `__repr__()` showing initialization parameters,
as well as log_prob(), `prior_transform()` and `sample()` methods.
"""
@abstractmethod
def __repr__(self) -> str:
pass
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return ut.make_hashable(self.__dict__) == ut.make_hashable(other.__dict__)
def __hash__(self):
return hash(ut.make_hashable(self.__dict__))
@property
def required_args(self) -> list[str]:
"""List of required arguments at initialization, generated with :func:`simpple.utils.find_args`"""
return ut.find_args(self, argtype="args")
@property
def optional_args(self) -> list[str]:
"""List of optional (keyword) arguments at initialization, generated with :func:`simpple.utils.find_args`"""
return ut.find_args(self, argtype="kwargs")
[docs]
@classmethod
def from_yaml_dict(cls, yaml_dict: dict):
"""Create a distribution from a YAML dictionary
The dictionary should have a ``dist`` key with the class name,
an ``args`` key with a list of required arguments
and optionally a ``kwargs`` key with optional arguments.
See also: :doc:`Writing Models to and from YAML Files <../tutorials/yaml>`.
:param yaml_dict: YAML dictionary typically directly loaded from a file
:return: A distribution of the type specified by the "dist" key in the dictionary
"""
if "dist" not in yaml_dict:
raise KeyError(
"Distribution dictionaries should have a 'dist' key for the distribution type"
)
dist = yaml_dict["dist"]
all_distributions = ut.get_subclasses(Distribution)
if dist not in all_distributions:
raise KeyError(
f"Distribution name '{dist}' not found. Available distributions are {all_distributions.keys()}"
)
if dist == "ScipyDistribution":
yaml_dict = _handle_scipy_yaml_dict(yaml_dict)
dist_cls = all_distributions[dist]
args = yaml_dict.get("args", [])
kwargs = yaml_dict.get("kwargs", {})
if isinstance(args, (list, tuple)):
return dist_cls(*args, **kwargs)
elif isinstance(args, dict):
return dist_cls(**args, **kwargs)
else:
raise TypeError(
"Distribution arguments should be a list, a tuple or a dict "
)
[docs]
def to_yaml_dict(self) -> dict:
"""Generate YAML dict
The dictionary will have a ``dist`` key with the class name,
an ``args`` key with a list of required arguments
and a ``kwargs`` key with optional arguments.
:attr:`~required_args` and :attr:`~optional_args` are used to infer the ``args`` and ``kwargs``.
See also: :doc:`Writing Models to and from YAML Files <../tutorials/yaml>`.
:return: YAML dictionary describing the distribution.
"""
yaml_dict = {}
yaml_dict["dist"] = self.__class__.__name__
yaml_dict["args"] = [getattr(self, arg) for arg in self.required_args]
yaml_dict["kwargs"] = {
kwarg: getattr(self, kwarg) for kwarg in self.optional_args
}
return yaml_dict
[docs]
@abstractmethod
def log_prob(self, x: float | ArrayLike) -> float | np.ndarray:
"""Log probability density of the distribution
This should be implemented by subclasses.
"""
pass
[docs]
def sample(
self,
size: int | None = None,
seed: int | np.ndarray[int] | None = None,
) -> np.ndarray:
"""Draw samples from the distribution
Generates random uniform samples and calls the prior transform.
:param size: Shape of the samples
:param seed: Random seed to use
:return: Random samples with shape `size`
"""
rng = np.random.default_rng(seed=seed)
u = rng.uniform(low=0.0, high=1.0, size=size)
p = self.prior_transform(u)
return p
[docs]
class ScipyDistribution(Distribution):
"""Distribution based on a scipy random variable.
.. note::
While this class can be convenient, custom ``log_prob()`` and ``prior_transform()`` functions written with Numpy are usually faster.
:param dist: Scipy random variable (RV), either already instantiated or not.
If the RV is not instantiated, it will be during init and all args
and kwargs are passed to it.
"""
def __init__(self, dist: rv_continuous | Callable, *args, **kwargs):
if hasattr(dist, "dist") and hasattr(dist, "args"):
if len(kwargs) > 0 or len(args) > 0:
warnings.warn(
"The scipy distribution is 'frozen' (already instantiated). "
"Extra arguments will be ignored",
category=RuntimeWarning,
stacklevel=2,
)
self._dist = dist
elif isinstance(dist, rv_continuous):
if len(args) == 0 and len(kwargs) == 0:
warnings.warn(
"The scipy distribution has no arguments and will use the default parameters."
"Make sure this is what you want.",
category=RuntimeWarning,
stacklevel=2,
)
self._dist = dist(*args, **kwargs)
else:
raise TypeError(f"Invalid type {type(dist)} for the scipy distribution.")
@property
def dist(self):
"""Underlying scipy distribution"""
return self._dist
def __repr__(self):
args_tuple = self.dist.args
kwargs_tuple = tuple(f"{k}={v}" for k, v in self.dist.kwds.items())
signature_tuple = args_tuple + kwargs_tuple
return f"ScipyDistribution({self.dist.dist.name}{signature_tuple})"
[docs]
def to_yaml_dict(self) -> dict:
"""Generate YAML dict
Serves the same purpose as :meth:`Distribution.to_yaml_dict()`, but the arguments
and keyword arguments are inferred from the :attr:`~dist` attribute.
:return: YAML dictionary describing the distribution
"""
if type(self) is not ScipyDistribution:
return Distribution.to_yaml_dict(self)
yaml_dict = {}
yaml_dict["dist"] = self.__class__.__name__
dist_name = self.dist.dist.name
dist_args = self.dist.args
dist_kwargs = self.dist.kwds
yaml_dict["args"] = [dist_name] + list(dist_args)
yaml_dict["kwargs"] = dist_kwargs
return yaml_dict
[docs]
def log_prob(self, x: float | ArrayLike, *args, **kwargs) -> np.ndarray:
"""Log probability density function.
Calls the scipy distribution's `logpdf()`.
:param x: Value(s) at which to evaluate the log probability.
:return: Log probability value(s)
"""
return self.dist.logpdf(x, *args, **kwargs)
[docs]
def sample(
self,
size: int | None = None,
seed: int | Generator | np.ndarray[int] | None = None,
**kwargs,
) -> np.ndarray:
"""Draw samples from the distribution
Calls the scipy distribution's ``rvs()`` method.
:param size: Shape of the samples
:param seed: Random seed to use (int or numpy generator)
:return: Random samples with shape `size`
"""
if not isinstance(seed, Generator):
seed = np.random.default_rng(seed=seed)
return self.dist.rvs(size=size, random_state=seed, **kwargs)
# TODO: Consider not making most distribution scipy sublcasses but just using dist attribute?
# One downside is that workarounds implemented here (example to_yaml being different for scipy subclasses) can be unexpected if users subclass distribution.
# So regardless we need to work around that.
# Then would need to ignore or handle _dist in comparison.
[docs]
class Normal(ScipyDistribution):
r"""Normal distribution
.. math::
p(x) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left( -\frac{(x - \mu)^2}{2\sigma^2} \right)
:param mu: Mean :math:`\mu`
:param sigma: Standard deviation :math:`\sigma`
"""
def __init__(self, mu: float, sigma: float):
self.mu = mu
self.sigma = sigma
if self.sigma <= 0.0:
raise ValueError(
"Standard deviation sigma must be positive for normal distribution"
)
super().__init__(norm(self.mu, self.sigma))
[docs]
def log_prob(self, x: float | ArrayLike, *args, **kwargs) -> np.ndarray:
"""Log probability density function of the normal distribution
:param x: Value(s) at which to evaluate the log probability.
:return: Log probability value(s)
"""
return -0.5 * (
np.log(2 * np.pi * self.sigma**2) + ((x - self.mu) / self.sigma) ** 2
)
def __repr__(self) -> str:
return f"Normal(mu={self.mu}, sigma={self.sigma})"
[docs]
class TruncatedNormal(ScipyDistribution):
r"""
.. math::
f(x; \mu, \sigma, a, b) =
\begin{cases}
\dfrac{1}{\sigma} \dfrac{\phi\left(\frac{x - \mu}{\sigma}\right)}
{\Phi\left(\frac{b - \mu}{\sigma}\right) - \Phi\left(\frac{a - \mu}{\sigma}\right)}
& \text{if } a \leq x \leq b \\
0 & \text{otherwise}
\end{cases}
where :math:`\phi` is the standard normal distribution and :math:`\Phi` the standard normal CDF.
See `Truncated normal distribution on Wikipedia <https://en.wikipedia.org/wiki/Truncated_normal_distribution>`_.
"""
def __init__(
self,
mu: float,
sigma: float,
low: float | None = None,
high: float | None = None,
):
self.mu = mu
self.sigma = sigma
self.low = -np.inf if low is None else low
self.high = np.inf if high is None else high
if self.sigma <= 0.0:
raise ValueError(
"Standard deviation sigma must be positive for truncated normal distribution"
)
if self.low > self.high:
raise ValueError(
"lower bound should be lower than upper bound for uniform distribution"
)
a, b = (self.low - self.mu) / self.sigma, (self.high - self.mu) / self.sigma
super().__init__(truncnorm(a=a, b=b, loc=self.mu, scale=self.sigma))
def __repr__(self) -> str:
arg_signature = f"mu={self.mu}, sigma={self.sigma}"
if np.isfinite(self.low):
arg_signature += f", low={self.low}"
if np.isfinite(self.high):
arg_signature += f", high={self.high}"
return f"TruncatedNormal({arg_signature})"
[docs]
def log_prob(self, x: float | ArrayLike, *args, **kwargs) -> np.ndarray:
"""Log probability density function.
:param x: Value(s) at which to evaluate the log probability.
:return: Log probability value(s)
"""
phi_low = 0.5 * (1 + erf((self.low - self.mu) / self.sigma / np.sqrt(2)))
phi_high = 0.5 * (1 + erf((self.high - self.mu) / self.sigma / np.sqrt(2)))
log_norm = -0.5 * (
np.log(2 * np.pi * self.sigma**2) + ((x - self.mu) / self.sigma) ** 2
)
lp = np.where(
np.logical_and(np.greater_equal(x, self.low), np.less(x, self.high)),
log_norm - np.log(phi_high - phi_low),
-np.inf,
)
if lp.ndim == 0:
return lp.item()
return lp
[docs]
class Fixed(Distribution):
"""Fixed distribution
This is essentially a delta function centered on ``value``.
Is handled as a special case in :class:`simpple.model.Model` to avoid inefficient sampling.
:param value: Value to which the parameter is fixed.
"""
def __init__(self, value: float):
self.value = value
def __repr__(self) -> str:
return f"Fixed(value={self.value})"
[docs]
def log_prob(self, x: float | ArrayLike) -> float | np.narray:
"""Log probability of a fixed variable.
Returns ``np.inf`` at ``value`` and ``-np.inf`` elsewhere.
:param x: Value(s) at which to evaluate the log probability.
:return: Log probability value(s)
"""
lp = np.where(
x == self.value,
np.inf,
-np.inf,
)
if lp.ndim == 0:
return lp.item()
return lp