{ "cells": [ { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "# 用 EasyGuide 构建 guides\n", "\n", "本教程讲解 [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html) module. 预备知识包括 [SVI](http://pyro.ai/examples/svi_part_ii.html) and [tensor shapes](http://pyro.ai/examples/tensor_shapes.html)." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "**Summary:**\n", "\n", "- For simple black-box guides, try using components in [pyro.infer.autoguide](http://docs.pyro.ai/en/stable/infer.autoguide.html).\n", "- For more complex guides, try using components in [pyro.contrib.easyguide](http://docs.pyro.ai/en/stable/contrib.easyguide.html).\n", "- Decorate with `@easy_guide(model)`.\n", "- Select multiple model sites using `group = self.group(match=\"my_regex\")`.\n", "- Guide a group of sites by a single distribution using `group.sample(...)`.\n", "- Inspect concatenated group shape using `group.batch_shape`, `group.event_shape`, etc.\n", "- Use `self.plate(...)` instead of `pyro.plate(...)`.\n", "- To be compatible with subsampling, pass the `event_dim` arg to `pyro.param(...)`.\n", "- To MAP estimate model site \"foo\", use `foo = self.map_estimate(\"foo\")`." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "**Table of contents**\n", "\n", "- [Modeling time series data](#Modeling-time-series-data)\n", "- [Writing a guide without EasyGuide](#Writing-a-guide-without-EasyGuide)\n", "- [Using EasyGuide](#Using-EasyGuide)\n", "- [Amortized guides](#Amortized-guides)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "import os\n", "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.contrib.easyguide import easy_guide\n", "from pyro.optim import Adam\n", "from torch.distributions import constraints\n", "\n", "pyro.enable_validation(True)\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.3.0')" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## 时间序列建模" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "考虑一个具有缓慢变化的连续 latent state 的时间序列模型以及 Bernoulli observations with a logistic link function." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "def model(batch, subsample, full_size):\n", " batch = list(batch)\n", " num_time_steps = len(batch)\n", " drift = pyro.sample(\"drift\", dist.LogNormal(-1, 0.5))\n", " with pyro.plate(\"data\", full_size, subsample=subsample):\n", " z = 0.\n", " for t in range(num_time_steps):\n", " z = pyro.sample(\"state_{}\".format(t), dist.Normal(z, drift))\n", " batch[t] = pyro.sample(\"obs_{}\".format(t),\n", " dist.Bernoulli(logits=z),\n", " obs=batch[t])\n", " return torch.stack(batch)" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "让我们直接从模型中生成一些数据。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "full_size = 100\n", "num_time_steps = 7\n", "pyro.set_rng_seed(123456789)\n", "data = model([None] * num_time_steps, torch.arange(full_size), full_size)\n", "assert data.shape == (num_time_steps, full_size)" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## 用 EasyGuide 快速构建 guide" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### 构建 guide\n", "\n", "\n", "Consider a possible guide for this model where we point-estimate the `drift` parameter using a `Delta` distribution, and then model local time series using shared uncertainty but local means, using a `LowRankMultivariateNormal` distribution. There is a single global sample site which we can model with a `param` and `sample` statement. Then we sample a global pair of uncertainty parameters `cov_diag` and `cov_factor`. Next we sample a local `loc` parameter using `pyro.param(..., event_dim=...)` and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-`Delta`s pattern is quite common." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "rank = 3\n", " \n", "def guide(batch, subsample, full_size):\n", " num_time_steps, batch_size = batch.shape\n", "\n", " # MAP estimate the drift.\n", " drift_loc = pyro.param(\"drift_loc\", lambda: torch.tensor(0.1),\n", " constraint=constraints.positive)\n", " pyro.sample(\"drift\", dist.Delta(drift_loc))\n", "\n", " # Model local states using shared uncertainty + local mean.\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full((num_time_steps,), 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(num_time_steps, rank) * 0.01)\n", " with pyro.plate(\"data\", full_size, subsample=subsample):\n", " # Sample local mean.\n", " loc = pyro.param(\"state_loc\",\n", " lambda: torch.full((full_size, num_time_steps), 0.5),\n", " event_dim=1)\n", " states = pyro.sample(\"states\",\n", " dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),\n", " infer={\"is_auxiliary\": True})\n", " # Unpack the joint states into one sample site per time step.\n", " for t in range(num_time_steps):\n", " pyro.sample(\"state_{}\".format(t), dist.Delta(states[:, t]))" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "Let's train using [SVI](http://docs.pyro.ai/en/stable/inference_algos.html#module-pyro.infer.svi) and [Trace_ELBO](http://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.trace_elbo.Trace_ELBO), manually batching data into small minibatches." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):\n", " full_size = data.size(-1)\n", " pyro.get_param_store().clear()\n", " pyro.set_rng_seed(123456789)\n", " svi = SVI(model, guide, Adam({\"lr\": 0.02}), Trace_ELBO())\n", " for epoch in range(num_epochs):\n", " pos = 0\n", " losses = []\n", " while pos < full_size:\n", " subsample = torch.arange(pos, pos + batch_size)\n", " batch = data[:, pos:pos + batch_size]\n", " pos += batch_size\n", " losses.append(svi.step(batch, subsample, full_size=full_size))\n", " epoch_loss = sum(losses) / len(losses)\n", " if epoch % 10 == 0:\n", " print(\"epoch {} loss = {}\".format(epoch, epoch_loss / data.numel()))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "Collapsed": "false" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 loss = 2.6799370694841658\n", "epoch 10 loss = 1.0252116586480822\n", "epoch 20 loss = 0.9349947195393699\n", "epoch 30 loss = 0.8692359572308405\n", "epoch 40 loss = 0.8368030676501137\n", "epoch 50 loss = 0.8429559614998954\n", "epoch 60 loss = 0.7737532197747913\n", "epoch 70 loss = 0.8165627040011542\n", "epoch 80 loss = 0.7903614648069655\n", "epoch 90 loss = 0.7785837551695961\n", "epoch 100 loss = 0.7461400769267763\n" ] } ], "source": [ "train(guide)" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### Using EasyGuide\n", "\n", "Now let's simplify using the `@easy_guide` decorator. Our modifications are:\n", "\n", "1. Decorate with `@easy_guide` and add `self` to args.\n", "2. Replace the `Delta` guide for drift with a simple `map_estimate()`.\n", "3. Select a `group` of model sites and read their concatenated `event_shape`.\n", "4. Replace the auxiliary site and `Delta` slices with a single `group.sample()`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "@easy_guide(model)\n", "def guide(self, batch, subsample, full_size):\n", " # MAP estimate the drift.\n", " self.map_estimate(\"drift\")\n", "\n", " # Model local states using shared uncertainty + local mean.\n", " group = self.group(match=\"state_[0-9]*\") # Selects all local variables.\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full(group.event_shape, 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n", " with self.plate(\"data\", full_size, subsample=subsample):\n", " # Sample local mean.\n", " loc = pyro.param(\"state_loc\",\n", " lambda: torch.full((full_size,) + group.event_shape, 0.5),\n", " event_dim=1)\n", " # Automatically sample the joint latent, then unpack and replay model sites.\n", " group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "Note we've used `group.event_shape` to determine the total flattened concatenated shape of all matched sites in the group." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "Collapsed": "false" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 loss = 1.6414739089523043\n", "epoch 10 loss = 0.9550169031790324\n", "epoch 20 loss = 0.8266084059476851\n", "epoch 30 loss = 0.8526671367372785\n", "epoch 40 loss = 0.7920228289876666\n", "epoch 50 loss = 0.8243447229521614\n", "epoch 60 loss = 0.8051274507556644\n", "epoch 70 loss = 0.7844243283271789\n", "epoch 80 loss = 0.7685391607114247\n", "epoch 90 loss = 0.7792391348736627\n", "epoch 100 loss = 0.7797092771530152\n" ] } ], "source": [ "train(guide)" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Amortized guides\n", "\n", "`EasyGuide` also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let's modify the last guide to predict the latent `loc` as an affine function of observed data, rather than memorizing each data point's latent variable. This amortized guide is more useful in practice because it can handle new data." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "@easy_guide(model)\n", "def guide(self, batch, subsample, full_size):\n", " num_time_steps, batch_size = batch.shape\n", " self.map_estimate(\"drift\")\n", "\n", " group = self.group(match=\"state_[0-9]*\")\n", " cov_diag = pyro.param(\"state_cov_diag\",\n", " lambda: torch.full(group.event_shape, 0.01),\n", " constraint=constraints.positive)\n", " cov_factor = pyro.param(\"state_cov_factor\",\n", " lambda: torch.randn(group.event_shape + (rank,)) * 0.01)\n", "\n", " # Predict latent propensity as an affine function of observed data.\n", " if not hasattr(self, \"nn\"):\n", " self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())\n", " self.nn.weight.data.fill_(1.0 / num_time_steps)\n", " self.nn.bias.data.fill_(-0.5)\n", " pyro.module(\"state_nn\", self.nn)\n", " with self.plate(\"data\", full_size, subsample=subsample):\n", " loc = self.nn(batch.t())\n", " group.sample(\"states\", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "Collapsed": "false" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0 loss = 1.4899632521527155\n", "epoch 10 loss = 0.7560149289710181\n", "epoch 20 loss = 0.7410904037782123\n", "epoch 30 loss = 0.7411658687080656\n", "epoch 40 loss = 0.7323255259820394\n", "epoch 50 loss = 0.7339726304667337\n", "epoch 60 loss = 0.7356956726482937\n", "epoch 70 loss = 0.7183537655728204\n", "epoch 80 loss = 0.7098537818193436\n", "epoch 90 loss = 0.7118427662849427\n", "epoch 100 loss = 0.7383492155926568\n" ] } ], "source": [ "train(guide)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }