用 EasyGuide 构建 guides

本教程讲解 pyro.contrib.easyguide module. 预备知识包括 SVI and tensor shapes.

Summary:

  • For simple black-box guides, try using components in pyro.infer.autoguide.

  • For more complex guides, try using components in pyro.contrib.easyguide.

  • Decorate with @easy_guide(model).

  • Select multiple model sites using group = self.group(match="my_regex").

  • Guide a group of sites by a single distribution using group.sample(...).

  • Inspect concatenated group shape using group.batch_shape, group.event_shape, etc.

  • Use self.plate(...) instead of pyro.plate(...).

  • To be compatible with subsampling, pass the event_dim arg to pyro.param(...).

  • To MAP estimate model site “foo”, use foo = self.map_estimate("foo").

Table of contents

[1]:
import os
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.easyguide import easy_guide
from pyro.optim import Adam
from torch.distributions import constraints

pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')

时间序列建模

考虑一个具有缓慢变化的连续 latent state 的时间序列模型以及 Bernoulli observations with a logistic link function.

[4]:
def model(batch, subsample, full_size):
    batch = list(batch)
    num_time_steps = len(batch)
    drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
    with pyro.plate("data", full_size, subsample=subsample):
        z = 0.
        for t in range(num_time_steps):
            z = pyro.sample("state_{}".format(t), dist.Normal(z, drift))
            batch[t] = pyro.sample("obs_{}".format(t),
                                   dist.Bernoulli(logits=z),
                                   obs=batch[t])
    return torch.stack(batch)

让我们直接从模型中生成一些数据。

[3]:
full_size = 100
num_time_steps = 7
pyro.set_rng_seed(123456789)
data = model([None] * num_time_steps, torch.arange(full_size), full_size)
assert data.shape == (num_time_steps, full_size)

用 EasyGuide 快速构建 guide

构建 guide

Consider a possible guide for this model where we point-estimate the drift parameter using a Delta distribution, and then model local time series using shared uncertainty but local means, using a LowRankMultivariateNormal distribution. There is a single global sample site which we can model with a param and sample statement. Then we sample a global pair of uncertainty parameters cov_diag and cov_factor. Next we sample a local loc parameter using pyro.param(..., event_dim=...) and an auxiliary sample site. Finally we unpack that auxiliary site into one element per time series. The auxiliary-unpacked-to-Deltas pattern is quite common.

[5]:
rank = 3

def guide(batch, subsample, full_size):
    num_time_steps, batch_size = batch.shape

    # MAP estimate the drift.
    drift_loc = pyro.param("drift_loc", lambda: torch.tensor(0.1),
                           constraint=constraints.positive)
    pyro.sample("drift", dist.Delta(drift_loc))

    # Model local states using shared uncertainty + local mean.
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full((num_time_steps,), 0.01),
                         constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(num_time_steps, rank) * 0.01)
    with pyro.plate("data", full_size, subsample=subsample):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full((full_size, num_time_steps), 0.5),
                         event_dim=1)
        states = pyro.sample("states",
                             dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),
                             infer={"is_auxiliary": True})
        # Unpack the joint states into one sample site per time step.
        for t in range(num_time_steps):
            pyro.sample("state_{}".format(t), dist.Delta(states[:, t]))

Let’s train using SVI and Trace_ELBO, manually batching data into small minibatches.

[6]:
def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):
    full_size = data.size(-1)
    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
    for epoch in range(num_epochs):
        pos = 0
        losses = []
        while pos < full_size:
            subsample = torch.arange(pos, pos + batch_size)
            batch = data[:, pos:pos + batch_size]
            pos += batch_size
            losses.append(svi.step(batch, subsample, full_size=full_size))
        epoch_loss = sum(losses) / len(losses)
        if epoch % 10 == 0:
            print("epoch {} loss = {}".format(epoch, epoch_loss / data.numel()))
[7]:
train(guide)
epoch 0 loss = 2.6799370694841658
epoch 10 loss = 1.0252116586480822
epoch 20 loss = 0.9349947195393699
epoch 30 loss = 0.8692359572308405
epoch 40 loss = 0.8368030676501137
epoch 50 loss = 0.8429559614998954
epoch 60 loss = 0.7737532197747913
epoch 70 loss = 0.8165627040011542
epoch 80 loss = 0.7903614648069655
epoch 90 loss = 0.7785837551695961
epoch 100 loss = 0.7461400769267763

Using EasyGuide

Now let’s simplify using the @easy_guide decorator. Our modifications are:

  1. Decorate with @easy_guide and add self to args.

  2. Replace the Delta guide for drift with a simple map_estimate().

  3. Select a group of model sites and read their concatenated event_shape.

  4. Replace the auxiliary site and Delta slices with a single group.sample().

[8]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
    # MAP estimate the drift.
    self.map_estimate("drift")

    # Model local states using shared uncertainty + local mean.
    group = self.group(match="state_[0-9]*")  # Selects all local variables.
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01),
                          constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
    with self.plate("data", full_size, subsample=subsample):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full((full_size,) + group.event_shape, 0.5),
                         event_dim=1)
        # Automatically sample the joint latent, then unpack and replay model sites.
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))

Note we’ve used group.event_shape to determine the total flattened concatenated shape of all matched sites in the group.

[9]:
train(guide)
epoch 0 loss = 1.6414739089523043
epoch 10 loss = 0.9550169031790324
epoch 20 loss = 0.8266084059476851
epoch 30 loss = 0.8526671367372785
epoch 40 loss = 0.7920228289876666
epoch 50 loss = 0.8243447229521614
epoch 60 loss = 0.8051274507556644
epoch 70 loss = 0.7844243283271789
epoch 80 loss = 0.7685391607114247
epoch 90 loss = 0.7792391348736627
epoch 100 loss = 0.7797092771530152

Amortized guides

EasyGuide also makes it easy to write amortized guides (guides where we learn a function that predicts latent variables from data, rather than learning one parameter per datapoint). Let’s modify the last guide to predict the latent loc as an affine function of observed data, rather than memorizing each data point’s latent variable. This amortized guide is more useful in practice because it can handle new data.

[10]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
    num_time_steps, batch_size = batch.shape
    self.map_estimate("drift")

    group = self.group(match="state_[0-9]*")
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01),
                          constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)

    # Predict latent propensity as an affine function of observed data.
    if not hasattr(self, "nn"):
        self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())
        self.nn.weight.data.fill_(1.0 / num_time_steps)
        self.nn.bias.data.fill_(-0.5)
    pyro.module("state_nn", self.nn)
    with self.plate("data", full_size, subsample=subsample):
        loc = self.nn(batch.t())
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
[11]:
train(guide)
epoch 0 loss = 1.4899632521527155
epoch 10 loss = 0.7560149289710181
epoch 20 loss = 0.7410904037782123
epoch 30 loss = 0.7411658687080656
epoch 40 loss = 0.7323255259820394
epoch 50 loss = 0.7339726304667337
epoch 60 loss = 0.7356956726482937
epoch 70 loss = 0.7183537655728204
epoch 80 loss = 0.7098537818193436
epoch 90 loss = 0.7118427662849427
epoch 100 loss = 0.7383492155926568
[ ]: