Dirichlet Regression with PyMC
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.
I want to apologize at the top for the general lack-luster appearance and text in this post. It is meant to serve as a quick, simple guide, so I chose to keep it relatively light on text and explanation.
Introduction
Below, I provide a simple example of a Dirichlet regression in PyMC. This form of generalized linear model is appropriate when modeling proportions of multiple groups, that is, when modeling a collection of positive values that must sum to a constant. Some common examples include ratios and percentages.
For this example, I used a simplified case that was the original impetus for me looking in this form of model. I have measured a protein’s expression in two groups, a control and experimental, across $10$ tissues. I have measured the expression in $6$ replicates for each condition across all $10$ tissues. Therefore, I have $10 \times 6 \times 2$ measurements. The values are all greater than or equal to $0$ (i.e. 0 or positive) and the sum of the values for each replicate sum to $1$.
I want to know if the expression of the protein is different between control and experiment in each tissue.
Because of the constraint on the values being $\ge 0$ and summing to $1$ across replicates, the likelihood should be a Dirichlet distribution. The exponential is the appropriate link function between the likelihood and linear combination of variables.
Setup
import arviz as az import janitor # noqa: F401 import matplotlib.pyplot as plt import numpy as np import pandas as pd import pymc as pm import seaborn as sns %matplotlib inline %config InlineBackend.figure_format = 'retina' sns.set_style("whitegrid")
Data generation
A fake dataset was produced for the situation described above.
The vectors ctrl_tissue_props
and expt_tissue_props
contain the “true” proportions of protein expression across the ten tissues for the control and experimental conditions.
These were randomly generated as printed at the end of the code block.
N_TISSUES = 10 N_REPS = 6 CONDITIONS = ["C", "E"] TISSUES = [f"tissue-{i}" for i in range(N_TISSUES)] REPS = [f"{CONDITIONS[0]}-{i}" for i in range(N_REPS)] REPS += [f"{CONDITIONS[1]}-{i}" for i in range(N_REPS)] np.random.seed(909) ctrl_tissue_props = np.random.beta(2, 2, N_TISSUES) ctrl_tissue_props = ctrl_tissue_props / np.sum(ctrl_tissue_props) expt_tissue_props = np.random.beta(2, 2, N_TISSUES) expt_tissue_props = expt_tissue_props / np.sum(expt_tissue_props) print("Real proportions for each tissue:") print(np.vstack([ctrl_tissue_props, expt_tissue_props]).round(3)) Real proportions for each tissue: [[0.072 0.148 0.137 0.135 0.074 0.118 0.083 0.015 0.12 0.098] [0.066 0.104 0.138 0.149 0.062 0.057 0.098 0.109 0.131 0.086]]
Protein expression values were sampled using these proportions, multiplied by 100 to reduce the variability in the sampled values. Recall that the Dirichlet is effectively a multi-class Beta distribution, so the input numbers can be thought of as the observed number of instances for each class. The more observations, the more confidence that the observed frequencies are representative of the true proportions.
_ctrl_data = np.random.dirichlet(ctrl_tissue_props * 100, N_REPS) _expt_data = np.random.dirichlet(expt_tissue_props * 100, N_REPS) expr_data = ( pd.DataFrame(np.vstack([_ctrl_data, _expt_data]), columns=TISSUES) .assign(replicate=REPS) .set_index("replicate") ) expr_data.round(3)
tissue-0 | tissue-1 | tissue-2 | tissue-3 | tissue-4 | tissue-5 | tissue-6 | tissue-7 | tissue-8 | tissue-9 | |
---|---|---|---|---|---|---|---|---|---|---|
replicate | ||||||||||
C-0 | 0.085 | 0.116 | 0.184 | 0.175 | 0.038 | 0.101 | 0.069 | 0.014 | 0.102 | 0.115 |
C-1 | 0.134 | 0.188 | 0.101 | 0.119 | 0.075 | 0.104 | 0.052 | 0.001 | 0.121 | 0.107 |
C-2 | 0.098 | 0.125 | 0.127 | 0.138 | 0.094 | 0.091 | 0.134 | 0.017 | 0.080 | 0.096 |
C-3 | 0.069 | 0.154 | 0.140 | 0.082 | 0.065 | 0.182 | 0.054 | 0.011 | 0.110 | 0.132 |
C-4 | 0.033 | 0.208 | 0.151 | 0.090 | 0.067 | 0.109 | 0.064 | 0.003 | 0.160 | 0.115 |
C-5 | 0.074 | 0.130 | 0.130 | 0.113 | 0.081 | 0.129 | 0.059 | 0.020 | 0.111 | 0.152 |
E-0 | 0.100 | 0.105 | 0.114 | 0.081 | 0.088 | 0.056 | 0.120 | 0.087 | 0.167 | 0.081 |
E-1 | 0.043 | 0.124 | 0.184 | 0.098 | 0.071 | 0.040 | 0.122 | 0.071 | 0.157 | 0.089 |
E-2 | 0.099 | 0.108 | 0.102 | 0.139 | 0.089 | 0.039 | 0.115 | 0.092 | 0.158 | 0.059 |
E-3 | 0.076 | 0.074 | 0.122 | 0.142 | 0.058 | 0.062 | 0.103 | 0.081 | 0.106 | 0.176 |
E-4 | 0.098 | 0.103 | 0.117 | 0.113 | 0.048 | 0.110 | 0.113 | 0.104 | 0.166 | 0.027 |
E-5 | 0.059 | 0.110 | 0.119 | 0.190 | 0.059 | 0.054 | 0.071 | 0.065 | 0.155 | 0.117 |
sns.heatmap(expr_data, vmin=0, cmap="seismic");
The sum of the values for each replicate should be 1.
expr_data.values.sum(axis=1) # > array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Model
Model specification
The model is rather straight forward and immediately recognizable as a generalized linear model. The main attributes are the use of the Dirichlet likelihood and exponential link function. Note, that for the PyMC library, the first dimension contains each “group” of data, that is, the values should sum to $1$ along that axis. In this case, the values of each replicate should sum to $1$.
coords = {"tissue": TISSUES, "replicate": REPS} intercept = np.ones_like(expr_data) x_expt_cond = np.vstack([np.zeros((N_REPS, N_TISSUES)), np.ones((N_REPS, N_TISSUES))]) with pm.Model(coords=coords) as dirichlet_reg: a = pm.Normal("a", 0, 5, dims=("tissue",)) b = pm.Normal("b", 0, 2.5, dims=("tissue",)) eta = pm.Deterministic( "eta", a[None, :] * intercept + b[None, :] * x_expt_cond, dims=("replicate", "tissue"), ) mu = pm.Deterministic("mu", pm.math.exp(eta), dims=("replicate", "tissue")) y = pm.Dirichlet("y", mu, observed=expr_data.values, dims=("replicate", "tissue")) # pm.model_to_graphviz(dirichlet_reg) dirichlet_reg
$$
\begin{array}{rcl}
a &\sim & \mathcal{N}(0,~5) \\
b &\sim & \mathcal{N}(0,~2.5) \\
\eta &\sim & \operatorname{Deterministic}(f(a, b)) \\
\mu &\sim & \operatorname{Deterministic}(f(\eta)) \\
y &\sim & \operatorname{Dir}(\mu)
\end{array}
$$
Sampling
PyMC does all of the heavy lifting and we just need to press the “Inference Button” with the pm.sample()
function.
with dirichlet_reg: trace = pm.sample( draws=1000, tune=1000, chains=2, cores=2, random_seed=20, target_accept=0.9 ) _ = pm.sample_posterior_predictive(trace, random_seed=43, extend_inferencedata=True) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [a, b]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 29 seconds.
Posterior analysis
Recovering known parameters
The table below shows the summaries of the marginal posterior distributions for the variables $a$ and $b$ of the model.
real_a = np.log(ctrl_tissue_props * 100) real_b = np.log(expt_tissue_props * 100) - real_a res_summary = ( az.summary(trace, var_names=["a", "b"], hdi_prob=0.89) .assign(real=np.hstack([real_a, real_b])) .reset_index() ) res_summary
index | mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | real | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | a[tissue-0] | 2.122 | 0.240 | 1.735 | 2.490 | 0.012 | 0.009 | 393.0 | 923.0 | 1.0 | 1.973280 |
1 | a[tissue-1] | 2.782 | 0.223 | 2.432 | 3.142 | 0.013 | 0.009 | 309.0 | 657.0 | 1.0 | 2.691607 |
2 | a[tissue-2] | 2.691 | 0.231 | 2.349 | 3.082 | 0.013 | 0.009 | 334.0 | 566.0 | 1.0 | 2.618213 |
3 | a[tissue-3] | 2.529 | 0.234 | 2.163 | 2.903 | 0.013 | 0.009 | 324.0 | 481.0 | 1.0 | 2.603835 |
4 | a[tissue-4] | 2.009 | 0.247 | 1.648 | 2.435 | 0.013 | 0.009 | 399.0 | 586.0 | 1.0 | 2.006772 |
5 | a[tissue-5] | 2.538 | 0.231 | 2.180 | 2.906 | 0.013 | 0.009 | 322.0 | 641.0 | 1.0 | 2.465910 |
6 | a[tissue-6] | 2.015 | 0.250 | 1.653 | 2.435 | 0.013 | 0.009 | 363.0 | 804.0 | 1.0 | 2.118144 |
7 | a[tissue-7] | 0.159 | 0.324 | -0.334 | 0.690 | 0.012 | 0.008 | 828.0 | 1068.0 | 1.0 | 0.407652 |
8 | a[tissue-8] | 2.497 | 0.230 | 2.123 | 2.843 | 0.013 | 0.009 | 328.0 | 568.0 | 1.0 | 2.484808 |
9 | a[tissue-9] | 2.552 | 0.234 | 2.198 | 2.930 | 0.013 | 0.009 | 333.0 | 688.0 | 1.0 | 2.281616 |
10 | b[tissue-0] | 0.010 | 0.335 | -0.507 | 0.549 | 0.016 | 0.011 | 435.0 | 810.0 | 1.0 | -0.089065 |
11 | b[tissue-1] | -0.351 | 0.313 | -0.841 | 0.161 | 0.016 | 0.011 | 401.0 | 852.0 | 1.0 | -0.353870 |
12 | b[tissue-2] | -0.086 | 0.313 | -0.636 | 0.372 | 0.015 | 0.011 | 413.0 | 680.0 | 1.0 | 0.009744 |
13 | b[tissue-3] | 0.065 | 0.318 | -0.445 | 0.560 | 0.016 | 0.011 | 409.0 | 746.0 | 1.0 | 0.099328 |
14 | b[tissue-4] | 0.009 | 0.334 | -0.535 | 0.528 | 0.015 | 0.011 | 486.0 | 884.0 | 1.0 | -0.184059 |
15 | b[tissue-5] | -0.682 | 0.324 | -1.191 | -0.160 | 0.016 | 0.011 | 433.0 | 743.0 | 1.0 | -0.720852 |
16 | b[tissue-6] | 0.437 | 0.331 | -0.082 | 0.987 | 0.016 | 0.011 | 423.0 | 759.0 | 1.0 | 0.162447 |
17 | b[tissue-7] | 2.045 | 0.389 | 1.443 | 2.676 | 0.015 | 0.010 | 703.0 | 1041.0 | 1.0 | 1.981890 |
18 | b[tissue-8] | 0.298 | 0.306 | -0.189 | 0.795 | 0.016 | 0.011 | 390.0 | 761.0 | 1.0 | 0.085959 |
19 | b[tissue-9] | -0.384 | 0.325 | -0.876 | 0.143 | 0.016 | 0.011 | 423.0 | 797.0 | 1.0 | -0.129034 |
The plot below shows the posterior estimates (blue) against the known proportions (orange).
_, ax = plt.subplots(figsize=(5, 5)) sns.scatterplot( data=res_summary, y="index", x="mean", color="tab:blue", ax=ax, zorder=10, label="est.", ) ax.hlines( res_summary["index"], xmin=res_summary["hdi_5.5%"], xmax=res_summary["hdi_94.5%"], color="tab:blue", alpha=0.5, zorder=5, ) sns.scatterplot( data=res_summary, y="index", x="real", ax=ax, color="tab:orange", zorder=20, label="real", ) ax.legend(loc="upper left", bbox_to_anchor=(1, 1)) plt.show()
Posterior predictive distribution
post_pred = ( trace.posterior_predictive["y"] .to_dataframe() .reset_index() .filter_column_isin("replicate", ["C-0", "E-0"]) .assign(condition=lambda d: [x[0] for x in d["replicate"]]) ) plot_expr_data = ( expr_data.copy() .reset_index() .pivot_longer("replicate", names_to="tissue", values_to="expr") .assign(condition=lambda d: [x[0] for x in d["replicate"]]) ) violin_pal = {"C": "#cab2d6", "E": "#b2df8a"} point_pal = {"C": "#6a3d9a", "E": "#33a02c"} _, ax = plt.subplots(figsize=(5, 7)) sns.violinplot( data=post_pred, x="y", y="tissue", hue="condition", palette=violin_pal, linewidth=0.5, ax=ax, ) sns.stripplot( data=plot_expr_data, x="expr", y="tissue", hue="condition", palette=point_pal, dodge=True, ax=ax, ) ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="condition") <matplotlib.legend.Legend at 0x105b11de0>
Session Info
%load_ext watermark %watermark -d -u -v -iv -b -h -m Last updated: 2022-11-09 Python implementation: CPython Python version : 3.10.6 IPython version : 8.4.0 Compiler : Clang 13.0.1 OS : Darwin Release : 21.6.0 Machine : x86_64 Processor : i386 CPU cores : 4 Architecture: 64bit Hostname: JHCookMac.local Git branch: sex-diff-expr-better matplotlib: 3.5.3 pandas : 1.4.4 numpy : 1.21.6 arviz : 0.12.1 pymc : 4.1.5 janitor : 0.22.0 seaborn : 0.11.2
R-bloggers.com offers daily e-mail updates about R news and tutorials about learning R and many other topics. Click here if you're looking to post or find an R/data-science job.
Want to share your content on R-bloggers? click here if you have a blog, or here if you don't.