Fitting a spline with PyMC3

How to fit a spline using the Bayesian modeling framework PyMC3.


Often, the model we want to fit is not a perfect line between some $x$ and $y$. Instead, the parameters of the model are expected to vary over $x$. There are multiple ways to handle this situation, one of which is to fit a spline. The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called knots. Below is an example of how to fit a spline using the Bayesian framework PyMC.

Fitting a spline with PyMC3

Below is a full working example of how to fit a spline using the probabilistic programming language PyMC (v4.0.0b2). The data and model are taken from Statistical Rethinking 2e by Richard McElreath. As the book uses Stan (another advanced probabilistic programming language), the modeling code is primarily taken from the GitHub repository of the PyMC3 implementation of Statistical Rethinking. My contributions are primarily of explanation and additional analyses of the data and results.


Below is the code to import packages and set some variables used in the analysis. Most of the libraries and modules are likely familiar to many. Of those that may not be well known are ‘ArviZ’, ‘patsy’, and ‘plotnine’. ‘ArviZ’ is a library for managing the components of a Bayesian model. I will use it to manage the results of fitting the model and some standard data visualizations. The ‘patsy’ library is an interface to statistical modeling using a specific formula language similar to that used in the R language. Finally, ‘plotnine’ is a plotting library that implements the “Grammar or Graphics” system based on the ‘ggplot2’ R package. As I have a lot of experience with R, I found ‘plotnine’ far more natural than the “standard” in Python data science, ‘matplotlib’.

from pathlib import Path

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotnine as gg
import pymc as pm
import seaborn as sns
from patsy import dmatrix

# Set default theme for 'plotnine'.

# For reproducibility.

# Path to the data used in Statistical Rethinking.
rethinking_data_path = Path("../data/rethinking_data")


The data for this example was the first day of the year (doy) that the cherry trees bloomed in each year (year). Years missing a doy were dropped.

d = pd.read_csv(rethinking_data_path / "cherry_blossoms.csv")
d2 = d.dropna(subset=["doy"]).reset_index(drop=True)

There are 827 years with doy data.

>>> d2.shape
(827, 5)

Below is the doy values plotted over year.

    gg.ggplot(d2, gg.aes(x="year", y="doy"))
    + gg.geom_point(color="black", alpha=0.4, size=1.3)
    + gg.theme(figure_size=(10, 5))
    + gg.labs(x="year", y="day of year", title="Cherry blossom data")



We will fit the following model.

$D \sim \mathcal{N}(\mu, \sigma)$
$\quad \mu = a + Bw$
$\qquad a \sim \mathcal{N}(100, 10)$
$\qquad w \sim \mathcal{N}(0, 10)$
$\quad \sigma \sim \text{Exp}(1)$

The day of first bloom will be modeled as a normal distribution with mean $\mu$ and standard deviation $\sigma$. The mean will be a linear model composed of a y-intercept $a$ and spline defined by the basis $B$ multiplied by the model parameter $w$ with a variable for each region of the basis. Both have relatively weak normal priors.

Prepare the spline

We can now prepare the spline matrix. First, we must determine the boundaries of the spline, often referred to as “knots” because the individual lines will be tied together at these boundaries to make a continuous and smooth curve. For this example, we will create 15 knots unevenly spaced over the years such that each region will have the same proportion of data.

num_knots = 15
knot_list = np.quantile(d2.year, np.linspace(0, 1, num_knots))
>>> knot_list
array([ 812., 1036., 1174., 1269., 1377., 1454., 1518., 1583., 1650.,
       1714., 1774., 1833., 1893., 1956., 2015.])

Below is the plot of the data we are modeling with the splines indicated by the vertical gray lines.

    gg.ggplot(d2, gg.aes(x="year", y="doy"))
    + gg.geom_point(color="black", alpha=0.4, size=1.3)
    + gg.geom_vline(xintercept=knot_list, color="gray", alpha=0.8)
    + gg.theme(figure_size=(10, 5))
    + gg.labs(x="year", y="day of year", title="Cherry blossom data with spline knots")


We can get an idea of what the spline will look like by fitting a LOESS curve (a local polynomial regression).

    gg.ggplot(d2, gg.aes(x="year", y="doy"))
    + gg.geom_point(color="black", alpha=0.4, size=1.3)
    + gg.geom_smooth(method = "loess", span=0.3, size=1.5, color="blue", linetype="-")
    + gg.geom_vline(xintercept=knot_list, color="gray", alpha=0.8)
    + gg.theme(figure_size=(10, 5))
    + gg.labs(x="year", y="day of year", title="Cherry blossom data with spline knots")


Another way of visualizing what the spline should look like is to plot individual linear models over the data between each knot. The spline will effectively be a compromise between these individual models and a continuous curve.

d2["knot_group"] = [np.where(a <= knot_list)[0][0] for a in d2.year]
d2["knot_group"] = pd.Categorical(d2["knot_group"], ordered=True)

    gg.ggplot(d2, gg.aes(x="year", y="doy"))
    + gg.geom_point(color="black", alpha=0.4, size=1.3)
    + gg.geom_smooth(
        gg.aes(group="knot_group"), method="lm", size=1.5, color="red", linetype="-"
    + gg.geom_vline(xintercept=knot_list, color="gray", alpha=0.8)
    + gg.theme(figure_size=(10, 5))
    + gg.labs(x="year", y="day of year", title="Cherry blossom data with spline knots")


Finally we can use ‘patsy’ to create the matrix $B$ that will be the b-spline basis for the regression. The degree is set to 3 to create a cubic b-spline.

B = dmatrix(
    "bs(year, knots=knots, degree=3, include_intercept=True) - 1",
    {"year": d2.year.values, "knots": knot_list[1:-1]},
>>> B
DesignMatrix with shape (827, 17)
    ['bs(year, knots=knots, degree=3, include_intercept=True)[0]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[1]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[2]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[3]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[4]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[5]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[6]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[7]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[8]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[9]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[10]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[11]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[12]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[13]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[14]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[15]',
     'bs(year, knots=knots, degree=3, include_intercept=True)[16]']
    'bs(year, knots=knots, degree=3, include_intercept=True)' (columns 0:17)
  (to view full data, use np.asarray(this_obj))

The b-spline basis is plotted below, showing the “domain” of each piece of the spline. The height of each curve indicates how influential the corresponding model covariate (one per spline region) will be on the final model. The overlapping regions represent the knots showing how the smooth transition from one region to the next is formed.

spline_df = (
    .melt("year", var_name="spline_i", value_name="value")

    gg.ggplot(spline_df, gg.aes(x="year", y="value"))
    + gg.geom_line(gg.aes(group="spline_i", color="spline_i"))
    + gg.scale_color_discrete(guide=gg.guide_legend(ncol=2))
    + gg.labs(x="year", y="basis", color="spline idx")



Finally, the model can be built using PyMC. A graphical diagram shows the organization of the model parameters.

with pm.Model(rng_seeder=RANDOM_SEED) as m4_7:
    a = pm.Normal("a", 100, 5)
    w = pm.Normal("w", mu=0, sd=3, shape=B.shape[1])
    mu = pm.Deterministic("mu", a +, order="F"), w.T))
    sigma = pm.Exponential("sigma", 1)
    D = pm.Normal("D", mu, sigma, observed=d2.doy)


2000 samples of the posterior distribution are taken and the posterior predictions are calculated.

with m4_7:
    trace_m4_7 = pm.sample(2000, tune=2000, chains=2, return_inferencedata=True)
    _ = pm.sample_posterior_predictive(trace_m4_7, extend_inferencedata=True)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [a, w, sigma]

100.00% [8000/8000 00:41<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 2_000 tune and 2_000 draw iterations (4_000 + 4_000 draws total) took 57 seconds.

100.00% [4000/4000 00:00<00:00]


Now we can analyze the draws from the posterior of the model.

Fit parameters

Below is a table summarizing the posterior distributions of the model parameters. The posteriors of $a$ and $\sigma$ are quite narrow while those for $w$ are wider. This is likely because all of the data points are used to estimate $a$ and $\sigma$ whereas only a subset are used for each value of $w$. (It could be interesting to model these hierarchically allowing for the sharing of information and adding regularization across the spline.) The effective sample size and $\widehat{R}$ values all look good, indicating that the model has converged and sampled well from the posterior distribution.

az.summary(trace_m4_7, var_names=["a", "w", "sigma"])

We can visualize the trace (MCMC samples) of the parameters, again showing they were confidently estimated.

az.plot_trace(trace_m4_7, var_names=["a", "w", "sigma"])


A forest plot shows the distributions of the values for $w$ are larger, though some do fall primarily away from 0 indicating a non-null effect/association.

az.plot_forest(trace_m4_7, var_names=["w"], combined=True);


Another visualization of the fit spline values is to plot them multiplied against the basis matrix. The knot boundaries are shown in gray again, but now the spline basis is multiplied against the values of $w$ (represented as the rainbow-colored curves). The dot product of $B$ and $w$ - the actual computation in the linear model - is shown in blue.

wp = trace_m4_7.posterior["w"].values.mean(axis=(0, 1))

spline_df = (
    pd.DataFrame(B * wp.T)
    .melt("year", var_name="spline_i", value_name="value")

spline_df_merged = (
    pd.DataFrame(, wp.T))
    .melt("year", var_name="spline_i", value_name="value")

    gg.ggplot(spline_df, gg.aes(x="year", y="value"))
    + gg.geom_vline(xintercept=knot_list, color="#0C73B4", alpha=0.5)
    + gg.geom_line(data=spline_df_merged, linetype="-", color="blue", size=2, alpha=0.7)
    + gg.geom_line(gg.aes(group="spline_i", color="spline_i"), alpha=0.7, size=1)
    + gg.scale_color_discrete(guide=gg.guide_legend(ncol=2), color_space="husl")
    + gg.theme(figure_size=(10, 5))
    + gg.labs(x="year", y="basis", title="Fit spline", color="spline idx")


Model predictions

Lastly, we can visualize the predictions of the model using the posterior predictive check.

post_pred = az.summary(trace_m4_7, var_names=["mu"]).reset_index(drop=True)
d2_post = d2.copy().reset_index(drop=True)
d2_post["pred_mean"] = post_pred["mean"]
d2_post["pred_hdi_lower"] = post_pred["hdi_3%"]
d2_post["pred_hdi_upper"] = post_pred["hdi_97%"]
    gg.ggplot(d2_post, gg.aes(x="year"))
    + gg.geom_ribbon(
        gg.aes(ymin="pred_hdi_lower", ymax="pred_hdi_upper"), alpha=0.3, fill="tomato"
    + gg.geom_line(gg.aes(y="pred_mean"), color="firebrick", alpha=1, size=2)
    + gg.geom_point(gg.aes(y="doy"), color="black", alpha=0.4, size=1.3)
    + gg.geom_vline(xintercept=knot_list, color="gray", alpha=0.8)
    + gg.theme(figure_size=(10, 5))
    + gg.labs(
        y="day of year",
        title="Cherry blossom data with posterior predictions",




It was pointed out to me by a reader that doy was not the number of days of bloom, but the day of the year with the first bloom. I fixed this in the text and plots. I also took this opportunity to fix the embarrassingly large number of typos and update the code to use PyMC v4.

comments powered by Disqus