Fixed Parameters in simpple Models#
As shown in previous tutorials, simpple model parameters are defined directly as distributions.
However, sometimes, we build a model and would like to test it with only a subset of its parameters.
In such cases, re-building the entire model is unnecessarily complicated, and freezing certain parameters is preferable.
To enable fixed parameters, simpple provides a Fixed distributions, which is basically a delta function centered on a given parameter value.
To keep samplign efficient, these fixed parameters are not treated in the same way as other parameters:
During initialization,
model.fixed_pandmodel.vary_pare created under the hood to separated fixed and variable parameters.Fixed parameters are not included in
model.ndimFixed parameters are not included in
model.keys()by default. They will be included withmodel.keys(fixed=True)When calling
model.log_likelihood(),model.log_prob()andmodel.forward(), fixed parameters can optionally be included, but are not required. If they are not included, the fixed value is used.When calling
model.log_prior(),model.prior_transform()andmodel.nautilus_priors(), fixed parameters are ignored.
We will explore this functionality below with a simple sinusoidal model.
Simulated Data#
We will first simulate sinusoidal data that we will fit with a simple forward model.
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(819)
def forward_sine(p: dict, x: np.ndarray) -> np.ndarray:
return p["A"] * np.sin(2 * np.pi * x / p["P"] - p["phi"])
p_true = {
"A": 10.0,
"P": 3.0,
"phi": np.pi,
}
x = np.sort(rng.uniform(low=0.0, high=10.0, size=100))
x_mod = np.linspace(0.0, 10.0, num=1000)
y_true = forward_sine(p_true, x)
y_true_mod = forward_sine(p_true, x_mod)
yerr = 1
y = y_true + yerr * rng.normal(size=x.size)
def plot_data():
plt.plot(x_mod, y_true_mod, "k", alpha=0.5, zorder=10000, label="True model")
plt.errorbar(x, y, yerr=yerr, label="Data", fmt="k.", mfc="w", capsize=2)
plt.legend(loc=1)
plt.xlabel("x")
plt.ylabel("y")
plot_data()
plt.show()
simpple Model without Fixed Parameters#
First, we will fit the data with a fully variable model where none of the sinusoidal parameters are fixed. This is very similar to the line-fitting tutorial tutorial, but with a different forward model.
We first build our simpple model and test it with fixed values.
from simpple.model import ForwardModel
import simpple.distributions as sdist
def log_likelihood(
p: dict, x: np.ndarray, y: np.ndarray, yerr: np.ndarray
) -> np.ndarray:
ymod = forward_sine(p, x)
s2 = yerr**2
return -0.5 * np.sum(np.log(2 * np.pi * s2) + (y - ymod) ** 2 / s2)
parameters = {
"A": sdist.LogUniform(1e-4, 1e2),
"P": sdist.LogUniform(1e-1, 1e1),
"phi": sdist.Uniform(0, 2 * np.pi),
}
model = ForwardModel(parameters, log_likelihood, forward_sine)
test_p = {"A": 11.0, "P": 3.5, "phi": 2.5}
plot_data()
plt.plot(x_mod, model.forward(test_p, x_mod), label="Test model")
plt.show()
Sampling of the Model without Fixed Parameters#
Now that we have a working model, we can sample the parameters with MCMC or nested sampling. Here will use emcee to keep things simple.
import emcee
nwalkers = 50
nsteps = 5000
sampler = emcee.EnsembleSampler(nwalkers, model.ndim, model.log_prob, args=(x, y, yerr))
p0 = np.array(list(test_p.values())) + 1e-4 * rng.normal(size=(nwalkers, model.ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 53/5000 [00:00<00:09, 529.85it/s]
2%|▏ | 106/5000 [00:00<00:09, 529.57it/s]
3%|▎ | 159/5000 [00:00<00:09, 528.75it/s]
4%|▍ | 212/5000 [00:00<00:09, 529.10it/s]
5%|▌ | 265/5000 [00:00<00:08, 528.75it/s]
6%|▋ | 318/5000 [00:00<00:08, 528.48it/s]
7%|▋ | 371/5000 [00:00<00:08, 528.58it/s]
8%|▊ | 425/5000 [00:00<00:08, 529.15it/s]
10%|▉ | 478/5000 [00:00<00:08, 529.33it/s]
11%|█ | 532/5000 [00:01<00:08, 529.57it/s]
12%|█▏ | 586/5000 [00:01<00:08, 529.73it/s]
13%|█▎ | 640/5000 [00:01<00:08, 529.92it/s]
14%|█▍ | 693/5000 [00:01<00:08, 529.67it/s]
15%|█▍ | 746/5000 [00:01<00:08, 529.66it/s]
16%|█▌ | 799/5000 [00:01<00:07, 529.23it/s]
17%|█▋ | 852/5000 [00:01<00:07, 528.63it/s]
18%|█▊ | 905/5000 [00:01<00:07, 528.24it/s]
19%|█▉ | 958/5000 [00:01<00:07, 528.13it/s]
20%|██ | 1011/5000 [00:01<00:07, 527.64it/s]
21%|██▏ | 1064/5000 [00:02<00:07, 527.70it/s]
22%|██▏ | 1117/5000 [00:02<00:07, 528.02it/s]
23%|██▎ | 1170/5000 [00:02<00:07, 528.31it/s]
24%|██▍ | 1223/5000 [00:02<00:07, 527.13it/s]
26%|██▌ | 1276/5000 [00:02<00:07, 527.85it/s]
27%|██▋ | 1329/5000 [00:02<00:06, 527.94it/s]
28%|██▊ | 1382/5000 [00:02<00:06, 528.16it/s]
29%|██▊ | 1436/5000 [00:02<00:06, 528.67it/s]
30%|██▉ | 1489/5000 [00:02<00:06, 528.92it/s]
31%|███ | 1542/5000 [00:02<00:06, 528.94it/s]
32%|███▏ | 1595/5000 [00:03<00:06, 528.82it/s]
33%|███▎ | 1648/5000 [00:03<00:06, 528.95it/s]
34%|███▍ | 1701/5000 [00:03<00:06, 528.82it/s]
35%|███▌ | 1754/5000 [00:03<00:06, 528.02it/s]
36%|███▌ | 1807/5000 [00:03<00:06, 528.30it/s]
37%|███▋ | 1860/5000 [00:03<00:05, 528.20it/s]
38%|███▊ | 1914/5000 [00:03<00:05, 528.84it/s]
39%|███▉ | 1967/5000 [00:03<00:05, 528.61it/s]
40%|████ | 2021/5000 [00:03<00:05, 529.34it/s]
41%|████▏ | 2074/5000 [00:03<00:05, 528.75it/s]
43%|████▎ | 2128/5000 [00:04<00:05, 529.25it/s]
44%|████▎ | 2181/5000 [00:04<00:05, 528.96it/s]
45%|████▍ | 2234/5000 [00:04<00:05, 528.86it/s]
46%|████▌ | 2287/5000 [00:04<00:05, 529.00it/s]
47%|████▋ | 2340/5000 [00:04<00:05, 529.08it/s]
48%|████▊ | 2393/5000 [00:04<00:04, 529.20it/s]
49%|████▉ | 2446/5000 [00:04<00:04, 528.92it/s]
50%|████▉ | 2499/5000 [00:04<00:04, 529.04it/s]
51%|█████ | 2553/5000 [00:04<00:04, 529.27it/s]
52%|█████▏ | 2606/5000 [00:04<00:04, 529.10it/s]
53%|█████▎ | 2659/5000 [00:05<00:04, 529.11it/s]
54%|█████▍ | 2712/5000 [00:05<00:04, 528.62it/s]
55%|█████▌ | 2766/5000 [00:05<00:04, 529.22it/s]
56%|█████▋ | 2819/5000 [00:05<00:04, 527.87it/s]
57%|█████▋ | 2872/5000 [00:05<00:04, 528.10it/s]
58%|█████▊ | 2925/5000 [00:05<00:03, 527.82it/s]
60%|█████▉ | 2978/5000 [00:05<00:03, 527.54it/s]
61%|██████ | 3031/5000 [00:05<00:03, 527.96it/s]
62%|██████▏ | 3084/5000 [00:05<00:03, 527.21it/s]
63%|██████▎ | 3137/5000 [00:05<00:03, 526.48it/s]
64%|██████▍ | 3190/5000 [00:06<00:03, 527.32it/s]
65%|██████▍ | 3243/5000 [00:06<00:03, 527.18it/s]
66%|██████▌ | 3296/5000 [00:06<00:03, 527.09it/s]
67%|██████▋ | 3349/5000 [00:06<00:03, 526.97it/s]
68%|██████▊ | 3403/5000 [00:06<00:03, 528.40it/s]
69%|██████▉ | 3457/5000 [00:06<00:02, 528.90it/s]
70%|███████ | 3510/5000 [00:06<00:02, 528.93it/s]
71%|███████▏ | 3563/5000 [00:06<00:02, 529.04it/s]
72%|███████▏ | 3616/5000 [00:06<00:02, 529.04it/s]
73%|███████▎ | 3669/5000 [00:06<00:02, 529.06it/s]
74%|███████▍ | 3722/5000 [00:07<00:02, 529.18it/s]
76%|███████▌ | 3775/5000 [00:07<00:02, 529.22it/s]
77%|███████▋ | 3828/5000 [00:07<00:02, 528.75it/s]
78%|███████▊ | 3881/5000 [00:07<00:02, 528.91it/s]
79%|███████▊ | 3934/5000 [00:07<00:02, 529.06it/s]
80%|███████▉ | 3988/5000 [00:07<00:01, 529.41it/s]
81%|████████ | 4041/5000 [00:07<00:01, 528.85it/s]
82%|████████▏ | 4094/5000 [00:07<00:01, 529.04it/s]
83%|████████▎ | 4147/5000 [00:07<00:01, 528.12it/s]
84%|████████▍ | 4200/5000 [00:07<00:01, 528.31it/s]
85%|████████▌ | 4254/5000 [00:08<00:01, 529.14it/s]
86%|████████▌ | 4308/5000 [00:08<00:01, 529.60it/s]
87%|████████▋ | 4361/5000 [00:08<00:01, 528.83it/s]
88%|████████▊ | 4414/5000 [00:08<00:01, 528.58it/s]
89%|████████▉ | 4467/5000 [00:08<00:01, 528.91it/s]
90%|█████████ | 4521/5000 [00:08<00:00, 529.82it/s]
91%|█████████▏| 4574/5000 [00:08<00:00, 528.44it/s]
93%|█████████▎| 4627/5000 [00:08<00:00, 526.85it/s]
94%|█████████▎| 4680/5000 [00:08<00:00, 527.53it/s]
95%|█████████▍| 4733/5000 [00:08<00:00, 527.80it/s]
96%|█████████▌| 4787/5000 [00:09<00:00, 528.78it/s]
97%|█████████▋| 4840/5000 [00:09<00:00, 528.62it/s]
98%|█████████▊| 4893/5000 [00:09<00:00, 528.76it/s]
99%|█████████▉| 4946/5000 [00:09<00:00, 528.87it/s]
100%|█████████▉| 4999/5000 [00:09<00:00, 528.95it/s]
100%|██████████| 5000/5000 [00:09<00:00, 528.57it/s]
Let us have a look at the chains and the posterior samples, both in parameter space and in model predictions.
from simpple.plot import chainplot
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()
import corner
chain = sampler.get_chain(flat=True, discard=1000, thin=5)
corner.corner(chain, labels=model.keys(), truths=list(p_true.values()))
plt.show()
pred_samples = model.get_posterior_pred(chain.T, 100, x_mod)
plt.plot(x_mod, pred_samples[0], "C0-", label="Posterior samples")
plt.plot(x_mod, pred_samples[1:].T, "C0-", alpha=0.2)
plot_data()
plt.show()
Model with Fixed Parameters#
Let us say that for some reason, we want to keep one of the parameters fixed. Maybe we have an extremely good constraint from another dataset, or we want to test whether a model comparison favors this parameter as fixed or variable.
For a simple model like the one used in this tutorial, we could easily re-write the model with one of the parameters fixed. However, as we build more complicated models, re-writing them every time we wish to freeze a parameters quickly becomes inconvenient.
As mentioned above, using the Fixed distribution will freeze specific parameters in our model.
For example, in the model we fix the phase to its true value.
parameters_fix = {
"A": sdist.LogUniform(1e-4, 1e2),
"P": sdist.LogUniform(1e-1, 1e1),
"phi": sdist.Fixed(np.pi),
}
model_fix = ForwardModel(parameters_fix, log_likelihood, forward_sine)
As explained above, the number of dimensions and the keys of the model will not account for the fixed parameter.
print("Keys of the fixed phase model:", model_fix.keys())
print("Keys of the fixed phase model (including fixed):", model_fix.keys(fixed=True))
print("ndim of the fixed phase model:", model_fix.ndim)
Keys of the fixed phase model: ['A', 'P']
Keys of the fixed phase model (including fixed): ['A', 'P', 'phi']
ndim of the fixed phase model: 2
There are also two extra dictionaries, fixed_p and vary_p, which are mostly for internal use but can be useful to filter parameters:
print("Fixed parameters", model_fix.fixed_p)
print("Variable parameters:", model_fix.vary_p)
Fixed parameters {'phi': Fixed(value=3.141592653589793)}
Variable parameters: {'A': LogUniform(low=0.0001, high=100.0), 'P': LogUniform(low=0.1, high=10.0)}
The nice thing about this is that we should be able to re-use most of the code from the sampling section with very little modifications. Let us try!
We need to change the model name from model to model_fix first.
import emcee
nwalkers = 50
nsteps = 5000
sampler = emcee.EnsembleSampler(
nwalkers, model_fix.ndim, model_fix.log_prob, args=(x, y, yerr)
)
And in our test parameters, we must use only the variable parameters to initialize the sampler.
test_p_vary = {k: v for k, v in test_p.items() if k in model_fix.keys()}
p0 = np.array(list(test_p_vary.values())) + 1e-5 * rng.normal(
size=(nwalkers, model_fix.ndim)
)
_ = sampler.run_mcmc(p0, nsteps, progress=True)
0%| | 0/5000 [00:00<?, ?it/s]
1%| | 62/5000 [00:00<00:08, 612.56it/s]
2%|▏ | 124/5000 [00:00<00:08, 608.81it/s]
4%|▎ | 185/5000 [00:00<00:07, 608.45it/s]
5%|▍ | 246/5000 [00:00<00:07, 607.02it/s]
6%|▌ | 307/5000 [00:00<00:07, 606.87it/s]
7%|▋ | 368/5000 [00:00<00:07, 607.76it/s]
9%|▊ | 429/5000 [00:00<00:07, 608.35it/s]
10%|▉ | 490/5000 [00:00<00:07, 608.13it/s]
11%|█ | 551/5000 [00:00<00:07, 608.10it/s]
12%|█▏ | 612/5000 [00:01<00:07, 607.70it/s]
13%|█▎ | 673/5000 [00:01<00:07, 606.63it/s]
15%|█▍ | 734/5000 [00:01<00:07, 607.21it/s]
16%|█▌ | 795/5000 [00:01<00:06, 607.05it/s]
17%|█▋ | 856/5000 [00:01<00:06, 605.71it/s]
18%|█▊ | 917/5000 [00:01<00:06, 606.71it/s]
20%|█▉ | 978/5000 [00:01<00:06, 606.46it/s]
21%|██ | 1039/5000 [00:01<00:06, 606.91it/s]
22%|██▏ | 1100/5000 [00:01<00:06, 607.17it/s]
23%|██▎ | 1161/5000 [00:01<00:06, 607.47it/s]
24%|██▍ | 1222/5000 [00:02<00:06, 607.23it/s]
26%|██▌ | 1283/5000 [00:02<00:06, 607.10it/s]
27%|██▋ | 1344/5000 [00:02<00:06, 607.35it/s]
28%|██▊ | 1405/5000 [00:02<00:05, 607.69it/s]
29%|██▉ | 1466/5000 [00:02<00:05, 607.25it/s]
31%|███ | 1527/5000 [00:02<00:05, 606.93it/s]
32%|███▏ | 1588/5000 [00:02<00:05, 607.11it/s]
33%|███▎ | 1649/5000 [00:02<00:05, 607.00it/s]
34%|███▍ | 1710/5000 [00:02<00:05, 607.54it/s]
35%|███▌ | 1771/5000 [00:02<00:05, 607.64it/s]
37%|███▋ | 1832/5000 [00:03<00:05, 607.45it/s]
38%|███▊ | 1893/5000 [00:03<00:05, 608.15it/s]
39%|███▉ | 1954/5000 [00:03<00:05, 607.36it/s]
40%|████ | 2015/5000 [00:03<00:04, 606.68it/s]
42%|████▏ | 2076/5000 [00:03<00:04, 607.21it/s]
43%|████▎ | 2137/5000 [00:03<00:04, 607.93it/s]
44%|████▍ | 2198/5000 [00:03<00:04, 607.03it/s]
45%|████▌ | 2259/5000 [00:03<00:04, 606.44it/s]
46%|████▋ | 2320/5000 [00:03<00:04, 607.05it/s]
48%|████▊ | 2381/5000 [00:03<00:04, 607.48it/s]
49%|████▉ | 2442/5000 [00:04<00:04, 607.18it/s]
50%|█████ | 2503/5000 [00:04<00:04, 607.50it/s]
51%|█████▏ | 2564/5000 [00:04<00:04, 607.87it/s]
52%|█████▎ | 2625/5000 [00:04<00:03, 607.39it/s]
54%|█████▎ | 2686/5000 [00:04<00:03, 607.45it/s]
55%|█████▍ | 2747/5000 [00:04<00:03, 607.22it/s]
56%|█████▌ | 2808/5000 [00:04<00:03, 608.03it/s]
57%|█████▋ | 2869/5000 [00:04<00:03, 607.85it/s]
59%|█████▊ | 2930/5000 [00:04<00:03, 608.46it/s]
60%|█████▉ | 2992/5000 [00:04<00:03, 608.80it/s]
61%|██████ | 3053/5000 [00:05<00:03, 608.63it/s]
62%|██████▏ | 3114/5000 [00:05<00:03, 607.45it/s]
64%|██████▎ | 3176/5000 [00:05<00:02, 608.10it/s]
65%|██████▍ | 3237/5000 [00:05<00:02, 607.91it/s]
66%|██████▌ | 3298/5000 [00:05<00:02, 607.37it/s]
67%|██████▋ | 3359/5000 [00:05<00:02, 607.11it/s]
68%|██████▊ | 3420/5000 [00:05<00:02, 607.07it/s]
70%|██████▉ | 3481/5000 [00:05<00:02, 606.01it/s]
71%|███████ | 3542/5000 [00:05<00:02, 606.62it/s]
72%|███████▏ | 3603/5000 [00:05<00:02, 607.05it/s]
73%|███████▎ | 3664/5000 [00:06<00:02, 607.59it/s]
74%|███████▍ | 3725/5000 [00:06<00:02, 607.43it/s]
76%|███████▌ | 3786/5000 [00:06<00:01, 607.18it/s]
77%|███████▋ | 3848/5000 [00:06<00:01, 608.30it/s]
78%|███████▊ | 3909/5000 [00:06<00:01, 607.70it/s]
79%|███████▉ | 3970/5000 [00:06<00:01, 607.51it/s]
81%|████████ | 4031/5000 [00:06<00:01, 608.19it/s]
82%|████████▏ | 4092/5000 [00:06<00:01, 607.94it/s]
83%|████████▎ | 4153/5000 [00:06<00:01, 607.69it/s]
84%|████████▍ | 4214/5000 [00:06<00:01, 607.16it/s]
86%|████████▌ | 4275/5000 [00:07<00:01, 607.36it/s]
87%|████████▋ | 4336/5000 [00:07<00:01, 607.13it/s]
88%|████████▊ | 4397/5000 [00:07<00:00, 606.90it/s]
89%|████████▉ | 4458/5000 [00:07<00:00, 605.94it/s]
90%|█████████ | 4519/5000 [00:07<00:00, 606.06it/s]
92%|█████████▏| 4581/5000 [00:07<00:00, 607.26it/s]
93%|█████████▎| 4642/5000 [00:07<00:00, 608.02it/s]
94%|█████████▍| 4703/5000 [00:07<00:00, 608.09it/s]
95%|█████████▌| 4764/5000 [00:07<00:00, 608.53it/s]
96%|█████████▋| 4825/5000 [00:07<00:00, 608.89it/s]
98%|█████████▊| 4886/5000 [00:08<00:00, 608.39it/s]
99%|█████████▉| 4948/5000 [00:08<00:00, 609.03it/s]
100%|██████████| 5000/5000 [00:08<00:00, 607.48it/s]
That worked! Let us see what the results look like.
from simpple.plot import chainplot
chainplot(sampler.get_chain(), labels=model_fix.keys())
plt.show()
For the corner plot, we will use a dictionary to filter the parameters automatically.
import corner
chain = sampler.get_chain(flat=True, discard=1000, thin=5)
chain_dict = dict(zip(model_fix.keys(), chain.T))
corner.corner(chain_dict, labels=model_fix.keys(), truths=p_true)
plt.show()
pred_samples = model_fix.get_posterior_pred(chain.T, 100, x_mod)
plt.plot(x_mod, pred_samples[0], "C0-", label="Posterior samples")
plt.plot(x_mod, pred_samples[1:].T, "C0-", alpha=0.2)
plot_data()
plt.show()
That’s it!
Hopefully this can be useful when building complicated models for various use-cases with simpple.