Source code for simpple.utils
import inspect
from typing import Any
import numpy as np
from scipy.stats._distn_infrastructure import rv_continuous_frozen
[docs]
def get_subclasses(cls: type) -> dict[str, type]:
"""Get all subclasses of a class
This is used internally when reading models and distributions from YAML.
:param cls: Class for which we want the subclasses
:return: Dictionary mapping subclass names to the actual subclasses
"""
subclasses = cls.__subclasses__()
results = {s.__name__: s for s in subclasses}
if len(subclasses) == 0:
return {}
for subclass in subclasses:
results |= get_subclasses(subclass)
return results
[docs]
def scipy_dist_to_dict(dist) -> dict:
"""Convert a scipy distribution to a dictionary
This first calls the ``__dict__`` method and then tries to unroll
arguments and keyword arguments.
Used internally by :func:`make_hashable`.
:param dist: Any scipy distribution creted through scipy.stats.
:return: Dictionary with the distribution attributes
"""
dist_dict = dist.__dict__
comp_dict = {}
for k in dist_dict:
if k == "dist":
comp_dict["scipy_dist_type"] = type(dist_dict[k])
continue
elif isinstance(dist_dict[k], dict):
for kwd in dist_dict[k]:
comp_dict[f"scipy_dist_{k}_{kwd}"] = dist_dict[k][kwd]
else:
comp_dict[f"scipy_dist_{k}"] = dist_dict[k]
return comp_dict
[docs]
def make_hashable(obj: Any) -> tuple | bytes | Any:
"""Return an hashable version of an object
Used internally to hash and compare models and distributions.
Handles the following objects:
- dict is converted to a tuple recursively
- list and tuples are converted to tuples recursively
- numpy arrays are converted to bytes with ``.tobytes()``
- Scipy distribution are converted to dictionaries with :func:`scipy_dist_to_dict`
and then made hashable.
- Any other object is returned as is
:param obj: Object to be made hashable
:return: Hashable version of the object
"""
# vibe-coded
if isinstance(obj, dict):
return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
elif isinstance(obj, (list, tuple)):
return tuple(make_hashable(x) for x in obj)
elif isinstance(obj, np.ndarray):
return obj.tobytes()
elif isinstance(obj, rv_continuous_frozen):
return make_hashable(scipy_dist_to_dict(obj))
else:
return obj
[docs]
def find_args(obj: Any, argtype: str = "args") -> list[str]:
"""Find arguments that an object requires at initialization
Inspects the ``__init__`` method of the object class.
Used internally by models and distributions.
:param obj: Any object
:param argtype: Type of argument desired (``"args"`` or ``"kwars"``)
:return: List of argument names
"""
def check_pval_type(pval, argtype: str):
if argtype == "args":
return pval.default is pval.empty
elif argtype == "kwargs":
return pval.default is not pval.empty
else:
raise ValueError("argtype must be one of 'args' or 'kwargs'")
if callable(obj):
func = obj
else:
func = obj.__class__.__init__
sig = inspect.signature(func)
ignored_args = ["self", "args", "kwargs", "parameters"]
required_args = [
pname
for pname, pval in sig.parameters.items()
if pname not in ignored_args and check_pval_type(pval, argtype)
]
return required_args