离散潜变量模型

本教程讲解 Pyro 离散潜变量模型的枚举策略. 本教程假定读者已经熟悉 Tensor Shapes Tutorial.

Summary

  • Pyro对离散的潜在变量实现自动枚举。

  • This strategy can be used alone or inside SVI (via TraceEnum_ELBO), HMC, or NUTS.

  • The standalone infer_discrete can generate samples or MAP estimates.

  • Annotate a sample site infer={"enumerate": "parallel"} to trigger enumeration.

  • 如果样本点(sample site)确定下游结构,就使用 {"enumerate": "sequential"}

  • Write your models to allow arbitrarily deep batching on the left, e.g. use broadcasting.

  • 推理复杂度对于树宽(treewidth)指数级增长,因此请尝试编写具有窄树宽的模型。

  • If you have trouble, ask for help on forum.pyro.ai!

[1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation()
pyro.set_rng_seed(0)

概览

Pyro的枚举策略包含了许多流行算法 including variable elimination, exact message passing, forward-filter-backward-sample, inside-out, Baum-Welch, and many other special-case algorithms.

除了枚举策略, Pyro实施了许多其他推理策略,包括 variational inference (SVI) and monte carlo (HMC and NUTS).

枚举策略既可以用作通过 infer_discrete 单独使用,也可以用作其他策略的组成部分。Thus enumeration allows Pyro to marginalize out discrete latent variables in HMC and SVI models, and to use variational enumeration of discrete variables in SVI guides.

Mechanics of enumeration

枚举策略的核心思想是将不连续的 pyro.sample 语句解释为完全枚举,而不是随机采样。 Other inference algorithms can then sum out the enumerated values. For example a sample statement might return a tensor of scalar shape under the standard “sample” interpretation (we’ll illustrate with trivial model and guide):

[17]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(6)))
    print('model z = {}'.format(z))

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(6)))
    print('guide z = {}'.format(z))

elbo = Trace_ELBO()
elbo.loss(model, guide)
guide z = 2
model z = 2
[17]:
-0.0

(实现并行枚举)However under the enumeration interpretation, the same sample site will return a fully enumerated set of values, based on its distribution’s .enumerate_support() method.

[21]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));
guide z = tensor([0, 1, 2, 3, 4, 5])
model z = tensor([0, 1, 2, 3, 4, 5])

Note that we’ve used “parallel” enumeration to enumerate along a new tensor dimension. This is cheap and allows Pyro to parallelize computation, but requires downstream program structure to avoid branching on the value of z.

为了支持动态程序结构, 你也可以用 “sequential” enumeration, which runs the entire model,guide pair once per sample value, but requires running the model multiple times.

[22]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));
guide z = 5
model z = 5
guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0

(我们后面专注并行枚举) Parallel enumeration is cheaper but more complex than sequential enumeration, so we’ll focus the rest of this tutorial on the parallel variant. Note that both forms can be interleaved.

多元潜变量

我们刚刚看到 that a single discrete sample site can be enumerated via nonstandard interpretation. 具有单个离散潜变量的模型是混合模型,而具有多个离散潜在变量的模型更复杂, 比如 HMMs, CRFs, DBNs, 和其他结构模型。 In models with multiple discrete latent variables, Pyro enumerates each variable in a different tensor dimension (counting from the right; see Tensor Shapes Tutorial). This allows Pyro to determine the dependency graph among variables and then perform cheap exact inference using variable elimination algorithms.

(我们使用例子帮助理解) To understand enumeration dimension allocation, consider the following model, where here we collapse variables out of the model, rather than enumerate them in the guide.

[26]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    # print(p)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print('model x.shape = {}'.format(x.shape))
    print('model y.shape = {}'.format(y.shape))
    print('model z.shape = {}'.format(z.shape))
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);
model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])

Examining discrete latent states

(而具有多个离散潜在变量的模型)While enumeration in SVI allows fast learning of parameters like p above, it does not give access to predicted values of the discrete latent variables like x,y,z above. We can access these using a standalone infer_discrete handler. In this case the guide was trivial, so we can simply wrap the model in infer_discrete. We need to pass a first_available_dim argument to tell infer_discrete which dimensions are available for enumeration; this is related to the max_plate_nesting arg of TraceEnum_ELBO via

first_available_dim = -1 - max_plate_nesting
[42]:
serving_model = infer_discrete(model, first_available_dim=-1)
x, y, z = serving_model()  # takes the same args as model(), here no args
print("x = {}".format(x))
print("y = {}".format(y))
print("z = {}".format(z))
model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])
model x.shape = torch.Size([])
model y.shape = torch.Size([])
model z.shape = torch.Size([])
x = 0
y = 0
z = 0

(模型运行了两次,一次用于并行枚举,一次用于抽样)Notice that under the hood infer_discrete runs the model twice: first in forward-filter mode where sites are enumerated, then in replay-backward-sample model where sites are sampled.

(用于MAP推断)infer_discrete can also perform MAP inference by passing temperature=0. Note that while infer_discrete produces correct posterior samples, it does not currently produce correct logprobs, and should not be used in other gradient-based inference algorthms.

枚举变量的索引

It can be tricky to use advanced indexing to select an element of a tensor using one or more enumerated variables. For example, suppose a plated random variable z depends on two different random variables:

p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(),
               constraint=constraints.simplex)
x = pyro.sample("x", dist.Categorical(torch.ones(4)))
y = pyro.sample("y", dist.Categorical(torch.ones(3)))
with pyro.plate("z_plate", 5):
    p_xy = p[..., x, y, :]  # Not compatible with enumeration!
    z = pyro.sample("z", dist.Categorical(p_xy)

Due to advanced indexing semantics, the expression p[..., x, y, :] will work correctly without enumeration, but is incorrect when x or y is enumerated. Pyro provides a simple way to index correctly, but first let’s see how to correctly index using PyTorch’s advanced indexing without Pyro:

# Compatible with enumeration, but not recommended:
p_xy = p[torch.arange(5, device=p.device).reshape(5, 1),
         x.unsqueeze(-1),
         y.unsqueeze(-1),
         torch.arange(2, device=p.device)]

Pyro provides a helper Vindex()[] to use enumeration-compatible advanced indexing semantics rather than PyTorch/NumPy semantics. Vindex()[] makes the .__getitem__() operator broadcast like other familiar operators +, * etc. Using Vindex()[] we can write the same expression as if x and y were numbers (i.e. not enumerated):

# Recommended syntax compatible with enumeration:
p_xy = Vindex(p)[..., x, y, :]

Here is a complete example:

[43]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(torch.ones(4)))
    y = pyro.sample("y", dist.Categorical(torch.ones(3)))
    with pyro.plate("z_plate", 5):
        p_xy = Vindex(p)[..., x, y, :]
        z = pyro.sample("z", dist.Categorical(p_xy))
    print('   p.shape = {}'.format(p.shape))
    print('   x.shape = {}'.format(x.shape))
    print('   y.shape = {}'.format(y.shape))
    print('p_xy.shape = {}'.format(p_xy.shape))
    print('   z.shape = {}'.format(z.shape))
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide);
   p.shape = torch.Size([5, 4, 3, 2])
   x.shape = torch.Size([4, 1])
   y.shape = torch.Size([3, 1, 1])
p_xy.shape = torch.Size([3, 4, 5, 2])
   z.shape = torch.Size([2, 1, 1, 1])

Plates and enumeration

Pyro plates 表示随机变量之间的条件独立性。 Pyro’s enumeration strategy 可借助 Pyro plates 来减少计算代价 (exponential in the size of the plate) of enumerating a cartesian product down to a low cost (linear in the size of the plate) of enumerating conditionally independent random variables in lock-step. This is especially important for e.g. minibatched data.

To illustrate, consider a gaussian mixture model with shared variance and different mean.

[44]:
@config_enumerate
def model(data, num_components=3):
    print('Running model with {} data points'.format(len(data)))
    p = pyro.sample("p", dist.Dirichlet(0.5 * torch.ones(3)))
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))
    with pyro.plate("components", num_components):
        loc = pyro.sample("loc", dist.Normal(0, 10))
    with pyro.plate("data", len(data)):
        x = pyro.sample("x", dist.Categorical(p))
        print("x.shape = {}".format(x.shape))
        pyro.sample("obs", dist.Normal(loc[x], scale), obs=data)
        print("dist.Normal(loc[x], scale).batch_shape = {}".format(
            dist.Normal(loc[x], scale).batch_shape))

guide = AutoDiagonalNormal(poutine.block(model, hide=["x", "data"]))

data = torch.randn(10)

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, data);
Running model with 10 data points
x.shape = torch.Size([10])
dist.Normal(loc[x], scale).batch_shape = torch.Size([10])
Running model with 10 data points
x.shape = torch.Size([3, 1])
dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])

Observe that the model is run twice, first by the AutoDiagonalNormal to trace sample sites, and second by elbo to compute loss. In the first run, x has the standard interpretation of one sample per datum, hence shape (10,). In the second run enumeration can use the same three values (3,1) for all data points, and relies on broadcasting for any dependent sample or observe sites that depend on data. For example, in the pyro.sample("obs",...) statement, the distribution has shape (3,1), the data has shape(10,), and the broadcasted log probability tensor has shape (3,10).

For a more in-depth treatment of enumeration in mixture models, see the Gaussian Mixture Model Tutorial and the HMM Example.

Dependencies among plates

The computational savings of enumerating in vectorized plates comes with restrictions on the dependency structure of models. These restrictions are in addition to the usual restrictions of conditional independence. The enumeration restrictions are checked by TraceEnum_ELBO and will result in an error if violated (however the usual conditional independence restriction cannot be generally verified by Pyro). For completeness we list all three restrictions:

Restriction 1: conditional independence

Variables within a plate may not depend on each other (along the plate dimension). This applies to any variable, whether or not it is enumerated. This applies to both sequential plates and vectorized plates. For example the following model is invalid:

def invalid_model():
    x = 0
    for i in pyro.plate("invalid", 10):
        x = pyro.sample("x_{}".format(i), dist.Normal(x, 1.))

Restriction 2: no downstream coupling

No variable outside of a vectorized plate can depend on an enumerated variable inside of that plate. This would violate Pyro’s exponential speedup assumption. For example the following model is invalid:

@config_enumerate
def invalid_model(data):
    with pyro.plate("plate", 10):  # <--- invalid vectorized plate
        x = pyro.sample("x", dist.Bernoulli(0.5))
    assert x.shape == (10,)
    pyro.sample("obs", dist.Normal(x.sum(), 1.), data)

To work around this restriction, you can convert the vectorized plate to a sequential plate:

@config_enumerate
def valid_model(data):
    x = []
    for i in pyro.plate("plate", 10):  # <--- valid sequential plate
        x.append(pyro.sample("x_{}".format(i), dist.Bernoulli(0.5)))
    assert len(x) == 10
    pyro.sample("obs", dist.Normal(sum(x), 1.), data)

Restriction 3: single path leaving each plate

The final restriction is subtle, but is required to enable Pyro’s exponential speedup

For any enumerated variable x, the set of all enumerated variables on which x depends must be linearly orderable in their vectorized plate nesting.

This requirement only applies when there are at least two plates and at least three variables in different plate contexts. The simplest counterexample is a Boltzmann machine

@config_enumerate
def invalid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10, dim=-2)  # vectorized
    with plate_1:
        x = pyro.sample("y", dist.Bernoulli(0.5))
    with plate_2:
        y = pyro.sample("x", dist.Bernoulli(0.5))
    with plate_1, plate2:
        z = pyro.sample("z", dist.Bernoulli((1. + x + y) / 4.))
        ...

Here we see that the variable z depends on variable x (which is in plate_1 but not plate_2) and depends on variable y (which is in plate_2 but not plate_1). This model is invalid because there is no way to linearly order x and y such that one’s plate nesting is less than the other.

To work around this restriction, you can convert one of the plates to a sequential plate:

@config_enumerate
def valid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10)          # sequential
    with plate_1:
        x = pyro.sample("y", dist.Bernoulli(0.5))
    for i in plate_2:
        y = pyro.sample("x_{}".format(i), dist.Bernoulli(0.5))
        with plate_1:
            z = pyro.sample("z_{}".format(i), dist.Bernoulli((1. + x + y) / 4.))
            ...

but beware that this increases the computational complexity, which may be exponential in the size of the sequential plate.

时间序列示例

考虑一个隐变量为 \(x_t\) 观测变量为 \(y_t\) 的 HMM 模型。 Suppose we want to learn the transition and emission probabilities.

[45]:
data_dim = 4
num_steps = 10
data = dist.Categorical(torch.ones(num_steps, data_dim)).sample()

def hmm_model(data, data_dim, hidden_dim=10):
    print('Running for {} time steps'.format(len(data)))
    # Sample global matrices wrt a Jeffreys prior.
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in enumerate(data):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample("y_{}".format(t), dist.Categorical(emission[x]), obs=y)
        print("x_{}.shape = {}".format(t, x.shape))

# FAQ transition 的维度是啥?

We can learn the global parameters using SVI with an autoguide.

[46]:
hmm_guide = AutoDiagonalNormal(poutine.block(hmm_model, expose=["transition", "emission"]))

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
Running for 10 time steps
x_0.shape = torch.Size([])
x_1.shape = torch.Size([])
x_2.shape = torch.Size([])
x_3.shape = torch.Size([])
x_4.shape = torch.Size([])
x_5.shape = torch.Size([])
x_6.shape = torch.Size([])
x_7.shape = torch.Size([])
x_8.shape = torch.Size([])
x_9.shape = torch.Size([])
Running for 10 time steps
x_0.shape = torch.Size([10, 1])
x_1.shape = torch.Size([10, 1, 1])
x_2.shape = torch.Size([10, 1, 1, 1])
x_3.shape = torch.Size([10, 1, 1, 1, 1])
x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])
x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])
x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])
x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])
x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])
x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

Notice that the model was run twice here: first it was run without enumeration by AutoDiagonalNormal, so that the autoguide can record all sample sites; then second it is run by TraceEnum_ELBO with enumeration enabled. We see in the first run that samples have the standard interpretation, whereas in the second run samples have the enumeration interpretation.

For more complex examples, including minibatching and multiple plates, see the HMM tutorial.

How to enumerate more than 25 variables

PyTorch tensors have a dimension limit of 25 in CUDA and 64 in CPU. By default Pyro enumerates each sample site in a new dimension. If you need more sample sites, you can annotate your model with pyro.markov to tell Pyro when it is safe to recycle tensor dimensions. Let’s see how that works with the HMM model from above. The only change we need is to annotate the for loop with pyro.markov, informing Pyro that the variables in each step of the loop depend only on variables outside of the loop and variables at this step and the previous step of the loop:

- for t, y in enumerate(data):
+ for t, y in pyro.markov(enumerate(data)):
[47]:
def hmm_model(data, data_dim, hidden_dim=10):
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in pyro.markov(enumerate(data)):
        x = pyro.sample("x_{}".format(t), dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample("y_{}".format(t), dist.Categorical(emission[x]), obs=y)
        print("x_{}.shape = {}".format(t, x.shape))

# We'll reuse the same guide and elbo.
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
x_0.shape = torch.Size([10, 1])
x_1.shape = torch.Size([10, 1, 1])
x_2.shape = torch.Size([10, 1])
x_3.shape = torch.Size([10, 1, 1])
x_4.shape = torch.Size([10, 1])
x_5.shape = torch.Size([10, 1, 1])
x_6.shape = torch.Size([10, 1])
x_7.shape = torch.Size([10, 1, 1])
x_8.shape = torch.Size([10, 1])
x_9.shape = torch.Size([10, 1, 1])

Notice that this model now only needs three tensor dimensions: one for the plate, one for even states, and one for odd states. For more complex examples, see the Dynamic Bayes Net model in the HMM example.