Site icon R-bloggers

Dirichlet Regression with PyMC

[This article was first published on Posts | Joshua Cook, and kindly contributed to R-bloggers]. (You can report issue about the content on this page here)
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.

I want to apologize at the top for the general lack-luster appearance and text in this post. It is meant to serve as a quick, simple guide, so I chose to keep it relatively light on text and explanation.

Introduction

Below, I provide a simple example of a Dirichlet regression in PyMC. This form of generalized linear model is appropriate when modeling proportions of multiple groups, that is, when modeling a collection of positive values that must sum to a constant. Some common examples include ratios and percentages.

For this example, I used a simplified case that was the original impetus for me looking in this form of model. I have measured a protein’s expression in two groups, a control and experimental, across $10$ tissues. I have measured the expression in $6$ replicates for each condition across all $10$ tissues. Therefore, I have $10 \times 6 \times 2$ measurements. The values are all greater than or equal to $0$ (i.e. 0 or positive) and the sum of the values for each replicate sum to $1$.

I want to know if the expression of the protein is different between control and experiment in each tissue.

Because of the constraint on the values being $\ge 0$ and summing to $1$ across replicates, the likelihood should be a Dirichlet distribution. The exponential is the appropriate link function between the likelihood and linear combination of variables.

Setup

import arviz as az
import janitor # noqa: F401
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
sns.set_style("whitegrid")

Data generation

A fake dataset was produced for the situation described above. The vectors ctrl_tissue_props and expt_tissue_props contain the “true” proportions of protein expression across the ten tissues for the control and experimental conditions. These were randomly generated as printed at the end of the code block.

N_TISSUES = 10
N_REPS = 6
CONDITIONS = ["C", "E"]
TISSUES = [f"tissue-{i}" for i in range(N_TISSUES)]
REPS = [f"{CONDITIONS[0]}-{i}" for i in range(N_REPS)]
REPS += [f"{CONDITIONS[1]}-{i}" for i in range(N_REPS)]
np.random.seed(909)
ctrl_tissue_props = np.random.beta(2, 2, N_TISSUES)
ctrl_tissue_props = ctrl_tissue_props / np.sum(ctrl_tissue_props)
expt_tissue_props = np.random.beta(2, 2, N_TISSUES)
expt_tissue_props = expt_tissue_props / np.sum(expt_tissue_props)
print("Real proportions for each tissue:")
print(np.vstack([ctrl_tissue_props, expt_tissue_props]).round(3))

Real proportions for each tissue:
[[0.072 0.148 0.137 0.135 0.074 0.118 0.083 0.015 0.12 0.098]
[0.066 0.104 0.138 0.149 0.062 0.057 0.098 0.109 0.131 0.086]]

Protein expression values were sampled using these proportions, multiplied by 100 to reduce the variability in the sampled values. Recall that the Dirichlet is effectively a multi-class Beta distribution, so the input numbers can be thought of as the observed number of instances for each class. The more observations, the more confidence that the observed frequencies are representative of the true proportions.

_ctrl_data = np.random.dirichlet(ctrl_tissue_props * 100, N_REPS)
_expt_data = np.random.dirichlet(expt_tissue_props * 100, N_REPS)
expr_data = (
pd.DataFrame(np.vstack([_ctrl_data, _expt_data]), columns=TISSUES)
.assign(replicate=REPS)
.set_index("replicate")
)
expr_data.round(3)
tissue-0 tissue-1 tissue-2 tissue-3 tissue-4 tissue-5 tissue-6 tissue-7 tissue-8 tissue-9
replicate
C-0 0.085 0.116 0.184 0.175 0.038 0.101 0.069 0.014 0.102 0.115
C-1 0.134 0.188 0.101 0.119 0.075 0.104 0.052 0.001 0.121 0.107
C-2 0.098 0.125 0.127 0.138 0.094 0.091 0.134 0.017 0.080 0.096
C-3 0.069 0.154 0.140 0.082 0.065 0.182 0.054 0.011 0.110 0.132
C-4 0.033 0.208 0.151 0.090 0.067 0.109 0.064 0.003 0.160 0.115
C-5 0.074 0.130 0.130 0.113 0.081 0.129 0.059 0.020 0.111 0.152
E-0 0.100 0.105 0.114 0.081 0.088 0.056 0.120 0.087 0.167 0.081
E-1 0.043 0.124 0.184 0.098 0.071 0.040 0.122 0.071 0.157 0.089
E-2 0.099 0.108 0.102 0.139 0.089 0.039 0.115 0.092 0.158 0.059
E-3 0.076 0.074 0.122 0.142 0.058 0.062 0.103 0.081 0.106 0.176
E-4 0.098 0.103 0.117 0.113 0.048 0.110 0.113 0.104 0.166 0.027
E-5 0.059 0.110 0.119 0.190 0.059 0.054 0.071 0.065 0.155 0.117
sns.heatmap(expr_data, vmin=0, cmap="seismic");

The sum of the values for each replicate should be 1.

expr_data.values.sum(axis=1)
# > array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

Model

Model specification

The model is rather straight forward and immediately recognizable as a generalized linear model. The main attributes are the use of the Dirichlet likelihood and exponential link function. Note, that for the PyMC library, the first dimension contains each “group” of data, that is, the values should sum to $1$ along that axis. In this case, the values of each replicate should sum to $1$.

coords = {"tissue": TISSUES, "replicate": REPS}
intercept = np.ones_like(expr_data)
x_expt_cond = np.vstack([np.zeros((N_REPS, N_TISSUES)), np.ones((N_REPS, N_TISSUES))])
with pm.Model(coords=coords) as dirichlet_reg:
a = pm.Normal("a", 0, 5, dims=("tissue",))
b = pm.Normal("b", 0, 2.5, dims=("tissue",))
eta = pm.Deterministic(
"eta",
a[None, :] * intercept + b[None, :] * x_expt_cond,
dims=("replicate", "tissue"),
)
mu = pm.Deterministic("mu", pm.math.exp(eta), dims=("replicate", "tissue"))
y = pm.Dirichlet("y", mu, observed=expr_data.values, dims=("replicate", "tissue"))
# pm.model_to_graphviz(dirichlet_reg)
dirichlet_reg

$$ \begin{array}{rcl} a &\sim & \mathcal{N}(0,~5) \\
b &\sim & \mathcal{N}(0,~2.5) \\
\eta &\sim & \operatorname{Deterministic}(f(a, b)) \\
\mu &\sim & \operatorname{Deterministic}(f(\eta)) \\
y &\sim & \operatorname{Dir}(\mu) \end{array} $$

Sampling

PyMC does all of the heavy lifting and we just need to press the “Inference Button” with the pm.sample() function.

with dirichlet_reg:
trace = pm.sample(
draws=1000, tune=1000, chains=2, cores=2, random_seed=20, target_accept=0.9
)
_ = pm.sample_posterior_predictive(trace, random_seed=43, extend_inferencedata=True)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [a, b]
< progress value='4000' class='' max='4000' style='width:300px; height:20px; vertical-align: middle;'> 100.00% [4000/4000 00:16<00:00 Sampling 2 chains, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 29 seconds.
< progress value='2000' class='' max='2000' style='width:300px; height:20px; vertical-align: middle;'> 100.00% [2000/2000 00:00<00:00]

Posterior analysis

Recovering known parameters

The table below shows the summaries of the marginal posterior distributions for the variables $a$ and $b$ of the model.

real_a = np.log(ctrl_tissue_props * 100)
real_b = np.log(expt_tissue_props * 100) - real_a
res_summary = (
az.summary(trace, var_names=["a", "b"], hdi_prob=0.89)
.assign(real=np.hstack([real_a, real_b]))
.reset_index()
)
res_summary
index mean sd hdi_5.5% hdi_94.5% mcse_mean mcse_sd ess_bulk ess_tail r_hat real
0 a[tissue-0] 2.122 0.240 1.735 2.490 0.012 0.009 393.0 923.0 1.0 1.973280
1 a[tissue-1] 2.782 0.223 2.432 3.142 0.013 0.009 309.0 657.0 1.0 2.691607
2 a[tissue-2] 2.691 0.231 2.349 3.082 0.013 0.009 334.0 566.0 1.0 2.618213
3 a[tissue-3] 2.529 0.234 2.163 2.903 0.013 0.009 324.0 481.0 1.0 2.603835
4 a[tissue-4] 2.009 0.247 1.648 2.435 0.013 0.009 399.0 586.0 1.0 2.006772
5 a[tissue-5] 2.538 0.231 2.180 2.906 0.013 0.009 322.0 641.0 1.0 2.465910
6 a[tissue-6] 2.015 0.250 1.653 2.435 0.013 0.009 363.0 804.0 1.0 2.118144
7 a[tissue-7] 0.159 0.324 -0.334 0.690 0.012 0.008 828.0 1068.0 1.0 0.407652
8 a[tissue-8] 2.497 0.230 2.123 2.843 0.013 0.009 328.0 568.0 1.0 2.484808
9 a[tissue-9] 2.552 0.234 2.198 2.930 0.013 0.009 333.0 688.0 1.0 2.281616
10 b[tissue-0] 0.010 0.335 -0.507 0.549 0.016 0.011 435.0 810.0 1.0 -0.089065
11 b[tissue-1] -0.351 0.313 -0.841 0.161 0.016 0.011 401.0 852.0 1.0 -0.353870
12 b[tissue-2] -0.086 0.313 -0.636 0.372 0.015 0.011 413.0 680.0 1.0 0.009744
13 b[tissue-3] 0.065 0.318 -0.445 0.560 0.016 0.011 409.0 746.0 1.0 0.099328
14 b[tissue-4] 0.009 0.334 -0.535 0.528 0.015 0.011 486.0 884.0 1.0 -0.184059
15 b[tissue-5] -0.682 0.324 -1.191 -0.160 0.016 0.011 433.0 743.0 1.0 -0.720852
16 b[tissue-6] 0.437 0.331 -0.082 0.987 0.016 0.011 423.0 759.0 1.0 0.162447
17 b[tissue-7] 2.045 0.389 1.443 2.676 0.015 0.010 703.0 1041.0 1.0 1.981890
18 b[tissue-8] 0.298 0.306 -0.189 0.795 0.016 0.011 390.0 761.0 1.0 0.085959
19 b[tissue-9] -0.384 0.325 -0.876 0.143 0.016 0.011 423.0 797.0 1.0 -0.129034

The plot below shows the posterior estimates (blue) against the known proportions (orange).

_, ax = plt.subplots(figsize=(5, 5))
sns.scatterplot(
data=res_summary,
y="index",
x="mean",
color="tab:blue",
ax=ax,
zorder=10,
label="est.",
)
ax.hlines(
res_summary["index"],
xmin=res_summary["hdi_5.5%"],
xmax=res_summary["hdi_94.5%"],
color="tab:blue",
alpha=0.5,
zorder=5,
)
sns.scatterplot(
data=res_summary,
y="index",
x="real",
ax=ax,
color="tab:orange",
zorder=20,
label="real",
)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1))
plt.show()

Posterior predictive distribution

post_pred = (
trace.posterior_predictive["y"]
.to_dataframe()
.reset_index()
.filter_column_isin("replicate", ["C-0", "E-0"])
.assign(condition=lambda d: [x[0] for x in d["replicate"]])
)
plot_expr_data = (
expr_data.copy()
.reset_index()
.pivot_longer("replicate", names_to="tissue", values_to="expr")
.assign(condition=lambda d: [x[0] for x in d["replicate"]])
)
violin_pal = {"C": "#cab2d6", "E": "#b2df8a"}
point_pal = {"C": "#6a3d9a", "E": "#33a02c"}
_, ax = plt.subplots(figsize=(5, 7))
sns.violinplot(
data=post_pred,
x="y",
y="tissue",
hue="condition",
palette=violin_pal,
linewidth=0.5,
ax=ax,
)
sns.stripplot(
data=plot_expr_data,
x="expr",
y="tissue",
hue="condition",
palette=point_pal,
dodge=True,
ax=ax,
)
ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="condition")

<matplotlib.legend.Legend at 0x105b11de0>


Session Info

%load_ext watermark
%watermark -d -u -v -iv -b -h -m

Last updated: 2022-11-09
Python implementation: CPython
Python version : 3.10.6
IPython version : 8.4.0
Compiler : Clang 13.0.1
OS : Darwin
Release : 21.6.0
Machine : x86_64
Processor : i386
CPU cores : 4
Architecture: 64bit
Hostname: JHCookMac.local
Git branch: sex-diff-expr-better
matplotlib: 3.5.3
pandas : 1.4.4
numpy : 1.21.6
arviz : 0.12.1
pymc : 4.1.5
janitor : 0.22.0
seaborn : 0.11.2
To leave a comment for the author, please follow the link and comment on their blog: Posts | Joshua Cook.

R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.