Note
Pyro 中分布具备维度概念,并且随机函数的维度不是数据的维度!并且非常重要的事情是你需要把所有的数据看成某“一个”随机向量的样本。
Pyro中随机函数的维度¶
本教程介绍 Pyro‘s organization of tensor dimensions. 开始之前,您应该熟悉PyTorch 广播语义, see PyTorch broadcasting semantics.
并行枚举是本节最难懂的内容,它到底如何理解?
学完本文,您将学会:
训练模型和 debug 的时候,设置
pyro.enable_validation(True)
.张量按照最右边的维度进行 Broadcasting:
torch.ones(3,4,5) + torch.ones(5)
.样本维度 =
batch_shape + event_shape
.对于样本 \(x\)
model.log_prob(x).shape == batch_shape
用
.expand()
抽取一批量的样本,或者用plate
自动 expand.使用
my_dist.to_event(n)
将右起 n 个维度声明为 dependent,从而是实现随机函数 reshape.使用
with pyro.plate('name', size):
来申明某个维度条件独立性。所有的维度必须被申明为 dependent 或者条件独立.
Try to support batching on the left. This lets Pyro auto-parallelize.
use negative indices like
x.sum(-1)
rather thanx.sum(2)
use ellipsis notation like
pixel = image[..., i, j]
use Vindex if
i,j
are enumerated,pixel = Vindex(image)[..., i, j]
When debugging, examine all shapes in a trace using Trace.format_shapes().
[1]:
import os, torch, pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True) # <---- This is always a good idea!
# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
随机函数的维度¶
Pytorch 中的 Tensor
只有一个维度属性 .shape
,但是 Pyro 中 Distribution
有两个维度属性: .batch_shape
和 .event_shape
. 他们两个一起定义了一个样本的维度。
x = d.sample()
assert x.shape == d.batch_shape + d.event_shape
# 两个元组的并,不要误解成对应每个维度相加
数学上来说,随机变量 \(X \sim p(x; \theta)\) 的某个事件是 \(X=x\),因此维度就是 event_shape
。而我们需要一个 batch 不同参数 \(\theta\) 的 r.v. \(X\) (称之为 batch r.v.),所以 batch_shape
就是不同参数同一类型的随机变量数量的维度。于是最后的一个 “batch r.v.(batch of 同种类型随机变量)” 样本的维度就是 batch_shape
并上 event_shape
.
Indices over .batch_shape
意味着随机变量之间的条件独立性, 而 indices over .event_shape
意味着随机向量分量之间的相依性 (即从分布中抽出一个样本). 因为一个随机向量样本对应这一个概率值, 所以 .log_prob()
方法为 each event of shape .event_shape
产生一个概率值,所以 .log_prob()
方法输出的维度就是 .batch_shape
:
# 每个样本对应一个概率值,batch_shape 就是 batch r.v. 的维度。
assert d.log_prob(x).shape == d.batch_shape
请注意 Distribution.sample()
方法也有参数 sample_shape
, 它用于表示batch r.v. 的独立同分布采样维度,所以
x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape
也就是说 Pyro 中每个分布(也就是随机函数)都具备两个维度 batch_shape 和 event_shape,抽样之后我们还可以得到一个 sample_shape
。
我们来举个例子,某个班级每个学生身高和体重组成一个随机向量,那么 event_shape = (2, )
;而一个学校有多个不同的班级,其中每个班级具备不同的身高体重分布参数,那么 batch_shape = 班级数量
;最后这个学校全体学生来自于某个总体,那么此时从该总体中可以抽样得到 sample_shape
个样本。
In summary
| iid | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape
简单例子¶
最简单的分布形状是一个单变量分布。
[5]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()
通过传入批量参数可得到 batched 分布(batch r.v.)。
[6]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
另外一种 batch distributions 的方式是使用 .expand()
method(该方法要求最右边的维度要相等).
[2]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)
多元分布具有非空 .event_shape
. 对于这些分布, .sample()
和 .log_prob(x)
的维度会不一样:
[8]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,) # == batch_shape + event_shape
assert d.log_prob(x).shape == () # == batch_shape
下一小节提示:
x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))
with pyro.plate("x_plate", 10):
x = pyro.sample("x", dist.Normal(0, 1))
变换event_shape
¶
在 Pyro 中,您可以使用 .to_event(n) 将一维分布转化成多维分布(也就是随便变量 stack 成随机向量),其中 n
表示从右边开始数 n 个维度,声明 dependent
(也就是他们为随机向量). 所以 to_event(0)
相当于分布没有维度变换,而 to_event(1)
表示随机向量的维度为原来分布 batch_shape
的最后一个维度。
[24]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)
在使用Pyro程序时,请记住样本具有维度 batch_shape + event_shape
, 而 .log_prob(x)
具有维度batch_shape
. 你需要确保 batch_shape
is carefully controlled by either trimming it down with .to_event(n)
or by declaring dimensions as independent via pyro.plate
.
实际应用中:It is always safe to assume dependence。在 Pyro 中我们经常将某些维度声明为 dependent,即使他们实际上是独立的, e.g.
x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)
这很有用,原因有两个:
(容易构造多元分布) it allows us to easily swap in a
MultivariateNormal
distribution later.(简化代码,避免使用
plate
) it simplifies the code a bit since we don’t need aplate
(see below) as in
with pyro.plate("x_plate", 10):
x = pyro.sample("x", dist.Normal(0, 1)) # .expand([10]) is automatic
assert x.shape == (10,)
这两个版本之间的区别是: 第二个有 plate
的版本告知 Pyro 使用条件独立性信息 when estimating gradients, 而第一个版本 Pyro 必须假定分量之间是相关的 (even though the normals are in fact conditionally independent). 这类似于图模型中的 \(d\)-分离: it is always safe to add edges and assume variables may be dependent (i.e. to widen the model class), but it is unsafe to assume independence when variables are actually dependent (i.e. narrowing the model class so the true model lies outside of the
class, as in mean field).
实际上,Pyro的 SVI 推理算法对 Normal
distributions 使用了重参数化技巧,因此两个梯度估计量具有相同的性能。
下节内容提示:
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
x = pyro.sample("x", Normal(0, 1))
with y_axis:
y = pyro.sample("y", Normal(0, 1))
with x_axis, y_axis:
xy = pyro.sample("xy", Normal(0, 1))
z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1) # batch_shape == (3,1) event_shape == ()
assert y.shape == (2, 1, 1) # batch_shape == (2,1,1) event_shape == ()
assert xy.shape == (2, 3, 1) # batch_shape == (2,3,1) event_shape == ()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2,3,1) event_shape == (5,)
变换batch_shape
¶
用 plate
声明条件独立性来变换 batch_shape
.
Pyro模型可以使用上下文管理器 pyro.plate 来申明特定批量维度(batch dimensions) 是独立的。然后,推理算法可以利用这种独立性,e.g. construct lower variance gradient estimators or to enumerate in linear space rather than exponential space. An example of an independent dimension is the index over data in a minibatch: each datum should be independent of all others.
最简单的方式申明独立行的方法如下:
with pyro.plate("my_plate"):
# within this context, batch dimension -1 is independent
我们推荐提供一个 size
参数来帮助 debug shapes:
with pyro.plate("my_plate", len(my_data)):
# within this context, batch dimension -1 is independent
从 Pyro0.2开始,您还可以嵌套 plate
, e.g. if you have per-pixel independence:
with pyro.plate("x_axis", 320):
# within this context, batch dimension -1 is independent
with pyro.plate("y_axis", 200):
# within this context, batch dimensions -2 and -1 are independent
注意我们总是用负数 -2, -1 来表示从右边索引。
Finally if you want to mix and match plate
s for e.g. noise that depends only on x
, some noise that depends only on y
, and some noise that depends on both, you can declare multiple plates
and use them as reusable context managers. In this case Pyro cannot automatically allocate a dimension, so you need to provide a dim
argument (again counting from the right):
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
# within this context, batch dimension -2 is independent
with y_axis:
# within this context, batch dimension -3 is independent
with x_axis, y_axis:
# within this context, batch dimensions -3 and -2 are independent
Let’s take a closer look at batch sizes within plate
s.
[73]:
def model1():
a = pyro.sample("a", Normal(0, 1))
b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
with pyro.plate("c_plate", 2):
c = pyro.sample("c", Normal(torch.zeros(2), 1))
with pyro.plate("d_plate", 3):
d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
assert a.shape == () # batch_shape == () event_shape == ()
assert b.shape == (2,) # batch_shape == () event_shape == (2,)
assert c.shape == (2,) # batch_shape == (2,) event_shape == ()
assert d.shape == (3,4,5) # batch_shape == (3,) event_shape == (4,5)
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
x = pyro.sample("x", Normal(0, 1))
with y_axis:
y = pyro.sample("y", Normal(0, 1))
with x_axis, y_axis:
xy = pyro.sample("xy", Normal(0, 1))
z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1) # batch_shape == (3,1) event_shape == ()
assert y.shape == (2, 1, 1) # batch_shape == (2,1,1) event_shape == ()
assert xy.shape == (2, 3, 1) # batch_shape == (2,3,1) event_shape == ()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2,3,1) event_shape == (5,)
test_model(model1, model1, Trace_ELBO())
It is helpful to visualize the .shape
s of each sample site by aligning them at the boundary between batch_shape
and event_shape
: dimensions to the right will be summed out in .log_prob()
and dimensions to the left will remain.
batch dims | event dims
-----------+-----------
| a = sample("a", Normal(0, 1))
|2 b = sample("b", Normal(zeros(2), 1)
| .to_event(1))
| with plate("c", 2):
2| c = sample("c", Normal(zeros(2), 1))
| with plate("d", 3):
3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| .to_event(2))
|
| x_axis = plate("x", 3, dim=-2)
| y_axis = plate("y", 2, dim=-3)
| with x_axis:
3 1| x = sample("x", Normal(0, 1))
| with y_axis:
2 1 1| y = sample("y", Normal(0, 1))
| with x_axis, y_axis:
2 3 1| xy = sample("xy", Normal(0, 1))
2 3 1|5 z = sample("z", Normal(0, 1).expand([5])
| .to_event(1))
为了自动的检查程序中样本的维度,you can trace the program 并且使用方法 Trace.format_shapes() 来打印每个 sample site 的三种维度:
分布的维度(both
site["fn"].batch_shape
andsite["fn"].event_shape
),the value shape(
site["value"].shape
),如果对数概率被计算来,那么也返回
log_prob
维度 (site["log_prob"].shape
):
[74]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
log_prob |
b dist | 2
value | 2
log_prob |
c_plate dist |
value 2 |
log_prob |
c dist 2 |
value 2 |
log_prob 2 |
d_plate dist |
value 3 |
log_prob |
d dist 3 | 4 5
value 3 | 4 5
log_prob 3 |
x_axis dist |
value 3 |
log_prob |
y_axis dist |
value 2 |
log_prob |
x dist 3 1 |
value 3 1 |
log_prob 3 1 |
y dist 2 1 1 |
value 2 1 1 |
log_prob 2 1 1 |
xy dist 2 3 1 |
value 2 3 1 |
log_prob 2 3 1 |
z dist 2 3 1 | 5
value 2 3 1 | 5
log_prob 2 3 1 |
plate
内部张量子采样¶
plate 的主要用途之一是对数据进行子采样。因为 plate
内部数据是条件独立的, 所以子采样之后计算的损失期望估计不会变。如果要对样本数据进行子采样,您需要同时将原始数据大小和子样本大小告知 Pyro。然后 Pyro 将选择一个随机数据子集并产生一组索引。
[3]:
data = torch.arange(100.)
def model2():
mean = pyro.param("mean", torch.zeros(len(data)))
with pyro.plate("data", len(data), subsample_size=10) as ind:
assert len(ind) == 10 # ind is a LongTensor that indexes the subsample.
batch = data[ind] # Select a minibatch of data.
mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.
# Do stuff with batch:
x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
assert len(x) == 10
test_model(model2, guide=lambda: None, loss=Trace_ELBO())
广播和并行计算¶
Broadcasting to allow parallel enumeration
Pyro 0.2 引入了并行枚举离散潜变量的功能。This can significantly reduce the variance of gradient estimators when learning a posterior via SVI.
To use parallel enumeration, Pyro needs to allocate tensor dimension that it can use for enumeration. To avoid conflicting with other dimensions that we want to use for plate
s, we need to declare a budget of the maximum number of tensor dimensions we’ll use. This budget is called max_plate_nesting
and is an argument to SVI (the argument is simply passed through to
TraceEnum_ELBO). Usually Pyro can determine this budget on its own (it runs the (model,guide)
pair once and record what happens), but in case of dynamic model structure you may need to declare max_plate_nesting
manually.
To understand max_plate_nesting
and how Pyro allocates dimensions for enumeration, let’s revisit model1()
from above. This time we’ll map out three types of dimensions: enumeration dimensions on the left (Pyro takes control of these), batch dimensions in the middle, and event dimensions on the right.
max_plate_nesting = 3
|<--->|
enumeration|batch|event
-----------+-----+-----
|. . .| a = sample("a", Normal(0, 1))
|. . .|2 b = sample("b", Normal(zeros(2), 1)
| | .to_event(1))
| | with plate("c", 2):
|. . 2| c = sample("c", Normal(zeros(2), 1))
| | with plate("d", 3):
|. . 3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| | .to_event(2))
| |
| | x_axis = plate("x", 3, dim=-2)
| | y_axis = plate("y", 2, dim=-3)
| | with x_axis:
|. 3 1| x = sample("x", Normal(0, 1))
| | with y_axis:
|2 1 1| y = sample("y", Normal(0, 1))
| | with x_axis, y_axis:
|2 3 1| xy = sample("xy", Normal(0, 1))
|2 3 1|5 z = sample("z", Normal(0, 1).expand([5]))
| | .to_event(1))
Note that it is safe to overprovision max_plate_nesting=4
but we cannot underprovision max_plate_nesting=2
(or Pyro will error). Let’s see how this works in practice.
[4]:
@config_enumerate
def model3():
p = pyro.param("p", torch.arange(6.) / 6)
locs = pyro.param("locs", torch.tensor([-1., 1.]))
a = pyro.sample("a", Categorical(torch.ones(6) / 6))
b = pyro.sample("b", Bernoulli(p[a])) # Note this depends on a.
with pyro.plate("c_plate", 4):
c = pyro.sample("c", Bernoulli(0.3))
with pyro.plate("d_plate", 5):
d = pyro.sample("d", Bernoulli(0.4))
e_loc = locs[d.long()].unsqueeze(-1)
e_scale = torch.arange(1., 8.)
e = pyro.sample("e", Normal(e_loc, e_scale)
.to_event(1)) # Note this depends on d.
# enumerated|batch|event dims
assert a.shape == ( 6, 1, 1 ) # Six enumerated values of the Categorical.
assert b.shape == ( 2, 1, 1, 1 ) # Two enumerated Bernoullis, unexpanded.
assert c.shape == ( 2, 1, 1, 1, 1 ) # Only two Bernoullis, unexpanded.
assert d.shape == (2, 1, 1, 1, 1, 1 ) # Only two Bernoullis, unexpanded.
assert e.shape == (2, 1, 1, 1, 5, 4, 7) # This is sampled and depends on d.
assert e_loc.shape == (2, 1, 1, 1, 1, 1, 1,)
assert e_scale.shape == ( 7,)
test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
Let’s take a closer look at those dimensions. First note that Pyro allocates enumeration dims starting from the right at max_plate_nesting
: Pyro allocates dim -3 to enumerate a
, then dim -4 to enumerate b
, then dim -5 to enumerate c
, and finally dim -6 to enumerate d
. Next note that samples only have extent (size > 1) in the new enumeration dimension. This helps keep tensors small and computation cheap. (Note that the log_prob
shape will be broadcast up to contain both
enumeratin shape and batch shape, so e.g. trace.nodes['d']['log_prob'].shape == (2, 1, 1, 1, 5, 4)
.)
We can draw a similar map of the tensor dimensions:
max_plate_nesting = 2
|<->|
enumeration batch event
------------|---|-----
6|1 1| a = pyro.sample("a", Categorical(torch.ones(6) / 6))
2 1|1 1| b = pyro.sample("b", Bernoulli(p[a]))
| | with pyro.plate("c_plate", 4):
2 1 1|1 1| c = pyro.sample("c", Bernoulli(0.3))
| | with pyro.plate("d_plate", 5):
2 1 1 1|1 1| d = pyro.sample("d", Bernoulli(0.4))
2 1 1 1|1 1|1 e_loc = locs[d.long()].unsqueeze(-1)
| |7 e_scale = torch.arange(1., 8.)
2 1 1 1|5 4|7 e = pyro.sample("e", Normal(e_loc, e_scale)
| | .to_event(1))
To automatically examine this model with enumeration semantics, we can create an enumerated trace and then use Trace.format_shapes():
[11]:
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
Param Sites:
p 6
locs 2
Sample Sites:
a dist |
value 6 1 1 |
log_prob 6 1 1 |
b dist 6 1 1 |
value 2 1 1 1 |
log_prob 2 6 1 1 |
c_plate dist |
value 4 |
log_prob |
c dist 4 |
value 2 1 1 1 1 |
log_prob 2 1 1 1 4 |
d_plate dist |
value 5 |
log_prob |
d dist 5 4 |
value 2 1 1 1 1 1 |
log_prob 2 1 1 1 5 4 |
e dist 2 1 1 1 5 4 | 7
value 2 1 1 1 5 4 | 7
log_prob 2 1 1 1 5 4 |
Writing parallelizable code¶
It can be tricky to write Pyro models that correctly handle parallelized sample sites. Two tricks help: broadcasting and ellipsis slicing. Let’s look at a contrived model to see how these work in practice. 我们的目标是写一个模型 that works both with and without enumeration.
[79]:
width, height = 8, 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None # set to either True or False below
def fun(observe):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
# Note that the shapes of these sites depend on whether Pyro is enumerating.
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y))
if enumerated:
assert x_active.shape == (2, 1, 1)
assert y_active.shape == (2, 1, 1, 1)
else:
assert x_active.shape == (width, 1)
assert y_active.shape == (height,)
# The first trick is to broadcast. This works with or without enumeration.
p = 0.1 + 0.5 * x_active * y_active
if enumerated:
assert p.shape == (2, 2, 1, 1)
else:
assert p.shape == (width, height)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
# The second trick is to index using ellipsis slicing.
# This allows Pyro to add arbitrary dimensions on the left.
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
if enumerated:
assert dense_pixels.shape == (2, 2, width, height)
else:
assert dense_pixels.shape == (width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
def guide4():
fun(observe=False)
# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())
# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
TraceEnum_ELBO(max_plate_nesting=2))
plate
内部自动广播¶
所有的 model/guide 定义中,我们依赖于 pyro.plate 自动扩展样本维度,来满足 pyro.sample
语句中关于 batch shape 的约束。However this broadcasting is equivalent to hand-annotated .expand()
statements.
我们将使用 previous section 中的 model4
对此进行演示。 请注意对先前代码的以下更改:
For the purpose of this example, we will only consider “parallel” enumeration, but broadcasting should work as expected without enumeration or with “sequential” enumeration.
We have separated out the sampling function which returns the tensors corresponding to the active pixels. Modularizing the model code into components is a common practice, and helps with maintainability of large models.
We would also like to use the
pyro.plate
construct to parallelize the ELBO estimator over num_particles. This is done by wrapping the contents of model/guide inside an outermostpyro.plate
context.
[78]:
num_particles = 100 # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand([num_particles, width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand([num_particles, 1, height]))
return x_active, y_active
def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y))
return x_active, y_active
def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand([width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand([height]))
return x_active, y_active
def fun(observe, sample_fn):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
with pyro.plate("num_particles", 100, dim=-3):
x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
# Indices corresponding to "parallel" enumeration are appended
# to the left of the "num_particles" plate dim.
assert x_active.shape == (2, 1, 1, 1)
assert y_active.shape == (2, 1, 1, 1, 1)
p = 0.1 + 0.5 * x_active * y_active
assert p.shape == (2, 2, 1, 1, 1)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
assert dense_pixels.shape == (2, 2, 1, width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def test_model_with_sample_fn(sample_fn):
def model():
fun(observe=True, sample_fn=sample_fn)
@config_enumerate
def guide():
fun(observe=False, sample_fn=sample_fn)
test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=3))
test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)
In the first sampling function, we had to do some manual book-keeping and expand the Bernoulli
distribution’s batch shape to account for the conditionally independent dimensions added by the pyro.plate
contexts. In particular, note how sample_pixel_locations
needs knowledge of num_particles
, width
and height
and is accessing these variables from the global scope, which is not ideal.
The second argument to
pyro.plate
, i.e. the optionalsize
argument needs to be provided for implicit broadasting, so that it can infer the batch shape requirement for each of the sample sites.The existing
batch_shape
of the sample site must be broadcastable with the size of thepyro.plate
contexts. In our particular example,Bernoulli(p_x)
has an empty batch shape which is universally broadcastable.
Note how simple it is to achieve parallelization via tensorized operations using pyro.plate
! pyro.plate
also helps in code modularization because model components can be written agnostic of the plate
contexts in which they may subsequently get embedded in.