from __future__ import annotations
from pathlib import Path
from collections.abc import Callable
from typing import TYPE_CHECKING
import numpy as np
import yaml
from numpy.random import Generator
from numpy.typing import ArrayLike
from simpple.distributions import Distribution, Fixed
from simpple.load import (
get_func_str,
parse_parameters,
resolve,
unparse_parameters,
)
import simpple.utils as ut
if TYPE_CHECKING:
from nautilus import Prior
[docs]
class Model:
"""Simpple model
.. note::
Subclasses should call ``super().__init__()`` to ensure that the ``fixed_p`` and ``vary_p`` attributes are properly set.
See :doc:`Writing Model Classes <../tutorials/writing-model-classes>`.
:param parameters: dictionary of parameters mapping parameter names to a
prior (``simpple.Distribution`` object).
:param log_likelihood: log-likelihood function that accepts a dictionary of parameters.
"""
def __init__(
self,
parameters: dict[str, Distribution],
log_likelihood: Callable | None = None,
):
self.parameters = parameters
if log_likelihood is not None:
self._log_likelihood = log_likelihood
# Attempt to make log-likelihood inherit docstring.
# Works at runtime but not for static (LSP) tools
self.log_likelihood.__func__.__doc__ = self._log_likelihood.__doc__
self.fixed_p = {}
self.fixed_p_vals = {}
self.vary_p = {}
for pname, pdist in self.parameters.items():
if isinstance(pdist, Fixed):
self.fixed_p[pname] = pdist
self.fixed_p_vals[pname] = pdist.value
else:
self.vary_p[pname] = pdist
def __repr__(self):
class_name = self.__class__.__name__
prior_str = f"parameters={self.parameters}"
try:
likelihood_name = getattr(
self._log_likelihood, "__name__", "_log_likelihood"
)
likelihood_str = f", log_likelihood={likelihood_name}"
except RecursionError:
likelihood_str = ""
return f"{class_name}({prior_str}{likelihood_str})"
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return ut.make_hashable(self.__dict__) == ut.make_hashable(other.__dict__)
def __hash__(self):
return hash(ut.make_hashable(self.__dict__))
@property
def required_args(self) -> list[str]:
"""List of required arguments at initialization, generated with :func:`simpple.utils.find_args`"""
return ut.find_args(self, argtype="args")
@property
def optional_args(self) -> list[str]:
"""List of optional (keyword) arguments at initialization, generated with :func:`simpple.utils.find_args`"""
return ut.find_args(self, argtype="kwargs")
[docs]
@classmethod
def from_yaml(cls, path: Path | str, *args, **kwargs) -> Model:
"""Create a model from a YAML file
The YAML file should have:
- A ``class`` field mapping to the name of the Model class
- A ``parameters`` field mapping parameter names to YAML distribution specs.
See also :meth:`simpple.distributions.Distribution.from_yaml_dict()`.
- An ``args`` field listing the required arguments for initialization.
Required only for model subclasses that accept arguments.
- A ``kwargs`` field listing the keyword arguments for initialization.
Accepted only for model subclasses that accept keyword arguments.
``log_likelihood`` and ``forward`` are treated as special keyword arguments
that refer to functions and we try to resolve them with :func:`simpple.load.resolve`.
**Note**: All extra arguments (``args``) and keyword arguments (``kwargs``)
are passed to this function are passed to the model initialization,
and they override their YAML counterpart.
See also: :doc:`Writing Models to and from YAML Files <../tutorials/yaml>`.
:param path: Path of the YAML file
:return: Model object
"""
with open(path) as f:
mdict = yaml.safe_load(f)
parameters = parse_parameters(mdict["parameters"])
yaml_args = mdict.get("args", [])
if len(args) == 0 and len(yaml_args) > 0:
args = yaml_args
kwargs = mdict.get("kwargs", {}) | kwargs
func_kwargs = ["log_likelihood", "forward"]
for kwarg in func_kwargs:
if kwarg in kwargs and isinstance(kwargs[kwarg], str):
kwargs[kwarg] = resolve(kwargs[kwarg])
model_classes = ut.get_subclasses(Model)
class_name = mdict.get("class", None)
if class_name is not None and class_name != "Model":
model_cls = model_classes[class_name]
else:
model_cls = cls
return model_cls(parameters, *args, **kwargs)
def to_yaml(self, path: Path | str, overwrite: bool = False):
model_dict = {}
model_dict["class"] = self.__class__.__name__
model_dict["parameters"] = unparse_parameters(self.parameters)
func_kwargs = ["log_likelihood", "forward"]
args_list = []
for arg in self.required_args:
arg_attr = getattr(self, arg)
if arg in func_kwargs:
arg_attr = get_func_str(arg_attr)
args_list.append(arg_attr)
kwargs_dict = {}
for kwarg in self.optional_args:
if kwarg in func_kwargs:
kwarg_attr = getattr(self, f"_{kwarg}")
kwarg_attr = get_func_str(kwarg_attr)
else:
kwarg_attr = getattr(self, kwarg)
kwargs_dict[kwarg] = kwarg_attr
if len(args_list) > 0:
model_dict["args"] = args_list
if len(kwargs_dict) > 0:
model_dict["kwargs"] = kwargs_dict
path = Path(path)
if path.exists() and not overwrite:
raise FileExistsError(
f"The file {path} already exists. Use overwrite=True to overwrite it."
)
with open(path, mode="w") as f:
yaml.dump(model_dict, f)
def _log_likelihood(self, parameters, *args, **kwargs) -> float:
raise NotImplementedError(
"log_likelihood must be passed to init or _log_likelihood must be "
"implemented by subclasses."
)
@property
def ndim(self):
"""Number of dimensions (**variable** parameters) in the model"""
return len(self.keys())
[docs]
def keys(self, fixed: bool = False) -> list[str]:
"""Get the ordered list of parameter names
:param fixed: Whether fixed parameters should be included. Defaults to ``False``.
:return: List of parameter names
"""
if fixed:
return list(self.parameters.keys())
else:
return list(self.vary_p.keys())
[docs]
def log_likelihood(self, parameters: dict | ArrayLike, *args, **kwargs) -> float:
"""Calculate the log-likelihood of the model
Internally, this wraps the ``log_likelihood`` function that was given at initialization
and passes all arguments to it.
:param parameters: Dictionary or array of parameters. If an array is used,
the order must be the same as ``Model.keys()``, without the fixed parameters.
Dictionaries can optionally include fixed parameter.s
:return: The log-likelihood value at ``parameters``.
"""
if not isinstance(parameters, dict):
parameters = dict(zip(self.keys(), parameters, strict=True))
# RHS has precedence
parameters = self.fixed_p_vals | parameters
return self._log_likelihood(parameters, *args, **kwargs)
[docs]
def log_prior(self, parameters: dict | ArrayLike) -> float:
"""Log of the prior probability for the model
:param parameters: Dictionary or array of parameters. If an array is used,
the order must be the same as ``Model.keys()``, without the fixed parameters.
Fixed parameters are always ignored even in dictionaries.
:return: Log-prior probability
"""
if not isinstance(parameters, dict):
parameters = dict(zip(self.keys(), parameters, strict=True))
lp = 0.0
for pname, pval in parameters.items():
# Skip fixed parameters
if pname in self.fixed_p:
continue
pdist = self.parameters[pname]
lp += pdist.log_prob(pval)
return lp
[docs]
def nautilus_prior(self) -> "Prior":
"""Builds and return a ``nautilus.Prior`` for the model.
Fixed parameters are not included.
:return: Nautilus Prior object.
"""
from nautilus import Prior
prior = Prior()
for pname, pdist in self.parameters.items():
if isinstance(pdist, Fixed):
continue
if not hasattr(pdist, "dist"):
raise AttributeError(
f"distribution {pdist} for parameter {pname} has no scipy distribution. "
f"This is required to build a Nautilus prior. "
"Either use the prior_transform() with nautilus or add a 'dist' attribute with the scipy distribution."
)
prior.add_parameter(pname, pdist.dist)
return prior
[docs]
def log_prob(self, parameters: dict | ArrayLike, *args, **kwargs) -> float:
"""Log posterior probability for the model
All extra arguments are passed to the ``self.log_likelihood()``.
:param parameters: Dictionary or array of parameters. If an array is used,
the order must be the same as ``Model.keys()``, without the fixed parameters.
Dictionaries can optionally include fixed parameter.s
:return: Log posterior probability
"""
if not isinstance(parameters, dict):
parameters = dict(zip(self.keys(), parameters, strict=True))
# RHS has precedence
parameters = self.fixed_p_vals | parameters
lp = self.log_prior(parameters)
if np.isfinite(lp):
return lp + self.log_likelihood(parameters, *args, **kwargs)
return lp
[docs]
def get_prior_samples(
self,
n_samples: int,
fmt: str = "dict",
seed: int | Generator | np.ndarray[int] | None = None,
fixed: bool = False,
) -> dict | np.ndarray:
"""Generate prior samples for all parameters
:param n_samples: Number of samples
:param fmt: Format of the samples (dict or array)
:param fixed: Whether fixed parameters should be included. Defaults to False.
:return: Dictionary of prior samples
"""
rng = np.random.default_rng(seed=seed)
u = rng.uniform(size=(len(self.keys(fixed=fixed)), n_samples))
if fmt == "dict":
u = dict(zip(self.keys(fixed=fixed), u, strict=True))
elif fmt != "array":
raise ValueError(f"Invalid format: {fmt}. Use 'dict' or 'array'.")
return self.prior_transform(u, fixed=fixed)
[docs]
class ForwardModel(Model):
"""A model whose likelihood calls a forward model as the mean.
.. note::
Subclasses should call ``super().__init__()`` to ensure that the ``fixed_p`` and ``vary_p`` attributes are properly set.
See :doc:`Writing Model Classes <../tutorials/writing-model-classes>`.
:param parameters: dictionary of parameters mapping parameter names to a
prior (``simpple.Distribution`` object).
:param log_likelihood: log-likelihood function that accepts a dictionary of parameters. This function should call ``forward``.
:param forward: Forward model function that accepts a dictionary of parameters as first argument.
"""
def __init__(
self,
parameters: dict,
log_likelihood: Callable | None = None,
forward: Callable | None = None,
):
super().__init__(parameters, log_likelihood=log_likelihood)
if forward is not None:
self._forward = forward
self.forward.__func__.doc__ = self._forward.__doc__
def __repr__(self):
class_name = self.__class__.__name__
prior_str = f"parameters={self.parameters}"
try:
likelihood_name = getattr(
self._log_likelihood, "__name__", "_log_likelihood"
)
likelihood_str = f", log_likelihood={likelihood_name}"
except RecursionError:
likelihood_str = ""
try:
forward_name = getattr(self._forward, "__name__", "_forward")
forward_str = f", forward={forward_name}"
except RecursionError:
forward_str = ""
return f"{class_name}({prior_str}{likelihood_str}{forward_str})"
def _forward(self, parameters, *args, **kwargs) -> float:
raise NotImplementedError(
"forward must be passed to init or _forward must be "
"implemented by subclasses."
)
[docs]
def forward(self, parameters: dict | ArrayLike, *args, **kwargs) -> np.ndarray:
"""Evaluate the forward model
Internally, this wraps the ``forward`` function that was given at initialization
and passes all arguments to it.
:param parameters: Dictionary or array of parameters. If an array is used,
the order must be the same as ``Model.keys()``, without the fixed parameters.
Dictionaries can optionally include fixed parameter.s
:return: The forward model evalulated at ``parameters``.
"""
if not isinstance(parameters, dict):
parameters = dict(zip(self.keys(), parameters, strict=True))
# RHS has precedence
parameters = self.fixed_p_vals | parameters
return self._forward(parameters, *args, **kwargs)
[docs]
def get_prior_pred(self, n_samples: int, *args, **kwargs) -> np.ndarray:
"""Get prior predictive samples
:param n_samples: Number of samples to generate
:return: Forward model realizations corresponding to prior samples
"""
prior_params = self.get_prior_samples(n_samples, fmt="array")
pred = []
for p in prior_params.T:
pred.append(self.forward(p, *args, **kwargs))
return np.array(pred)
[docs]
def get_posterior_pred(
self, chains: dict | ArrayLike, n_samples: int, *args, **kwargs
) -> np.ndarray:
"""Get posterior predictive samples
:param chains: Posterior chains organized as a dictionary or an array.
If an array, should have shape ``(ndim, nsamples)`` where
nsamples is the total number of samples in the chain and
unrelated to the ``n_samples`` argument below.
:param n_samples: Number of samples to pick from the chains.
:return: Forward model realizations corresponding to posterior samples
"""
if isinstance(chains, dict):
chains = np.array([a for a in chains.values()])
elif chains.ndim != 2 or chains.shape[0] != len(self.keys()):
raise ValueError("chains must have shape (ndim, nsamples)")
rng = np.random.default_rng()
show_idx = rng.choice(chains.shape[1], n_samples, replace=False)
pred = []
for i in show_idx:
pred.append(self.forward(chains[:, i], *args, **kwargs))
return np.array(pred)