Note

Pyro 中分布具备维度概念,并且随机函数的维度不是数据的维度!并且非常重要的事情是你需要把所有的数据看成某“一个”随机向量的样本。


Pyro中随机函数的维度

本教程介绍 Pyro‘s organization of tensor dimensions. 开始之前,您应该熟悉PyTorch 广播语义, see PyTorch broadcasting semantics.

并行枚举是本节最难懂的内容,它到底如何理解?

学完本文,您将学会:

  • 训练模型和 debug 的时候,设置 pyro.enable_validation(True).

  • 张量按照最右边的维度进行 Broadcasting: torch.ones(3,4,5) + torch.ones(5).

  • 样本维度 = batch_shape + event_shape.

  • 对于样本 \(x\) model.log_prob(x).shape == batch_shape

  • .expand() 抽取一批量的样本,或者用 plate 自动 expand.

  • 使用 my_dist.to_event(n) 将右起 n 个维度声明为 dependent,从而是实现随机函数 reshape.

  • 使用 with pyro.plate('name', size): 来申明某个维度条件独立性。

  • 所有的维度必须被申明为 dependent 或者条件独立.

  • Try to support batching on the left. This lets Pyro auto-parallelize.

    • use negative indices like x.sum(-1) rather than x.sum(2)

    • use ellipsis notation like pixel = image[..., i, j]

    • use Vindex if i,j are enumerated, pixel = Vindex(image)[..., i, j]

  • When debugging, examine all shapes in a trace using Trace.format_shapes().

[1]:
import os, torch, pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)    # <---- This is always a good idea!

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

随机函数的维度

Pytorch 中的 Tensor 只有一个维度属性 .shape,但是 Pyro 中 Distribution 有两个维度属性: .batch_shape.event_shape. 他们两个一起定义了一个样本的维度。

x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
# 两个元组的并,不要误解成对应每个维度相加

数学上来说,随机变量 \(X \sim p(x; \theta)\) 的某个事件是 \(X=x\),因此维度就是 event_shape。而我们需要一个 batch 不同参数 \(\theta\) 的 r.v. \(X\) (称之为 batch r.v.),所以 batch_shape 就是不同参数同一类型的随机变量数量的维度。于是最后的一个 “batch r.v.(batch of 同种类型随机变量)” 样本的维度就是 batch_shape 并上 event_shape.

Indices over .batch_shape 意味着随机变量之间的条件独立性, 而 indices over .event_shape 意味着随机向量分量之间的相依性 (即从分布中抽出一个样本). 因为一个随机向量样本对应这一个概率值, 所以 .log_prob() 方法为 each event of shape .event_shape 产生一个概率值,所以 .log_prob() 方法输出的维度就是 .batch_shape:

# 每个样本对应一个概率值,batch_shape 就是 batch r.v. 的维度。
assert d.log_prob(x).shape == d.batch_shape

请注意 Distribution.sample() 方法也有参数 sample_shape, 它用于表示batch r.v. 的独立同分布采样维度,所以

x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape

也就是说 Pyro 中每个分布(也就是随机函数)都具备两个维度 batch_shape 和 event_shape,抽样之后我们还可以得到一个 sample_shape

我们来举个例子,某个班级每个学生身高和体重组成一个随机向量,那么 event_shape = (2, );而一个学校有多个不同的班级,其中每个班级具备不同的身高体重分布参数,那么 batch_shape = 班级数量;最后这个学校全体学生来自于某个总体,那么此时从该总体中可以抽样得到 sample_shape 个样本。

In summary

      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape

简单例子

最简单的分布形状是一个单变量分布。

[5]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

通过传入批量参数可得到 batched 分布(batch r.v.)。

[6]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

另外一种 batch distributions 的方式是使用 .expand() method(该方法要求最右边的维度要相等).

[2]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

多元分布具有非空 .event_shape. 对于这些分布, .sample().log_prob(x) 的维度会不一样:

[8]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

下一小节提示:

x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))

with pyro.plate("x_plate", 10):
    x = pyro.sample("x", dist.Normal(0, 1))

变换event_shape

在 Pyro 中,您可以使用 .to_event(n) 将一维分布转化成多维分布(也就是随便变量 stack 成随机向量),其中 n 表示从右边开始数 n 个维度,声明 dependent (也就是他们为随机向量). 所以 to_event(0) 相当于分布没有维度变换,而 to_event(1) 表示随机向量的维度为原来分布 batch_shape 的最后一个维度。

[24]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

在使用Pyro程序时,请记住样本具有维度 batch_shape + event_shape, 而 .log_prob(x) 具有维度batch_shape. 你需要确保 batch_shape is carefully controlled by either trimming it down with .to_event(n) or by declaring dimensions as independent via pyro.plate.

实际应用中:It is always safe to assume dependence。在 Pyro 中我们经常将某些维度声明为 dependent,即使他们实际上是独立的, e.g.

x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)

这很有用,原因有两个:

  • (容易构造多元分布) it allows us to easily swap in a MultivariateNormal distribution later.

  • (简化代码,避免使用 plate) it simplifies the code a bit since we don’t need a plate (see below) as in

with pyro.plate("x_plate", 10):
    x = pyro.sample("x", dist.Normal(0, 1))  # .expand([10]) is automatic
    assert x.shape == (10,)

这两个版本之间的区别是: 第二个有 plate 的版本告知 Pyro 使用条件独立性信息 when estimating gradients, 而第一个版本 Pyro 必须假定分量之间是相关的 (even though the normals are in fact conditionally independent). 这类似于图模型中的 \(d\)-分离: it is always safe to add edges and assume variables may be dependent (i.e. to widen the model class), but it is unsafe to assume independence when variables are actually dependent (i.e. narrowing the model class so the true model lies outside of the class, as in mean field).

实际上,Pyro的 SVI 推理算法对 Normal distributions 使用了重参数化技巧,因此两个梯度估计量具有相同的性能。

下节内容提示:

x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
    x = pyro.sample("x", Normal(0, 1))
with y_axis:
    y = pyro.sample("y", Normal(0, 1))
with x_axis, y_axis:
    xy = pyro.sample("xy", Normal(0, 1))
    z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

变换batch_shape

plate 声明条件独立性来变换 batch_shape.

Pyro模型可以使用上下文管理器 pyro.plate 来申明特定批量维度(batch dimensions) 是独立的。然后,推理算法可以利用这种独立性,e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.

最简单的方式申明独立行的方法如下:

with pyro.plate("my_plate"):
    # within this context, batch dimension -1 is independent

我们推荐提供一个 size 参数来帮助 debug shapes:

with pyro.plate("my_plate", len(my_data)):
    # within this context, batch dimension -1 is independent

从 Pyro0.2开始,您还可以嵌套 plate, e.g. if you have per-pixel independence:

with pyro.plate("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.plate("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent

注意我们总是用负数 -2, -1 来表示从右边索引。

Finally if you want to mix and match plates for e.g. noise that depends only on x, some noise that depends only on y, and some noise that depends on both, you can declare multiple plates and use them as reusable context managers. In this case Pyro cannot automatically allocate a dimension, so you need to provide a dim argument (again counting from the right):

x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
    # within this context, batch dimension -2 is independent
with y_axis:
    # within this context, batch dimension -3 is independent
with x_axis, y_axis:
    # within this context, batch dimensions -3 and -2 are independent

Let’s take a closer look at batch sizes within plates.

[73]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_shape == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1))
        z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())

It is helpful to visualize the .shapes of each sample site by aligning them at the boundary between batch_shape and event_shape: dimensions to the right will be summed out in .log_prob() and dimensions to the left will remain.

batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .to_event(1))
           |        with plate("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with plate("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .to_event(2))
           |
           |        x_axis = plate("x", 3, dim=-2)
           |        y_axis = plate("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1))
      2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
           |                       .to_event(1))

为了自动的检查程序中样本的维度,you can trace the program 并且使用方法 Trace.format_shapes() 来打印每个 sample site 的三种维度:

  • 分布的维度(both site["fn"].batch_shape and site["fn"].event_shape),

  • the value shape(site["value"].shape),

  • 如果对数概率被计算来,那么也返回 log_prob 维度 (site["log_prob"].shape):

[74]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
 Param Sites:
Sample Sites:
       a dist       |
        value       |
     log_prob       |
       b dist       | 2
        value       | 2
     log_prob       |
 c_plate dist       |
        value     2 |
     log_prob       |
       c dist     2 |
        value     2 |
     log_prob     2 |
 d_plate dist       |
        value     3 |
     log_prob       |
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |
  x_axis dist       |
        value     3 |
     log_prob       |
  y_axis dist       |
        value     2 |
     log_prob       |
       x dist   3 1 |
        value   3 1 |
     log_prob   3 1 |
       y dist 2 1 1 |
        value 2 1 1 |
     log_prob 2 1 1 |
      xy dist 2 3 1 |
        value 2 3 1 |
     log_prob 2 3 1 |
       z dist 2 3 1 | 5
        value 2 3 1 | 5
     log_prob 2 3 1 |

plate内部张量子采样

plate 的主要用途之一是对数据进行子采样。因为 plate 内部数据是条件独立的, 所以子采样之后计算的损失期望估计不会变。如果要对样本数据进行子采样,您需要同时将原始数据大小和子样本大小告知 Pyro。然后 Pyro 将选择一个随机数据子集并产生一组索引。

[3]:
data = torch.arange(100.)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

广播和并行计算

  • Broadcasting to allow parallel enumeration

Pyro 0.2 引入了并行枚举离散潜变量的功能。This can significantly reduce the variance of gradient estimators when learning a posterior via SVI.

To use parallel enumeration, Pyro needs to allocate tensor dimension that it can use for enumeration. To avoid conflicting with other dimensions that we want to use for plates, we need to declare a budget of the maximum number of tensor dimensions we’ll use. This budget is called max_plate_nesting and is an argument to SVI (the argument is simply passed through to TraceEnum_ELBO). Usually Pyro can determine this budget on its own (it runs the (model,guide) pair once and record what happens), but in case of dynamic model structure you may need to declare max_plate_nesting manually.

To understand max_plate_nesting and how Pyro allocates dimensions for enumeration, let’s revisit model1() from above. This time we’ll map out three types of dimensions: enumeration dimensions on the left (Pyro takes control of these), batch dimensions in the middle, and event dimensions on the right.

      max_plate_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .to_event(1))
           |     |      with plate("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with plate("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .to_event(2))
           |     |
           |     |      x_axis = plate("x", 3, dim=-2)
           |     |      y_axis = plate("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1))
           |2 3 1|5         z = sample("z", Normal(0, 1).expand([5]))
           |     |                     .to_event(1))

Note that it is safe to overprovision max_plate_nesting=4 but we cannot underprovision max_plate_nesting=2 (or Pyro will error). Let’s see how this works in practice.

[4]:
@config_enumerate
def model3():
    p = pyro.param("p", torch.arange(6.) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.plate("c_plate", 4):
        c = pyro.sample("c", Bernoulli(0.3))
        with pyro.plate("d_plate", 5):
            d = pyro.sample("d", Bernoulli(0.4))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1., 8.)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .to_event(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

Let’s take a closer look at those dimensions. First note that Pyro allocates enumeration dims starting from the right at max_plate_nesting: Pyro allocates dim -3 to enumerate a, then dim -4 to enumerate b, then dim -5 to enumerate c, and finally dim -6 to enumerate d. Next note that samples only have extent (size > 1) in the new enumeration dimension. This helps keep tensors small and computation cheap. (Note that the log_prob shape will be broadcast up to contain both enumeratin shape and batch shape, so e.g. trace.nodes['d']['log_prob'].shape == (2, 1, 1, 1, 5, 4).)

We can draw a similar map of the tensor dimensions:

     max_plate_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.plate("c_plate", 4):
       2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
            |   |         with pyro.plate("d_plate", 5):
     2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1., 8.)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .to_event(1))

To automatically examine this model with enumeration semantics, we can create an enumerated trace and then use Trace.format_shapes():

[11]:
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
 Param Sites:
            p             6
         locs             2
Sample Sites:
       a dist             |
        value       6 1 1 |
     log_prob       6 1 1 |
       b dist       6 1 1 |
        value     2 1 1 1 |
     log_prob     2 6 1 1 |
 c_plate dist             |
        value           4 |
     log_prob             |
       c dist           4 |
        value   2 1 1 1 1 |
     log_prob   2 1 1 1 4 |
 d_plate dist             |
        value           5 |
     log_prob             |
       d dist         5 4 |
        value 2 1 1 1 1 1 |
     log_prob 2 1 1 1 5 4 |
       e dist 2 1 1 1 5 4 | 7
        value 2 1 1 1 5 4 | 7
     log_prob 2 1 1 1 5 4 |

Writing parallelizable code

It can be tricky to write Pyro models that correctly handle parallelized sample sites. Two tricks help: broadcasting and ellipsis slicing. Let’s look at a contrived model to see how these work in practice. 我们的目标是写一个模型 that works both with and without enumeration.

[79]:
width, height = 8, 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    if enumerated:
        assert x_active.shape  == (2, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
           TraceEnum_ELBO(max_plate_nesting=2))

plate内部自动广播

所有的 model/guide 定义中,我们依赖于 pyro.plate 自动扩展样本维度,来满足 pyro.sample 语句中关于 batch shape 的约束。However this broadcasting is equivalent to hand-annotated .expand() statements.

我们将使用 previous section 中的 model4 对此进行演示。 请注意对先前代码的以下更改:

  • For the purpose of this example, we will only consider “parallel” enumeration, but broadcasting should work as expected without enumeration or with “sequential” enumeration.

  • We have separated out the sampling function which returns the tensors corresponding to the active pixels. Modularizing the model code into components is a common practice, and helps with maintainability of large models.

  • We would also like to use the pyro.plate construct to parallelize the ELBO estimator over num_particles. This is done by wrapping the contents of model/guide inside an outermost pyro.plate context.

[78]:
num_particles = 100  # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand([num_particles, width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand([num_particles, 1, height]))
    return x_active, y_active

def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand([height]))
    return x_active, y_active

def fun(observe, sample_fn):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    with pyro.plate("num_particles", 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        # Indices corresponding to "parallel" enumeration are appended
        # to the left of the "num_particles" plate dim.
        assert x_active.shape  == (2, 1, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1, 1)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, 1, 1, 1)

        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, 1, width, height)

        with x_axis, y_axis:
            if observe:
                pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn):
    def model():
        fun(observe=True, sample_fn=sample_fn)

    @config_enumerate
    def guide():
        fun(observe=False, sample_fn=sample_fn)

    test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=3))

test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)

In the first sampling function, we had to do some manual book-keeping and expand the Bernoulli distribution’s batch shape to account for the conditionally independent dimensions added by the pyro.plate contexts. In particular, note how sample_pixel_locations needs knowledge of num_particles, width and height and is accessing these variables from the global scope, which is not ideal.

  • The second argument to pyro.plate, i.e. the optional size argument needs to be provided for implicit broadasting, so that it can infer the batch shape requirement for each of the sample sites.

  • The existing batch_shape of the sample site must be broadcastable with the size of the pyro.plate contexts. In our particular example, Bernoulli(p_x) has an empty batch shape which is universally broadcastable.

Note how simple it is to achieve parallelization via tensorized operations using pyro.plate! pyro.plate also helps in code modularization because model components can be written agnostic of the plate contexts in which they may subsequently get embedded in.