{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden" }, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 1;\n", " var nbb_unformatted_code = \"%load_ext nb_black\\nimport os\\n\\nPAPER_ROOT = os.path.expanduser(os.environ.get(\\\"PAPER_ROOT\\\", \\\".\\\"))\";\n", " var nbb_formatted_code = \"%load_ext nb_black\\nimport os\\n\\nPAPER_ROOT = os.path.expanduser(os.environ.get(\\\"PAPER_ROOT\\\", \\\".\\\"))\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%load_ext nb_black\n", "import os\n", "\n", "PAPER_ROOT = os.path.expanduser(os.environ.get(\"PAPER_ROOT\", \".\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Comparison of XSMC and PSMC" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup code" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "code_folding": [] }, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 2;\n", " var nbb_unformatted_code = \"import xsmc\\nimport xsmc.sampler\\nfrom xsmc import Segmentation\\nfrom xsmc.supporting.plotting import *\\nfrom xsmc.supporting.kde_ne import kde_ne\\nimport matplotlib.pyplot as plt\\nimport numpy as np\\nimport msprime as msp\\nfrom scipy.interpolate import PPoly\\nimport tskit\\nfrom concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor\\nimport logging\\nimport os.path\\n\\nlogging.getLogger(\\\"xsmc\\\").setLevel(logging.INFO)\";\n", " var nbb_formatted_code = \"import xsmc\\nimport xsmc.sampler\\nfrom xsmc import Segmentation\\nfrom xsmc.supporting.plotting import *\\nfrom xsmc.supporting.kde_ne import kde_ne\\nimport matplotlib.pyplot as plt\\nimport numpy as np\\nimport msprime as msp\\nfrom scipy.interpolate import PPoly\\nimport tskit\\nfrom concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor\\nimport logging\\nimport os.path\\n\\nlogging.getLogger(\\\"xsmc\\\").setLevel(logging.INFO)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import xsmc\n", "import xsmc.sampler\n", "from xsmc import Segmentation\n", "from xsmc.supporting.plotting import *\n", "from xsmc.supporting.kde_ne import kde_ne\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import msprime as msp\n", "from scipy.interpolate import PPoly\n", "import tskit\n", "from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor\n", "import logging\n", "import os.path\n", "\n", "logging.getLogger(\"xsmc\").setLevel(logging.INFO)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 3;\n", " var nbb_unformatted_code = \"# Ensure reproducibility in posterior sampling\\nnp.random.seed(1)\\n\\n\\ndef seed():\\n return np.random.randint(1, np.iinfo(np.int32).max)\";\n", " var nbb_formatted_code = \"# Ensure reproducibility in posterior sampling\\nnp.random.seed(1)\\n\\n\\ndef seed():\\n return np.random.randint(1, np.iinfo(np.int32).max)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Ensure reproducibility in posterior sampling\n", "np.random.seed(1)\n", "\n", "\n", "def seed():\n", " return np.random.randint(1, np.iinfo(np.int32).max)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `xsmc.supporting.psmc` module runs Li & Durbin's original PSMC method on tree sequence data." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting git+https://github.com/terhorst/mspsmc@e583e196f\n", " Cloning https://github.com/terhorst/mspsmc (to revision e583e196f) to /tmp/pip-req-build-z7s9zpy2\n", "\u001b[33m WARNING: Did not find branch or tag 'e583e196f', assuming revision or ref.\u001b[0m\n", "Requirement already satisfied (use --upgrade to upgrade): mspsmc==0.1.0 from git+https://github.com/terhorst/mspsmc@e583e196f in /home/terhorst/opt/py37/lib/python3.7/site-packages\n", "Requirement already satisfied: scipy in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (1.4.1)\n", "Requirement already satisfied: numpy in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (1.18.2)\n", "Requirement already satisfied: tskit in /home/terhorst/opt/py37/lib/python3.7/site-packages (from mspsmc==0.1.0) (0.3.1)\n", "Requirement already satisfied: jsonschema in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (3.1.1)\n", "Requirement already satisfied: h5py in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (2.10.0)\n", "Requirement already satisfied: attrs>=19.1.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (19.3.0)\n", "Requirement already satisfied: svgwrite in /home/terhorst/opt/py37/lib/python3.7/site-packages (from tskit->mspsmc==0.1.0) (1.3.1)\n", "Requirement already satisfied: importlib-metadata in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (0.23)\n", "Requirement already satisfied: pyrsistent>=0.14.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (0.15.4)\n", "Requirement already satisfied: setuptools in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (49.3.2)\n", "Requirement already satisfied: six>=1.11.0 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from jsonschema->tskit->mspsmc==0.1.0) (1.14.0)\n", "Requirement already satisfied: pyparsing>=2.0.1 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from svgwrite->tskit->mspsmc==0.1.0) (2.4.2)\n", "Requirement already satisfied: zipp>=0.5 in /home/terhorst/opt/py37/lib/python3.7/site-packages (from importlib-metadata->jsonschema->tskit->mspsmc==0.1.0) (0.6.0)\n", "Requirement already satisfied: more-itertools in /home/terhorst/opt/py37/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata->jsonschema->tskit->mspsmc==0.1.0) (7.2.0)\n", "Building wheels for collected packages: mspsmc\n", " Building wheel for mspsmc (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for mspsmc: filename=mspsmc-0.1.0-py2.py3-none-any.whl size=5757 sha256=6961418aa633d0180cec574751c509435246e3a7d5541bd1d739840903c4ee9e\n", " Stored in directory: /tmp/pip-ephem-wheel-cache-5a52eh0a/wheels/af/38/87/3eee2ca03eb580f4df9e22215d1e0ff5a8914f16a45cf26f93\n", "Successfully built mspsmc\n", "\u001b[33mWARNING: You are using pip version 20.2.3; however, version 20.2.4 is available.\n", "You should consider upgrading via the '/home/terhorst/opt/py37/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" ] }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 4;\n", " var nbb_unformatted_code = \"!pip install git+https://github.com/terhorst/mspsmc@e583e196f\";\n", " var nbb_formatted_code = \"!pip install git+https://github.com/terhorst/mspsmc@e583e196f\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "!pip install git+https://github.com/terhorst/mspsmc@e583e196f" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "code_folding": [] }, "outputs": [ { "data": { "text/plain": [ "'0.6.5-r67'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 5;\n", " var nbb_unformatted_code = \"# psmc code\\nimport os\\n\\nos.environ[\\\"PSMC_PATH\\\"] = \\\"/scratch/psmc/psmc\\\" # update as needed if running locally\\nimport mspsmc\\n\\n\\ndef run_psmc(reps, rho_over_theta=1.0):\\n def f(data, *args):\\n return mspsmc.msPSMC([(data, (0, 1))]).estimate(*args)\\n\\n with ThreadPoolExecutor() as p:\\n futs = [p.submit(f, data, \\\"-r\\\", 1.0 / rho_over_theta) for data in reps]\\n res = [f.result() for f in futs]\\n rescaled = []\\n for r in res:\\n # See Appendix I of https://github.com/lh3/psmc/blob/master/README\\n N0 = r.theta / (4 * mu) / 100\\n rescaled.append(r.Ne.rescale(2 * N0))\\n return rescaled\\n\\n\\nmspsmc.__psmc__version__\";\n", " var nbb_formatted_code = \"# psmc code\\nimport os\\n\\nos.environ[\\\"PSMC_PATH\\\"] = \\\"/scratch/psmc/psmc\\\" # update as needed if running locally\\nimport mspsmc\\n\\n\\ndef run_psmc(reps, rho_over_theta=1.0):\\n def f(data, *args):\\n return mspsmc.msPSMC([(data, (0, 1))]).estimate(*args)\\n\\n with ThreadPoolExecutor() as p:\\n futs = [p.submit(f, data, \\\"-r\\\", 1.0 / rho_over_theta) for data in reps]\\n res = [f.result() for f in futs]\\n rescaled = []\\n for r in res:\\n # See Appendix I of https://github.com/lh3/psmc/blob/master/README\\n N0 = r.theta / (4 * mu) / 100\\n rescaled.append(r.Ne.rescale(2 * N0))\\n return rescaled\\n\\n\\nmspsmc.__psmc__version__\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# psmc code\n", "import os\n", "\n", "os.environ[\"PSMC_PATH\"] = \"/scratch/psmc/psmc\" # update as needed if running locally\n", "import mspsmc\n", "\n", "\n", "def run_psmc(reps, rho_over_theta=1.0):\n", " def f(data, *args):\n", " return mspsmc.msPSMC([(data, (0, 1))]).estimate(*args)\n", "\n", " with ThreadPoolExecutor() as p:\n", " futs = [p.submit(f, data, \"-r\", 1.0 / rho_over_theta) for data in reps]\n", " res = [f.result() for f in futs]\n", " rescaled = []\n", " for r in res:\n", " # See Appendix I of https://github.com/lh3/psmc/blob/master/README\n", " N0 = r.theta / (4 * mu) / 100\n", " rescaled.append(r.Ne.rescale(2 * N0))\n", " return rescaled\n", "\n", "\n", "mspsmc.__psmc__version__" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "code_folding": [] }, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 6;\n", " var nbb_unformatted_code = \"# other supporting functions\\nL = int(5e7) # length of simulated chromosome\\nmu = 1.4e-8 # mutation rate/bp/gen\\nM = 25 # number of replicates\\n\\n\\ndef parallel_sample(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\\n xs = [\\n xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\\n for data in reps\\n ]\\n with ThreadPoolExecutor() as p:\\n futs = [\\n p.submit(x.sample_heights, j=j, k=k, seed=seed()) for i, x in enumerate(xs)\\n ]\\n return np.array(\\n [f.result() * 2 * x.theta / (4 * mu) for f, x in zip(futs, xs)]\\n ) # rescale each sampled path by 2N0 so that segment heights are in generations\\n\\n\\ndef parallel_sample0(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\\n xs = [\\n xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\\n for data in reps\\n ]\\n with ProcessPoolExecutor() as p:\\n futs = [\\n p.submit(x.sample, k=j, seed=seed(), prime=True) for i, x in enumerate(xs)\\n ]\\n paths = [\\n [p.rescale(2 * x.theta / (4 * mu)) for p in f.result()]\\n for f, x in zip(futs, xs)\\n ] # rescale each sampled path by 2N0 so that segment heights are in generations\\n return np.array([[p(np.random.uniform(0, L, k)) for p in path] for path in paths])\\n\\n\\ndef sim_data(de, **kwargs):\\n d = dict(\\n sample_size=2,\\n recombination_rate=1.4e-8,\\n mutation_rate=mu,\\n length=L,\\n demographic_events=de,\\n )\\n\\n d.update(kwargs)\\n with ThreadPoolExecutor() as p:\\n futs = [p.submit(msp.simulate, **d, random_seed=seed()) for i in range(M)]\\n return [f.result() for f in futs]\\n\\n\\ndef summarize_lines(xys, x0):\\n \\\"summarize a collection of lines by plotting their median and IQR\\\"\\n y0 = []\\n for x, y in xys:\\n f = interp1d(\\n x, y, bounds_error=False\\n ) # interpolate linearly to a common set of points\\n y0.append(f(x0))\\n return np.nanquantile(y0, [0.5, 0.25, 0.75], axis=0) # median, q25, q75\\n\\n\\ndef plot_summary(ax, lines, x, label=None, **kwargs):\\n all_x = np.concatenate([l[0] for l in lines]).reshape(-1)\\n m, q25, q75 = summarize_lines(lines, x)\\n ax.plot(x, m / 2, label=label, **kwargs)\\n ax.fill_between(x, q25 / 2, q75 / 2, **kwargs, alpha=0.5)\\n\\n\\ndef plot_combined(lines_psmc, lines_xsmc, truth, ax=None):\\n if ax is None:\\n ax = plt.gca()\\n x = np.geomspace(1e2, 1e6, 200)\\n for lines, label, color in zip(\\n [lines_xsmc, lines_psmc], [\\\"XSMC\\\", \\\"PSMC\\\"], [\\\"tab:blue\\\", \\\"tab:red\\\"]\\n ):\\n # for i, (x, y) in enumerate(lines):\\n # ax.plot(\\n # x,\\n # y,\\n # color=color,\\n # label=label if i == 0 else None,\\n # alpha=5.0 / len(lines),\\n # )\\n plot_summary(ax, lines, x, label=label, color=color)\\n ax.plot(\\n *truth, \\\"--\\\", color=\\\"darkgrey\\\", label=\\\"Truth\\\", drawstyle=\\\"steps-post\\\", zorder=0\\n )\\n ax.set_xscale(\\\"log\\\")\\n ax.set_yscale(\\\"log\\\")\\n ax.set_xlim(1e2, 1e5)\\n ax.set_ylim(1e3, 1e6)\";\n", " var nbb_formatted_code = \"# other supporting functions\\nL = int(5e7) # length of simulated chromosome\\nmu = 1.4e-8 # mutation rate/bp/gen\\nM = 25 # number of replicates\\n\\n\\ndef parallel_sample(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\\n xs = [\\n xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\\n for data in reps\\n ]\\n with ThreadPoolExecutor() as p:\\n futs = [\\n p.submit(x.sample_heights, j=j, k=k, seed=seed()) for i, x in enumerate(xs)\\n ]\\n return np.array(\\n [f.result() * 2 * x.theta / (4 * mu) for f, x in zip(futs, xs)]\\n ) # rescale each sampled path by 2N0 so that segment heights are in generations\\n\\n\\ndef parallel_sample0(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\\n xs = [\\n xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\\n for data in reps\\n ]\\n with ProcessPoolExecutor() as p:\\n futs = [\\n p.submit(x.sample, k=j, seed=seed(), prime=True) for i, x in enumerate(xs)\\n ]\\n paths = [\\n [p.rescale(2 * x.theta / (4 * mu)) for p in f.result()]\\n for f, x in zip(futs, xs)\\n ] # rescale each sampled path by 2N0 so that segment heights are in generations\\n return np.array([[p(np.random.uniform(0, L, k)) for p in path] for path in paths])\\n\\n\\ndef sim_data(de, **kwargs):\\n d = dict(\\n sample_size=2,\\n recombination_rate=1.4e-8,\\n mutation_rate=mu,\\n length=L,\\n demographic_events=de,\\n )\\n\\n d.update(kwargs)\\n with ThreadPoolExecutor() as p:\\n futs = [p.submit(msp.simulate, **d, random_seed=seed()) for i in range(M)]\\n return [f.result() for f in futs]\\n\\n\\ndef summarize_lines(xys, x0):\\n \\\"summarize a collection of lines by plotting their median and IQR\\\"\\n y0 = []\\n for x, y in xys:\\n f = interp1d(\\n x, y, bounds_error=False\\n ) # interpolate linearly to a common set of points\\n y0.append(f(x0))\\n return np.nanquantile(y0, [0.5, 0.25, 0.75], axis=0) # median, q25, q75\\n\\n\\ndef plot_summary(ax, lines, x, label=None, **kwargs):\\n all_x = np.concatenate([l[0] for l in lines]).reshape(-1)\\n m, q25, q75 = summarize_lines(lines, x)\\n ax.plot(x, m / 2, label=label, **kwargs)\\n ax.fill_between(x, q25 / 2, q75 / 2, **kwargs, alpha=0.5)\\n\\n\\ndef plot_combined(lines_psmc, lines_xsmc, truth, ax=None):\\n if ax is None:\\n ax = plt.gca()\\n x = np.geomspace(1e2, 1e6, 200)\\n for lines, label, color in zip(\\n [lines_xsmc, lines_psmc], [\\\"XSMC\\\", \\\"PSMC\\\"], [\\\"tab:blue\\\", \\\"tab:red\\\"]\\n ):\\n # for i, (x, y) in enumerate(lines):\\n # ax.plot(\\n # x,\\n # y,\\n # color=color,\\n # label=label if i == 0 else None,\\n # alpha=5.0 / len(lines),\\n # )\\n plot_summary(ax, lines, x, label=label, color=color)\\n ax.plot(\\n *truth, \\\"--\\\", color=\\\"darkgrey\\\", label=\\\"Truth\\\", drawstyle=\\\"steps-post\\\", zorder=0\\n )\\n ax.set_xscale(\\\"log\\\")\\n ax.set_yscale(\\\"log\\\")\\n ax.set_xlim(1e2, 1e5)\\n ax.set_ylim(1e3, 1e6)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# other supporting functions\n", "L = int(5e7) # length of simulated chromosome\n", "mu = 1.4e-8 # mutation rate/bp/gen\n", "M = 25 # number of replicates\n", "\n", "\n", "def parallel_sample(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\n", " xs = [\n", " xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\n", " for data in reps\n", " ]\n", " with ThreadPoolExecutor() as p:\n", " futs = [\n", " p.submit(x.sample_heights, j=j, k=k, seed=seed()) for i, x in enumerate(xs)\n", " ]\n", " return np.array(\n", " [f.result() * 2 * x.theta / (4 * mu) for f, x in zip(futs, xs)]\n", " ) # rescale each sampled path by 2N0 so that segment heights are in generations\n", "\n", "\n", "def parallel_sample0(reps, j=100, k=int(L / 50_000), rho_over_theta=1.0):\n", " xs = [\n", " xsmc.XSMC(data, focal=0, panel=[1], rho_over_theta=rho_over_theta)\n", " for data in reps\n", " ]\n", " with ProcessPoolExecutor() as p:\n", " futs = [\n", " p.submit(x.sample, k=j, seed=seed(), prime=True) for i, x in enumerate(xs)\n", " ]\n", " paths = [\n", " [p.rescale(2 * x.theta / (4 * mu)) for p in f.result()]\n", " for f, x in zip(futs, xs)\n", " ] # rescale each sampled path by 2N0 so that segment heights are in generations\n", " return np.array([[p(np.random.uniform(0, L, k)) for p in path] for path in paths])\n", "\n", "\n", "def sim_data(de, **kwargs):\n", " d = dict(\n", " sample_size=2,\n", " recombination_rate=1.4e-8,\n", " mutation_rate=mu,\n", " length=L,\n", " demographic_events=de,\n", " )\n", "\n", " d.update(kwargs)\n", " with ThreadPoolExecutor() as p:\n", " futs = [p.submit(msp.simulate, **d, random_seed=seed()) for i in range(M)]\n", " return [f.result() for f in futs]\n", "\n", "\n", "def summarize_lines(xys, x0):\n", " \"summarize a collection of lines by plotting their median and IQR\"\n", " y0 = []\n", " for x, y in xys:\n", " f = interp1d(\n", " x, y, bounds_error=False\n", " ) # interpolate linearly to a common set of points\n", " y0.append(f(x0))\n", " return np.nanquantile(y0, [0.5, 0.25, 0.75], axis=0) # median, q25, q75\n", "\n", "\n", "def plot_summary(ax, lines, x, label=None, **kwargs):\n", " all_x = np.concatenate([l[0] for l in lines]).reshape(-1)\n", " m, q25, q75 = summarize_lines(lines, x)\n", " ax.plot(x, m / 2, label=label, **kwargs)\n", " ax.fill_between(x, q25 / 2, q75 / 2, **kwargs, alpha=0.5)\n", "\n", "\n", "def plot_combined(lines_psmc, lines_xsmc, truth, ax=None):\n", " if ax is None:\n", " ax = plt.gca()\n", " x = np.geomspace(1e2, 1e6, 200)\n", " for lines, label, color in zip(\n", " [lines_xsmc, lines_psmc], [\"XSMC\", \"PSMC\"], [\"tab:blue\", \"tab:red\"]\n", " ):\n", " # for i, (x, y) in enumerate(lines):\n", " # ax.plot(\n", " # x,\n", " # y,\n", " # color=color,\n", " # label=label if i == 0 else None,\n", " # alpha=5.0 / len(lines),\n", " # )\n", " plot_summary(ax, lines, x, label=label, color=color)\n", " ax.plot(\n", " *truth, \"--\", color=\"darkgrey\", label=\"Truth\", drawstyle=\"steps-post\", zorder=0\n", " )\n", " ax.set_xscale(\"log\")\n", " ax.set_yscale(\"log\")\n", " ax.set_xlim(1e2, 1e5)\n", " ax.set_ylim(1e3, 1e6)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 7;\n", " var nbb_unformatted_code = \"from collections import Counter\\nfrom scipy.signal import convolve\\nfrom xsmc.supporting.kde_ne import *\\n\\n\\ndef parallel_kde(sampled_heights, **kwargs):\\n with ProcessPoolExecutor() as p:\\n futs = [p.submit(kde_ne, h.reshape(-1), **kwargs) for h in sampled_heights]\\n return [(f.result()[0], f.result()[1]) for f in futs]\";\n", " var nbb_formatted_code = \"from collections import Counter\\nfrom scipy.signal import convolve\\nfrom xsmc.supporting.kde_ne import *\\n\\n\\ndef parallel_kde(sampled_heights, **kwargs):\\n with ProcessPoolExecutor() as p:\\n futs = [p.submit(kde_ne, h.reshape(-1), **kwargs) for h in sampled_heights]\\n return [(f.result()[0], f.result()[1]) for f in futs]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from collections import Counter\n", "from scipy.signal import convolve\n", "from xsmc.supporting.kde_ne import *\n", "\n", "\n", "def parallel_kde(sampled_heights, **kwargs):\n", " with ProcessPoolExecutor() as p:\n", " futs = [p.submit(kde_ne, h.reshape(-1), **kwargs) for h in sampled_heights]\n", " return [(f.result()[0], f.result()[1]) for f in futs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Constant effective population size\n", "The simplest case. First we check the estimator on \"perfect\" data, that is i.i.d. samples from the true distribution:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 8;\n", " var nbb_unformatted_code = \"de = [msp.PopulationParametersChange(time=0, initial_size=1e4)]\";\n", " var nbb_formatted_code = \"de = [msp.PopulationParametersChange(time=0, initial_size=1e4)]\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "de = [msp.PopulationParametersChange(time=0, initial_size=1e4)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Perfect data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Verify the estimator on \"perfect\" data:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1000.0, 100000.0)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ "2020-10-22 17:02:57,282 WARNING matplotlib.font_manager MainThread findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 9;\n", " var nbb_unformatted_code = \"true_data = np.array(\\n [\\n next(sim.trees()).get_time(2)\\n for sim in msp.simulate(\\n num_replicates=10000, demographic_events=de, sample_size=2\\n )\\n ]\\n)\\nx, y = kde_ne(true_data)\\nplt.plot(x, y / 2, label=\\\"Fitted\\\")\\nplt.plot(plt.xlim(), [1e4] * 2, \\\"--\\\", color=\\\"darkgrey\\\", label=\\\"Truth\\\")\\nplt.xscale(\\\"log\\\")\\nplt.yscale(\\\"log\\\")\\nplt.legend()\\nplt.ylim(1e3, 1e5)\";\n", " var nbb_formatted_code = \"true_data = np.array(\\n [\\n next(sim.trees()).get_time(2)\\n for sim in msp.simulate(\\n num_replicates=10000, demographic_events=de, sample_size=2\\n )\\n ]\\n)\\nx, y = kde_ne(true_data)\\nplt.plot(x, y / 2, label=\\\"Fitted\\\")\\nplt.plot(plt.xlim(), [1e4] * 2, \\\"--\\\", color=\\\"darkgrey\\\", label=\\\"Truth\\\")\\nplt.xscale(\\\"log\\\")\\nplt.yscale(\\\"log\\\")\\nplt.legend()\\nplt.ylim(1e3, 1e5)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "true_data = np.array(\n", " [\n", " next(sim.trees()).get_time(2)\n", " for sim in msp.simulate(\n", " num_replicates=10000, demographic_events=de, sample_size=2\n", " )\n", " ]\n", ")\n", "x, y = kde_ne(true_data)\n", "plt.plot(x, y / 2, label=\"Fitted\")\n", "plt.plot(plt.xlim(), [1e4] * 2, \"--\", color=\"darkgrey\", label=\"Truth\")\n", "plt.xscale(\"log\")\n", "plt.yscale(\"log\")\n", "plt.legend()\n", "plt.ylim(1e3, 1e5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### XSMC \n", "Next we simulate data and sample from the posterior instead." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "25" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 10;\n", " var nbb_unformatted_code = \"data = sim_data(de)\\nlen(data)\";\n", " var nbb_formatted_code = \"data = sim_data(de)\\nlen(data)\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data = sim_data(de)\n", "len(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sampled_heights = parallel_sample(data)\n", "sampled_heights.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lines_xsmc = parallel_kde(sampled_heights)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Diagnostic\n", "q = np.linspace(0, 1.0, 100)\n", "plt.plot(np.quantile(sampled_heights.reshape(-1), q), np.quantile(true_data, q))\n", "plt.plot(plt.xlim(), plt.xlim())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PSMC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "psmc_out = run_psmc(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_psmc = np.geomspace(1e2, 1e5, 100)\n", "lines_psmc = [(x_psmc, r(x_psmc)) for r in psmc_out]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combined plot for paper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "truth = ([1e2, 1e6], [1e4, 1e4])\n", "fig, axs = plt.subplots(ncols=3, figsize=(12, 4.5), sharex=True, sharey=True, dpi=150)\n", "plot_combined(lines_psmc, lines_xsmc, truth, axs[0])\n", "# a.legend()\n", "axs[0].set_xlim(1e2, 1e5)\n", "axs[0].set_ylim(1e3, 1e6)\n", "axs[0].set_title(\"Constant\")\n", "fig.add_subplot(111, frameon=False)\n", "plt.tick_params(labelcolor=\"none\", top=False, bottom=False, left=False, right=False)\n", "# plt.xlabel(\"common X\")\n", "# plt.ylabel(\"common Y\")\n", "plt.xlabel(\"Generations\")\n", "plt.ylabel(\"$N_e$\")\n", "plt.tight_layout(pad=1.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recent growth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Perfect data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "de = [\n", " msp.PopulationParametersChange(time=0, initial_size=1e6),\n", " msp.PopulationParametersChange(time=1e3, initial_size=5e3),\n", " msp.PopulationParametersChange(time=2e3, initial_size=2e4),\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "true_data = np.array(\n", " [\n", " next(sim.trees()).get_time(2)\n", " for sim in msp.simulate(\n", " num_replicates=10000, demographic_events=de, sample_size=2, Ne=1\n", " )\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = kde_ne(true_data)\n", "plt.plot(x[::50], y[::50] / 2) # downsample the curves to make plotting faster\n", "truth = ([0, 1e3, 2e3, 1e5], [1e6, 5e3, 2e4, 2e4])\n", "\n", "\n", "plt.plot(\n", " *truth, \"--\", color=\"darkgrey\", label=\"Truth\", drawstyle=\"steps-post\",\n", ")\n", "plt.xscale(\"log\")\n", "plt.yscale(\"log\")\n", "plt.legend()\n", "plt.xlim(1e2, 1e5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### XSMC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = sim_data(de)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sampled_heights = parallel_sample(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lines_xsmc = parallel_kde(sampled_heights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PSMC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "psmc_out = run_psmc(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lines_psmc = [(x_psmc, r(x_psmc)) for r in psmc_out]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combined plot for paper" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_combined(lines_psmc, lines_xsmc, truth, axs[1])\n", "axs[1].set_title(\"Growth\")\n", "fig" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Zigzag" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import stdpopsim\n", "\n", "species = stdpopsim.get_species(\"HomSap\")\n", "model = species.get_demographic_model(\"Zigzag_1S14\")\n", "de = [\n", " msp.PopulationParametersChange(time=0, initial_size=14312)\n", "] + model.demographic_events" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Perfect data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "true_data = np.array(\n", " [\n", " next(sim.trees()).get_time(2)\n", " for sim in msp.simulate(\n", " num_replicates=10000, demographic_events=de, sample_size=2,\n", " )\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = kde_ne(true_data)\n", "plt.plot(x[::50], y[::50] / 2) # downsample the curves to make plotting faster\n", "\n", "f = plot_de(de, 14312)\n", "x_zz = np.geomspace(1e2, 1e6, 1000)\n", "truth = (x_zz, f(x_zz))\n", "\n", "\n", "plt.plot(\n", " *truth, \"--\", color=\"darkgrey\", label=\"Truth\", drawstyle=\"steps-post\",\n", ")\n", "plt.xscale(\"log\")\n", "plt.yscale(\"log\")\n", "plt.legend()\n", "plt.xlim(1e2, 1e5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### XSMC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = sim_data(de, recombination_rate=1e-9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sampled_heights = parallel_sample(data, rho_over_theta=1e-9 / mu)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lines_xsmc = parallel_kde(sampled_heights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### PSMC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "psmc_out = run_psmc(data, rho_over_theta=1e-9 / mu)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lines_psmc = [(x_psmc, r(x_psmc)) for r in psmc_out]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combined plot for paper" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/terhorst/opt/py37/lib/python3.7/site-packages/numpy/lib/nanfunctions.py:1392: RuntimeWarning: All-NaN slice encountered\n", " overwrite_input, interpolation)\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 38;\n", " var nbb_unformatted_code = \"plot_combined(lines_psmc, lines_xsmc, truth, axs[2])\\naxs[2].set_title(\\\"Zigzag\\\")\\nfig\";\n", " var nbb_formatted_code = \"plot_combined(lines_psmc, lines_xsmc, truth, axs[2])\\naxs[2].set_title(\\\"Zigzag\\\")\\nfig\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plot_combined(lines_psmc, lines_xsmc, truth, axs[2])\n", "axs[2].set_title(\"Zigzag\")\n", "fig" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" }, { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 42;\n", " var nbb_unformatted_code = \"axs[0].legend()\";\n", " var nbb_formatted_code = \"axs[0].legend()\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "axs[0].legend()" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "application/javascript": [ "\n", " setTimeout(function() {\n", " var nbb_cell_id = 40;\n", " var nbb_unformatted_code = \"fig.savefig(os.path.join(PAPER_ROOT, \\\"figures\\\", \\\"xsmc_psmc.pdf\\\"))\";\n", " var nbb_formatted_code = \"fig.savefig(os.path.join(PAPER_ROOT, \\\"figures\\\", \\\"xsmc_psmc.pdf\\\"))\";\n", " var nbb_cells = Jupyter.notebook.get_cells();\n", " for (var i = 0; i < nbb_cells.length; ++i) {\n", " if (nbb_cells[i].input_prompt_number == nbb_cell_id) {\n", " if (nbb_cells[i].get_text() == nbb_unformatted_code) {\n", " nbb_cells[i].set_text(nbb_formatted_code);\n", " }\n", " break;\n", " }\n", " }\n", " }, 500);\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig.savefig(os.path.join(PAPER_ROOT, \"figures\", \"xsmc_psmc.pdf\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional diagnostics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "b = np.linspace(4, 12, 32)\n", "for d in np.array(sampled_heights)[..., 0], true_data:\n", " plt.hist(np.log(d).reshape(-1), bins=b, density=True, alpha=0.5)" ] } ], "metadata": { "celltoolbar": "Edit Metadata", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }