# Dirichlet Regression with PyMC

Dev Bayesian PyMC Data Analysis Modeling Statistics Python

I want to apologize at the top for the general lack-luster appearance and text in this post. It is meant to serve as a quick, simple guide, so I chose to keep it relatively light on text and explanation.

## Introduction#

Below, I provide a simple example of a Dirichlet regression in PyMC. This form of generalized linear model is appropriate when modeling proportions of multiple groups, that is, when modeling a collection of positive values that must sum to a constant. Some common examples include ratios and percentages.

For this example, I used a simplified case that was the original impetus for me looking in this form of model. I have measured a protein’s expression in two groups, a control and experimental, across $10$ tissues. I have measured the expression in $6$ replicates for each condition across all $10$ tissues. Therefore, I have $10 \times 6 \times 2$ measurements. The values are all greater than or equal to $0$ (i.e. 0 or positive) and the sum of the values for each replicate sum to $1$.

I want to know if the expression of the protein is different between control and experiment in each tissue.

Because of the constraint on the values being $\ge 0$ and summing to $1$ across replicates, the likelihood should be a Dirichlet distribution. The exponential is the appropriate link function between the likelihood and linear combination of variables.

## Setup#

import arviz as az
import janitor  # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set_style("whitegrid")


## Data generation#

A fake dataset was produced for the situation described above. The vectors ctrl_tissue_props and expt_tissue_props contain the “true” proportions of protein expression across the ten tissues for the control and experimental conditions. These were randomly generated as printed at the end of the code block.

N_TISSUES = 10
N_REPS = 6
CONDITIONS = ["C", "E"]

TISSUES = [f"tissue-{i}" for i in range(N_TISSUES)]
REPS = [f"{CONDITIONS[0]}-{i}" for i in range(N_REPS)]
REPS += [f"{CONDITIONS[1]}-{i}" for i in range(N_REPS)]

np.random.seed(909)
ctrl_tissue_props = np.random.beta(2, 2, N_TISSUES)
ctrl_tissue_props = ctrl_tissue_props / np.sum(ctrl_tissue_props)
expt_tissue_props = np.random.beta(2, 2, N_TISSUES)
expt_tissue_props = expt_tissue_props / np.sum(expt_tissue_props)

print("Real proportions for each tissue:")
print(np.vstack([ctrl_tissue_props, expt_tissue_props]).round(3))

Real proportions for each tissue:
[[0.072 0.148 0.137 0.135 0.074 0.118 0.083 0.015 0.12  0.098]
[0.066 0.104 0.138 0.149 0.062 0.057 0.098 0.109 0.131 0.086]]


Protein expression values were sampled using these proportions, multiplied by 100 to reduce the variability in the sampled values. Recall that the Dirichlet is effectively a multi-class Beta distribution, so the input numbers can be thought of as the observed number of instances for each class. The more observations, the more confidence that the observed frequencies are representative of the true proportions.

_ctrl_data = np.random.dirichlet(ctrl_tissue_props * 100, N_REPS)
_expt_data = np.random.dirichlet(expt_tissue_props * 100, N_REPS)

expr_data = (
pd.DataFrame(np.vstack([_ctrl_data, _expt_data]), columns=TISSUES)
.assign(replicate=REPS)
.set_index("replicate")
)
expr_data.round(3)


tissue-0tissue-1tissue-2tissue-3tissue-4tissue-5tissue-6tissue-7tissue-8tissue-9
replicate
C-00.0850.1160.1840.1750.0380.1010.0690.0140.1020.115
C-10.1340.1880.1010.1190.0750.1040.0520.0010.1210.107
C-20.0980.1250.1270.1380.0940.0910.1340.0170.0800.096
C-30.0690.1540.1400.0820.0650.1820.0540.0110.1100.132
C-40.0330.2080.1510.0900.0670.1090.0640.0030.1600.115
C-50.0740.1300.1300.1130.0810.1290.0590.0200.1110.152
E-00.1000.1050.1140.0810.0880.0560.1200.0870.1670.081
E-10.0430.1240.1840.0980.0710.0400.1220.0710.1570.089
E-20.0990.1080.1020.1390.0890.0390.1150.0920.1580.059
E-30.0760.0740.1220.1420.0580.0620.1030.0810.1060.176
E-40.0980.1030.1170.1130.0480.1100.1130.1040.1660.027
E-50.0590.1100.1190.1900.0590.0540.0710.0650.1550.117
sns.heatmap(expr_data, vmin=0, cmap="seismic");


The sum of the values for each replicate should be 1.

expr_data.values.sum(axis=1)
# > array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])


## Model#

### Model specification#

The model is rather straight forward and immediately recognizable as a generalized linear model. The main attributes are the use of the Dirichlet likelihood and exponential link function. Note, that for the PyMC library, the first dimension contains each “group” of data, that is, the values should sum to $1$ along that axis. In this case, the values of each replicate should sum to $1$.

coords = {"tissue": TISSUES, "replicate": REPS}

intercept = np.ones_like(expr_data)
x_expt_cond = np.vstack([np.zeros((N_REPS, N_TISSUES)), np.ones((N_REPS, N_TISSUES))])

with pm.Model(coords=coords) as dirichlet_reg:
a = pm.Normal("a", 0, 5, dims=("tissue",))
b = pm.Normal("b", 0, 2.5, dims=("tissue",))
eta = pm.Deterministic(
"eta",
a[None, :] * intercept + b[None, :] * x_expt_cond,
dims=("replicate", "tissue"),
)
mu = pm.Deterministic("mu", pm.math.exp(eta), dims=("replicate", "tissue"))
y = pm.Dirichlet("y", mu, observed=expr_data.values, dims=("replicate", "tissue"))

# pm.model_to_graphviz(dirichlet_reg)
dirichlet_reg


$$\begin{array}{rcl} a &\sim & \mathcal{N}(0,~5) \\ b &\sim & \mathcal{N}(0,~2.5) \\ \eta &\sim & \operatorname{Deterministic}(f(a, b)) \\ \mu &\sim & \operatorname{Deterministic}(f(\eta)) \\ y &\sim & \operatorname{Dir}(\mu) \end{array}$$

### Sampling#

PyMC does all of the heavy lifting and we just need to press the “Inference Button” with the pm.sample() function.

with dirichlet_reg:
trace = pm.sample(
draws=1000, tune=1000, chains=2, cores=2, random_seed=20, target_accept=0.9
)
_ = pm.sample_posterior_predictive(trace, random_seed=43, extend_inferencedata=True)

Auto-assigning NUTS sampler...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [a, b]

100.00% [4000/4000 00:16<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 29 seconds.

100.00% [2000/2000 00:00<00:00]

## Posterior analysis#

### Recovering known parameters#

The table below shows the summaries of the marginal posterior distributions for the variables $a$ and $b$ of the model.

real_a = np.log(ctrl_tissue_props * 100)
real_b = np.log(expt_tissue_props * 100) - real_a

res_summary = (
az.summary(trace, var_names=["a", "b"], hdi_prob=0.89)
.assign(real=np.hstack([real_a, real_b]))
.reset_index()
)
res_summary


indexmeansdhdi_5.5%hdi_94.5%mcse_meanmcse_sdess_bulkess_tailr_hatreal
0a[tissue-0]2.1220.2401.7352.4900.0120.009393.0923.01.01.973280
1a[tissue-1]2.7820.2232.4323.1420.0130.009309.0657.01.02.691607
2a[tissue-2]2.6910.2312.3493.0820.0130.009334.0566.01.02.618213
3a[tissue-3]2.5290.2342.1632.9030.0130.009324.0481.01.02.603835
4a[tissue-4]2.0090.2471.6482.4350.0130.009399.0586.01.02.006772
5a[tissue-5]2.5380.2312.1802.9060.0130.009322.0641.01.02.465910
6a[tissue-6]2.0150.2501.6532.4350.0130.009363.0804.01.02.118144
7a[tissue-7]0.1590.324-0.3340.6900.0120.008828.01068.01.00.407652
8a[tissue-8]2.4970.2302.1232.8430.0130.009328.0568.01.02.484808
9a[tissue-9]2.5520.2342.1982.9300.0130.009333.0688.01.02.281616
10b[tissue-0]0.0100.335-0.5070.5490.0160.011435.0810.01.0-0.089065
11b[tissue-1]-0.3510.313-0.8410.1610.0160.011401.0852.01.0-0.353870
12b[tissue-2]-0.0860.313-0.6360.3720.0150.011413.0680.01.00.009744
13b[tissue-3]0.0650.318-0.4450.5600.0160.011409.0746.01.00.099328
14b[tissue-4]0.0090.334-0.5350.5280.0150.011486.0884.01.0-0.184059
15b[tissue-5]-0.6820.324-1.191-0.1600.0160.011433.0743.01.0-0.720852
16b[tissue-6]0.4370.331-0.0820.9870.0160.011423.0759.01.00.162447
17b[tissue-7]2.0450.3891.4432.6760.0150.010703.01041.01.01.981890
18b[tissue-8]0.2980.306-0.1890.7950.0160.011390.0761.01.00.085959
19b[tissue-9]-0.3840.325-0.8760.1430.0160.011423.0797.01.0-0.129034

The plot below shows the posterior estimates (blue) against the known proportions (orange).

_, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
data=res_summary,
y="index",
x="mean",
color="tab:blue",
ax=ax,
zorder=10,
label="est.",
)
ax.hlines(
res_summary["index"],
xmin=res_summary["hdi_5.5%"],
xmax=res_summary["hdi_94.5%"],
color="tab:blue",
alpha=0.5,
zorder=5,
)
sns.scatterplot(
data=res_summary,
y="index",
x="real",
ax=ax,
color="tab:orange",
zorder=20,
label="real",
)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()


### Posterior predictive distribution#

post_pred = (
trace.posterior_predictive["y"]
.to_dataframe()
.reset_index()
.filter_column_isin("replicate", ["C-0", "E-0"])
.assign(condition=lambda d: [x[0] for x in d["replicate"]])
)

plot_expr_data = (
expr_data.copy()
.reset_index()
.pivot_longer("replicate", names_to="tissue", values_to="expr")
.assign(condition=lambda d: [x[0] for x in d["replicate"]])
)

violin_pal = {"C": "#cab2d6", "E": "#b2df8a"}
point_pal = {"C": "#6a3d9a", "E": "#33a02c"}

_, ax = plt.subplots(figsize=(5, 7))
sns.violinplot(
data=post_pred,
x="y",
y="tissue",
hue="condition",
palette=violin_pal,
linewidth=0.5,
ax=ax,
)
sns.stripplot(
data=plot_expr_data,
x="expr",
y="tissue",
hue="condition",
palette=point_pal,
dodge=True,
ax=ax,
)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="condition")

<matplotlib.legend.Legend at 0x105b11de0>


## Session Info#

%load_ext watermark
%watermark -d -u -v -iv -b -h -m

Last updated: 2022-11-09

Python implementation: CPython
Python version       : 3.10.6
IPython version      : 8.4.0

Compiler    : Clang 13.0.1
OS          : Darwin
Release     : 21.6.0
Machine     : x86_64
Processor   : i386
CPU cores   : 4
Architecture: 64bit

Hostname: JHCookMac.local

Git branch: sex-diff-expr-better

matplotlib: 3.5.3
pandas    : 1.4.4
numpy     : 1.21.6
arviz     : 0.12.1
pymc      : 4.1.5
janitor   : 0.22.0
seaborn   : 0.11.2


## Related

PyMC3 MCMC performance with and without Theano's NumPy BLAS warning (updated with PyMC v4 comparison!)