Comparison of XSMC and PSMC¶
Setup code¶
[2]:
import xsmc
import xsmc.sampler
from xsmc import Segmentation
from xsmc.supporting.plotting import *
from xsmc.supporting.kde_ne import kde_ne
import matplotlib.pyplot as plt
import numpy as np
import msprime as msp
from scipy.interpolate import PPoly
import tskit
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
import logging
import os.path
logging.getLogger("xsmc").setLevel(logging.INFO)
[3]:
# Ensure reproducibility in posterior sampling
np.random.seed(1)
def seed():
return np.random.randint(1, np.iinfo(np.int32).max)
The xsmc.supporting.psmc
module runs Li & Durbin’s original PSMC method on tree sequence data.
[4]:
!pip install git+https://github.com/terhorst/mspsmc@e583e196f
Collecting git+https://github.com/terhorst/mspsmc@e583e196f
Cloning https://github.com/terhorst/mspsmc (to revision e583e196f) to /tmp/pip-req-build-z7s9zpy2
WARNING: Did not find branch or tag 'e583e196f', assuming revision or ref.
Requirement already satisfied (use --upgrade to upgrade): mspsmc==0.1.0 from git+https://github.com/terhorst/mspsmc@e583e196f in /home/terhorst/opt/py37/lib/python3.7/site-packages
Requirement already satisfied: scipy in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (1.4.1)
Requirement already satisfied: numpy in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (1.18.2)
Requirement already satisfied: tskit in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (0.3.1)
Requirement already satisfied: jsonschema in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (3.1.1)
Requirement already satisfied: h5py in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (2.10.0)
Requirement already satisfied: attrs>=19.1.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (19.3.0)
Requirement already satisfied: svgwrite in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (1.3.1)
Requirement already satisfied: importlib-metadata in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (0.23)
Requirement already satisfied: pyrsistent>=0.14.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (0.15.4)
Requirement already satisfied: setuptools in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (49.3.2)
Requirement already satisfied: six>=1.11.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (1.14.0)
Requirement already satisfied: pyparsing>=2.0.1 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from svgwrite->tskit->mspsmc==0.1.0) (2.4.2)
Requirement already satisfied: zipp>=0.5 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from importlib-metadata->jsonschema->tskit->mspsmc==0.1.0) (0.6.0)
Requirement already satisfied: more-itertools in /home/terhorst/opt/py37/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata->jsonschema->tskit->mspsmc==0.1.0) (7.2.0)
Building wheels for collected packages: mspsmc
Building wheel for mspsmc (setup.py) ... done
Created wheel for mspsmc: filename=mspsmc-0.1.0-py2.py3-none-any.whl size=5757 sha256=6961418aa633d0180cec574751c509435246e3a7d5541bd1d739840903c4ee9e
Stored in directory: /tmp/pip-ephem-wheel-cache-5a52eh0a/wheels/af/38/87/3eee2ca03eb580f4df9e22215d1e0ff5a8914f16a45cf26f93
Successfully built mspsmc
WARNING: You are using pip version 20.2.3; however, version 20.2.4 is available.
You should consider upgrading via the '/home/terhorst/opt/py37/bin/python3 -m pip install --upgrade pip' command.
[5]:
# psmc code
import os
os.environ["PSMC_PATH"] = "/scratch/psmc/psmc" # update as needed if running locally
import mspsmc
def run_psmc(reps, rho_over_theta=1.0):
def f(data, *args):
return mspsmc.msPSMC([(data, (0, 1))]).estimate(*args)
with ThreadPoolExecutor() as p:
futs = [p.submit(f, data, "-r", 1.0 / rho_over_theta) for data in reps]
res = [f.result() for f in futs]
rescaled = []
for r in res:
# See Appendix I of https://github.com/lh3/psmc/blob/master/README
N0 = r.theta / (4 * mu) / 100
rescaled.append(r.Ne.rescale(2 * N0))
return rescaled
mspsmc.__psmc__version__
[5]:
'0.6.5-r67'
[6]:
# other supporting functions
L = int(5e7) # length of simulated chromosome
mu = 1.4e-8 # mutation rate/bp/gen
M = 25 # number of replicates
def parallel_sample(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):
xs = [
xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)
for data in reps
]
with ThreadPoolExecutor() as p:
futs = [
p.submit(x.sample_heights, j=j, k=k, seed=seed()) for i, x in enumerate(xs)
]
return np.array(
[f.result() * 2 * x.theta / (4 * mu) for f, x in zip(futs, xs)]
) # rescale each sampled path by 2N0 so that segment heights are in generations
def parallel_sample0(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):
xs = [
xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)
for data in reps
]
with ProcessPoolExecutor() as p:
futs = [
p.submit(x.sample, k=j, seed=seed(), prime=True) for i, x in enumerate(xs)
]
paths = [
[p.rescale(2 * x.theta / (4 * mu)) for p in f.result()]
for f, x in zip(futs, xs)
] # rescale each sampled path by 2N0 so that segment heights are in generations
return np.array([[p(np.random.uniform(0, L, k)) for p in path] for path in paths])
def sim_data(de, **kwargs):
d = dict(
sample_size=2,
recombination_rate=1.4e-8,
mutation_rate=mu,
length=L,
demographic_events=de,
)
d.update(kwargs)
with ThreadPoolExecutor() as p:
futs = [p.submit(msp.simulate, **d, random_seed=seed()) for i in range(M)]
return [f.result() for f in futs]
def summarize_lines(xys, x0):
"summarize a collection of lines by plotting their median and IQR"
y0 = []
for x, y in xys:
f = interp1d(
x, y, bounds_error=False
) # interpolate linearly to a common set of points
y0.append(f(x0))
return np.nanquantile(y0, [0.5, 0.25, 0.75], axis=0) # median, q25, q75
def plot_summary(ax, lines, x, label=None, **kwargs):
all_x = np.concatenate([l[0] for l in lines]).reshape(-1)
m, q25, q75 = summarize_lines(lines, x)
ax.plot(x, m / 2, label=label, **kwargs)
ax.fill_between(x, q25 / 2, q75 / 2, **kwargs, alpha=0.5)
def plot_combined(lines_psmc, lines_xsmc, truth, ax=None):
if ax is None:
ax = plt.gca()
x = np.geomspace(1e2, 1e6, 200)
for lines, label, color in zip(
[lines_xsmc, lines_psmc], ["XSMC", "PSMC"], ["tab:blue", "tab:red"]
):
# for i, (x, y) in enumerate(lines):
# ax.plot(
# x,
# y,
# color=color,
# label=label if i == 0 else None,
# alpha=5.0 / len(lines),
# )
plot_summary(ax, lines, x, label=label, color=color)
ax.plot(
*truth, "--", color="darkgrey", label="Truth", drawstyle="steps-post", zorder=0
)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim(1e2, 1e5)
ax.set_ylim(1e3, 1e6)
[7]:
from collections import Counter
from scipy.signal import convolve
from xsmc.supporting.kde_ne import *
def parallel_kde(sampled_heights, **kwargs):
with ProcessPoolExecutor() as p:
futs = [p.submit(kde_ne, h.reshape(-1), **kwargs) for h in sampled_heights]
return [(f.result()[0], f.result()[1]) for f in futs]
Constant effective population size¶
The simplest case. First we check the estimator on “perfect” data, that is i.i.d. samples from the true distribution:
[8]:
de = [msp.PopulationParametersChange(time=0, initial_size=1e4)]
Perfect data¶
Verify the estimator on “perfect” data:
[9]:
true_data = np.array(
[
next(sim.trees()).get_time(2)
for sim in msp.simulate(
num_replicates=10000, demographic_events=de, sample_size=2
)
]
)
x, y = kde_ne(true_data)
plt.plot(x, y / 2, label="Fitted")
plt.plot(plt.xlim(), [1e4] * 2, "--", color="darkgrey", label="Truth")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.ylim(1e3, 1e5)
[9]:
(1000.0, 100000.0)
2020-10-22 17:02:57,282 WARNING matplotlib.font_manager MainThread findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.
XSMC¶
Next we simulate data and sample from the posterior instead.
[10]:
data = sim_data(de)
len(data)
[10]:
25
[ ]:
sampled_heights = parallel_sample(data)
sampled_heights.shape
[ ]:
lines_xsmc = parallel_kde(sampled_heights)
[ ]:
# Diagnostic
q = np.linspace(0, 1.0, 100)
plt.plot(np.quantile(sampled_heights.reshape(-1), q), np.quantile(true_data, q))
plt.plot(plt.xlim(), plt.xlim())
PSMC¶
[ ]:
psmc_out = run_psmc(data)
[ ]:
x_psmc = np.geomspace(1e2, 1e5, 100)
lines_psmc = [(x_psmc, r(x_psmc)) for r in psmc_out]
Combined plot for paper¶
[ ]:
truth = ([1e2, 1e6], [1e4, 1e4])
fig, axs = plt.subplots(ncols=3, figsize=(12, 4.5), sharex=True, sharey=True, dpi=150)
plot_combined(lines_psmc, lines_xsmc, truth, axs[0])
# a.legend()
axs[0].set_xlim(1e2, 1e5)
axs[0].set_ylim(1e3, 1e6)
axs[0].set_title("Constant")
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor="none", top=False, bottom=False, left=False, right=False)
# plt.xlabel("common X")
# plt.ylabel("common Y")
plt.xlabel("Generations")
plt.ylabel("$N_e$")
plt.tight_layout(pad=1.5)
Recent growth¶
Perfect data¶
[ ]:
de = [
msp.PopulationParametersChange(time=0, initial_size=1e6),
msp.PopulationParametersChange(time=1e3, initial_size=5e3),
msp.PopulationParametersChange(time=2e3, initial_size=2e4),
]
[ ]:
true_data = np.array(
[
next(sim.trees()).get_time(2)
for sim in msp.simulate(
num_replicates=10000, demographic_events=de, sample_size=2, Ne=1
)
]
)
[ ]:
x, y = kde_ne(true_data)
plt.plot(x[::50], y[::50] / 2) # downsample the curves to make plotting faster
truth = ([0, 1e3, 2e3, 1e5], [1e6, 5e3, 2e4, 2e4])
plt.plot(
*truth, "--", color="darkgrey", label="Truth", drawstyle="steps-post",
)
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.xlim(1e2, 1e5)
XSMC¶
[ ]:
data = sim_data(de)
[ ]:
sampled_heights = parallel_sample(data)
[ ]:
lines_xsmc = parallel_kde(sampled_heights)
Combined plot for paper¶
[ ]:
plot_combined(lines_psmc, lines_xsmc, truth, axs[1])
axs[1].set_title("Growth")
fig
Zigzag¶
[ ]:
import stdpopsim
species = stdpopsim.get_species("HomSap")
model = species.get_demographic_model("Zigzag_1S14")
de = [
msp.PopulationParametersChange(time=0, initial_size=14312)
] + model.demographic_events
Perfect data¶
[ ]:
true_data = np.array(
[
next(sim.trees()).get_time(2)
for sim in msp.simulate(
num_replicates=10000, demographic_events=de, sample_size=2,
)
]
)
[ ]:
x, y = kde_ne(true_data)
plt.plot(x[::50], y[::50] / 2) # downsample the curves to make plotting faster
f = plot_de(de, 14312)
x_zz = np.geomspace(1e2, 1e6, 1000)
truth = (x_zz, f(x_zz))
plt.plot(
*truth, "--", color="darkgrey", label="Truth", drawstyle="steps-post",
)
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.xlim(1e2, 1e5)
XSMC¶
[ ]:
data = sim_data(de, recombination_rate=1e-9)
[ ]:
sampled_heights = parallel_sample(data, rho_over_theta=1e-9 / mu)
[ ]:
lines_xsmc = parallel_kde(sampled_heights)
PSMC¶
[ ]:
psmc_out = run_psmc(data, rho_over_theta=1e-9 / mu)
[ ]:
lines_psmc = [(x_psmc, r(x_psmc)) for r in psmc_out]
Combined plot for paper¶
[38]:
plot_combined(lines_psmc, lines_xsmc, truth, axs[2])
axs[2].set_title("Zigzag")
fig
/home/terhorst/opt/py37/lib/python3.7/site-packages/numpy/lib/nanfunctions.py:1392: RuntimeWarning: All-NaN slice encountered
overwrite_input, interpolation)
[38]:
[42]:
axs[0].legend()
[42]:
<matplotlib.legend.Legend at 0x7f1da126bcd0>
[40]:
fig.savefig(os.path.join(PAPER_ROOT, "figures", "xsmc_psmc.pdf"))
Additional diagnostics¶
[ ]:
b = np.linspace(4, 12, 32)
for d in np.array(sampled_heights)[..., 0], true_data:
plt.hist(np.log(d).reshape(-1), bins=b, density=True, alpha=0.5)