Source code for simpple.load

import copy
import inspect
import sys
from collections.abc import Callable
from importlib import import_module
from pathlib import Path

import yaml

from simpple.distributions import Distribution


[docs] def parse_parameters(pdict: dict) -> dict[str, Distribution]: """Parse parameter distributions from a YAML dictionary Each parameter spec is read with :meth:`simpple.distributions.Distribution.from_yaml_dict` :param pdict: Dictionary mapping parameter names to distribution specs :return: Dictionary mapping parameter names to :class:`simpple.distributions.Distribution` objects """ pdict = copy.deepcopy(pdict) parameters = {} for name, spec in pdict.items(): parameters[name] = Distribution.from_yaml_dict(spec) return parameters
[docs] def load_parameters(path: Path | str) -> dict[str, Distribution]: """Load parameter dictionary from YAML file The YAML file should contain a parameter specification consistent with :func:`parse_parameters`, either under the ``parameters`` key or at the top level. This allows users to read only the parmeters from a full model YAML file. :param path: Path to a YAML file. :return: Dictionary mapping parameter names to :class:`simpple.distributions.Distribution` objects """ with open(path) as f: mdict = yaml.safe_load(f) if "parameters" in mdict: pdict = mdict["parameters"] else: pdict = mdict return parse_parameters(pdict)
[docs] def unparse_parameters(parameters: dict[str, Distribution]) -> dict: """Convert parameter dictionary to a YAML-compatible dictionary Does the exact opposite from :func:`parse_parameters`. Calls :meth:`simpple.distributions.Distribution.to_yaml_dict` for each parameter. (Note: this link is to the default implementation, see the docs for each class to see if it overrides it). :param parameters: Dictionary mapping parameter names to :class:`simpple.distributions.Distribution` objects :return: Dictionary mapping parameter names to YAML specifications """ out_dict = {} for pname, pdist in parameters.items(): out_dict[pname] = pdist.to_yaml_dict() return out_dict
[docs] def write_parameters( path: Path | str, parameters: dict[str, Distribution], overwrite: bool = False ): """Write parameters to a YAML file Calls :func:`unparse_parameters` and dumps it to the YAML file. :param path: Path of the YAML file :param parameters: Dictionary mapping parameter names to :class:`simpple.distributions.Distribution` objects :param overwrite: Overwrite the YAML file if ``True`` """ yaml_dict = unparse_parameters(parameters) path = Path(path) if path.exists() and not overwrite: raise FileExistsError( f"The file {path} already exists. Use overwrite=True to overwrite it." ) with open(path, mode="w") as f: yaml.dump(yaml_dict, f)
[docs] def resolve(func_str: str) -> Callable: """Resolve a function based on its name This function tries to resolve a function based on its name. It is used by :meth:`simpple.model.Model.from_yaml` to resolve the likelihood and forward model functions based on the name given in a YAML file. If there is a dot (``.``) in the string, it will treat everything before the last dot as the module name and will try to import the function. Otherwise, it loops through the ``globals()`` dictionary, the ``__main__`` namespace, and then goes up the stack of contexts to find one with the function. If nothing is found, a ``ValueError`` is raised. :param func_str: Name of the function :return: The function object that the name refers to """ func = None if "<locals>" in func_str: func_str = func_str.split(".")[-1] if "." in func_str and "<locals>" not in func_str: module_str, func_str = func_str.rsplit(".", 1) module = import_module(module_str) func = getattr(module, func_str) elif func_str in globals(): func = globals()[func_str] elif func_str in sys.modules.get("__main__", {}).__dict__: func = sys.modules["__main__"].__dict__[func_str] else: # this loop was 100% vibe-coded for frame_info in inspect.stack(): frame = frame_info.frame if func_str in frame.f_locals: func = frame.f_locals[func_str] if func_str in frame.f_globals: func = frame.f_globals[func_str] if not callable(func): raise ValueError(f"Could not find function {func_str}") return func
[docs] def get_func_str(func: Callable) -> str: """Get a string representing the function :param func: Function object :return: String representing the function """ if not callable(func): raise TypeError("func should be a callable") mod = func.__module__ name = func.__qualname__ if mod == "__main__": return name return f"{mod}.{name}"