Poutine: Pyro 中使用 Effect Handlers 编程手册

Poutine: A Guide to Programming with Effect Handlers in Pyro 其预备知识:

(该教程是 Pyro 的 EH 库 Poutine 的指导手册,推荐读者先阅读 minipyro.py 会有所帮助,因为 Poutine 可视作它的推广。)This tutorial is a guide to the API details of Pyro’s effect handling library, Poutine. We recommend readers first orient themselves with the simplified minipyro.py which contains a minimal, readable implementation of Pyro’s runtime and the effect handler abstraction described here. Pyro’s effect handler library is more general than minipyro’s but also contains more layers of indirection; it helps to read them side-by-side.

Messenger 是 mini-pyro 库的核心数据结构,trace, replay, block, seed, PlateMessenger 都是 Messenger 的子类。Messengers are stateful context manager objects that are placed on a global stack and send messages (hence the name) up and down the stack at each effectful operation, like a pyro.sample call.

  • trace: trace records the inputs and outputs of any primitive site it encloses, and returns a dictionary containing that data to the user.

  • replay: an effect handler for setting the value at a sample site.

  • block: allows the selective application of effect handlers to different parts of a model. Sites hidden by block will only have the handlers below block on the PYRO_STACK applied, allowing inference or other effectful computations to be nested inside models.

  • seed: is used to fix the RNG state when calling a model.

  • PlateMessenger: This limited implementation of PlateMessenger only implements broadcasting.

思考:Messenger 这个数据结构如何帮助 Pyro 实现其核心功能?

[2]:
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.poutine.runtime import effectful
pyro.set_rng_seed(101)

Introduction

问题背景介绍

概率编程中的推断会涉及到操作或者变换写成生成模型的概率程序。 例如,几乎所有近似推断算法都需要在某个生成模型下,计算非标准化的潜变量和观测变量的联合概率分布. 考虑以下示例模型 from the introductory inference tutorial:

[2]:
mu = 8.5
def scale(mu):
    weight = pyro.sample("weight", dist.Normal(mu, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

该模型定义了 "weight""measurement" 的一个联合分布:

\[{\sf weight} \, \sim \cal {\sf Normal}({\mu}, 1)\]
\[{\sf measurement} \, | {\sf weight} \sim {\sf Normal}({\sf weight}, 0.75)\]

如果我们知道每个 pyro.sample site 的输入和输出, 那么我们可以计算他们的 log-joint:

logp = dist.Normal(mu, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()

但是,我们上面定义的 scale 并未展示这些中间分布对象, and rewriting it to return them would be intrusive 而且会违反分离模型和推理算法的初衷 that a probabilistic programming language like Pyro is designed to enforce.

为了解决此冲突和方便推断算法的开发, Pyro 推出了 Poutine, a library of effect handlers , or composable building blocks for examining and modifying the behavior of Pyro programs. Pyro的大多数内部组件都是在Poutine之上实现的。

++++++ 下节预告:

def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

Pyro 的算法构建基石库 Poutine

A first look at Poutine: Pyro’s library of algorithmic building blocks

(什么是 EH?)Effect handlers, a common abstraction in the programming languages community, give nonstandard interpretations or side effects to the behavior of particular statements in a programming language, 例如 pyro.samplepyro.param. 有关编程语言研究中的效应处理程序的背景知识,请参阅本教程最后一个小节。

相对于查看更多定义,让我们看第一个例子来解释一下: 我们组合两个 EH, poutine.condition (它设定 pyro.sample 语句的输出值) 和 poutine.trace (它记录 pyro.sample 语句的输入,分布函数和输出), 来简单的定义一个新的 EH 用来计算对数似然。也就是说我们可以用两个现有的 EH 组成一个新的 EH 用于计算 log-joint:

[3]:
def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
tensor(-3.0203)

该代码段很短,但仍然有些难懂 - poutine.condition, poutine.trace, 而且 trace.log_prob_sum 依然是黑盒. Let’s remove a layer of boilerplate from poutine.condition and poutine.trace and explicitly implement what trace.log_prob_sum is doing:

[14]:
a = {"temp": torch.tensor(51)}
{"t":v for k, v in a.items()}
[14]:
{'t': tensor(51)}
[11]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
#         conditioned_model = poutine.condition(model, data=cond_data)
#         trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
#         return trace.log_prob_sum()
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            print('name:', name, ', value of node:', node['value'])
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)) # mu=8.5 是模型的输入参数
name: weight , value of node: 8.23
name: measurement , value of node: 9.5
tensor(-3.0203)

这让我们对计算 log-joint 的机制更清楚一点点了:

  • 我们可以看到 poutine.trace and poutine.condition are wrappers for context managers that presumably communicate with the model through something inside pyro.sample.

  • 我们也可以看到 poutine.trace produces a data structure (a Trace) containing a dictionary whose keys are sample site names and values are dictionaries containing the distribution ("fn") and output ("value") at each site, and that the output values at each site are exactly the values specified in data.

  • 最后, TraceMessengerConditionMessenger 是 Pyro 效应处理程序, or Messengers: stateful context manager objects that are placed on a global stack and send messages (hence the name) up and down the stack at each effectful operation, like a pyro.sample call. A Messenger is placed at the bottom of the stack when its __enter__ method is called, i.e. when it is used in a “with” statement.

我们将在本教程的后面部分详细介绍该过程。 参考 mini-pyro 中关于基类 Messenger 的内容, see pyro.contrib.minipyro.

class Messenger:
    def __init__(self, fn=None):
        self.fn = fn

    # Effect handlers push themselves onto the PYRO_STACK.
    # Handlers earlier in the PYRO_STACK are applied first.
    def __enter__(self):
        PYRO_STACK.append(self)

    def __exit__(self, *args, **kwargs):
        assert PYRO_STACK[-1] is self
        PYRO_STACK.pop()

    def process_message(self, msg):
        pass

    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        with self:
            return self.fn(*args, **kwargs)

Messenger 构建新 EH

尽管通过在pyro.poutine中组合现有的效应处理程序来构建新的效应处理程序是最容易的, 但是构建新效应处理程序作为 pyro.poutine.messenger.Messenger 子类实现非常更直接。在深入研究API之前,让我们看另一个例子:log-joint 计算的一个版本,在模型执行时执行求和。然后我们将回顾示例的每个部分的实际操作。

[5]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ 是用于将 Messenger 用作高阶函数的语法糖。
    # Messenger already defines __call__, 但是我们在这里重新定义
    # for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__enter__()
        # in their __enter__ methods
        return super().__enter__()

    # __exit__ takes the same arguments in all Python context managers
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__exit__ method
        # in their __exit__ methods.
        return super().__exit__(exc_type, exc_value, traceback)

    # _pyro_sample 对于每个 pyro.sample site 调用一次.
    # It takes a dictionary msg containing the name, distribution,
    # observation or sample value, and other metadata from the sample site.
    def _pyro_sample(self, msg):
        # Any unobserved random variables will trigger this assertion.
        # In the next section, we'll learn how to also handle sampled values.
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # Since we've observed a value for this site, we set the "is_observed" flag to True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
tensor(-3.0203)
tensor(-3.0203)

(可以把 LogJointMessenger 当成上下文管理器,装饰器或者高阶函数使用)A convenient bit of boilerplate that allows the use of LogJointMessenger as a context manager, decorator, or higher-order function is the following. Most of the existing effect handlers in pyro.poutine, including poutine.trace and poutine.condition which we used earlier, are Messengers wrapped this way in pyro.poutine.handlers.

[6]:
def log_joint(model=None, cond_data=None):
    msngr = LogJointMessenger(cond_data=cond_data)
    return msngr(model) if model is not None else msngr

scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))
tensor(-3.0203)

Messenger 类的方法详解

LogJointMessenger 的实现有三个主要方法: __enter__, __exit__, and _pyro_sample.

__enter____exit__ 是上下文管理器的特殊方法。When implementing new Messenger classes, if we override __enter__ and __exit__, we always need to call the base Messenger’s __enter__ and __exit__ methods for the new Messenger to be applied correctly.

(method LogJointMessenger._pyro_sample 用于读取和修改一条由字典组成的信息 msg。)The last method LogJointMessenger._pyro_sample, is called once at each sample site. It reads and modifies a message, which is a dictionary containing the sample site’s name, distribution, sampled or observed value, and other metadata. We’ll examine the contents of a message in more detail in the next section.

(类 Messenger 具备两个信息操作的方法,包括 _process_message_postprocess_message.)Instead of _pyro_sample, a generic Messenger actually contains two methods that are called once per operation where side effects are performed:

  1. _process_message modifies a message and sends the result to the Messenger just above on the stack

  2. _postprocess_message modifies a message and sends the result to the next Messenger down on the stack. It is always called after all active Messengers have had their _process_message method applied to the message.

Although custom Messengers can override _process_message and _postprocess_message, it’s convenient to avoid requiring all effect handlers to be aware of all possible effectful operation types. For this reason, by default Messenger._process_message will use msg["type"] to dispatch to a corresponding method Messenger._pyro_<type>, e.g. Messenger._pyro_sample as in LogJointMessenger. Just as exception handling code ignores unhandled exception types, this allows Messengers to simply forward operations they don’t know how to handle up to the next Messenger in the stack:

class Messenger:
    ...
    def _process_message(self, msg):
        method_name = "_pyro_{}".format(msg["type"])  # e.g. _pyro_sample when msg["type"] == "sample"
        if hasattr(self, method_name):
            getattr(self, method_name)(msg)
    ...

全局 Messenger

有关本部分中该机制的端到端实现,请参见 pyro.contrib.minipyro

The order in which Messengers are applied to an operation like a pyro.sample statement is determined by the order in which their __enter__ methods are called. Messenger.__enter__ appends a Messenger to the end (the bottom) of the global handler stack:

# 进入时候 append a 'Messenger' to the end of the stack, 而退出时候 pop self.
class Messenger:
    ...
    # __enter__ pushes a Messenger onto the stack
    def __enter__(self):
        ...
        _PYRO_STACK.append(self)
        ...

    # __exit__ removes a Messenger from the stack
    def __exit__(self, ...):
        ...
        assert _PYRO_STACK[-1] is self
        _PYRO_STACK.pop()
        ...

pyro.poutine.runtime.apply_stack then traverses the stack twice at each operation, first from bottom to top to apply each _process_message and then from top to bottom to apply each _postprocess_message:

# 从 bottom 到 top `_process_message`,然后从 top 到 bottom `_postprocess_message`
def apply_stack(msg):  # simplified
    for handler in reversed(_PYRO_STACK):
        handler._process_message(msg)
    ...
    default_process_message(msg)
    ...
    for handler in _PYRO_STACK:
        handler._postprocess_message(msg)
    ...
    return msg

重写例子 LogJointMessenger

The second method _postprocess_message is necessary because some effects can only be applied after all other effect handlers have had a chance to update the message once. In the case of LogJointMessenger, other effects, like enumeration, may modify a sample site’s value or distribution (msg["value"] or msg["fn"]), so we move the log-probability computation to a new method, _pyro_post_sample, which is called by _postprocess_message (via a dispatch mechanism like the one used by _process_message) at each sample site after all active handlers’ _pyro_sample methods have been applied:

[7]:
class LogJointMessenger2(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True

    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()


with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)
tensor(-3.0203)

Messenger 发送的信息 msg

如前两个示例所述,在堆栈上 sent up and down 的实际信息是带有特定键集的字典。考虑以下抽样语句:

pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)

This sample statement is converted into an initial message before any effects are applied, and each effect handler’s _process_message and _postprocess_message may update fields in place or add new fields. We write out the full initial message here for completeness:

msg = {
    # The following fields 包含样本点的名字, 输入, 分布函数, 和输出.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] 会包含 pyro.sample 的返回值.
    "is_observed": False,  # because obs=None by default; only used by sample sites
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    # 该字段通常包含特定推理算法所需或存储的元数据
    "infer": {"enumerate": "parallel"},
    # 其余字段通常仅由Pyro内部使用,或用于实现超出本教程范围的更高级效果
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}

Note that when we use poutine.trace or TraceMessenger as in our first two versions of make_log_joint, the contents of msg are exactly the information stored in the trace for each sample and param site.

回顾和总结例子的处理

我们回顾和总结一下例子的处理

  • 我们首先提出问题“如何计算对数似然”

  • 其次打开黑箱,使用 trace 来获得内部抽样节点,手动计算对数似然

  • 再而使用 Messenger 直接写一个 EH 来计算 log-joint

  • 最后改进了该子类

[3]:
mu = 8.5
def scale(mu):
    weight = pyro.sample("weight", dist.Normal(mu, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))
scale(mu)
[3]:
tensor(6.4981)
[5]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint
scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5)) # mu=8.5 是模型的输入参数
tensor(-3.0203)
[6]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ 是用于将 Messenger 用作高阶函数的语法糖。
    # Messenger already defines __call__, 但是我们在这里重新定义 for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    # _pyro_sample 对于每个 pyro.sample site 调用一次.
    def _pyro_sample(self, msg):
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # 由于我们已经观察到该样本点的值,因此将 “is_observed” 标志设置为 True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
tensor(-3.0203)
tensor(-3.0203)
[7]:
class LogJointMessenger2(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True

    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()


with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)
tensor(-3.0203)

Mini-pyro 简介

mini-pyro 的核心是

class Messenger:
    def __init__(self, fn=None):
        self.fn = fn # 表示节点的分布

    def __enter__(self):
        PYRO_STACK.append(self) #Effect handlers push themselves onto the PYRO_STACK.
    def __exit__(self, *args, **kwargs):
        assert PYRO_STACK[-1] is self
        PYRO_STACK.pop()

    def process_message(self, msg):
        pass
    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        with self:
            return self.fn(*args, **kwargs)

基类 Messenger

trace, replay, block, seed, PlateMessenger 都是 Messenger 的子类。

  • trace: trace records the inputs and outputs of any primitive site it encloses, and returns a dictionary containing that data to the user.

  • replay: an effect handler for setting the value at a sample site.

  • block: allows the selective application of effect handlers to different parts of a model. Sites hidden by block will only have the handlers below block on the PYRO_STACK applied, allowing inference or other effectful computations to be nested inside models.

  • seed: is used to fix the RNG state when calling a model.

  • PlateMessenger: This limited implementation of PlateMessenger only implements broadcasting.

[4]:
import random
import warnings
import weakref
from collections import OrderedDict
import torch
from pyro.distributions import validation_enabled

(Pyro跟踪两种全局状态,包括样本点 PYRO_STACK 和可训练参数 PARAM_STORE)Pyro keeps track of two kinds of global state:

[1]:
# Messenger 类就是用来处理 PYRO_STACK
PYRO_STACK = []
PARAM_STORE = {}  # maps name -> (unconstrained_value, constraint)

def get_param_store():
    return PARAM_STORE

# The base effect handler class (called Messenger here for consistency with Pyro).
class Messenger:
    def __init__(self, fn=None):
        self.fn = fn

    # Effect handlers push themselves onto the PYRO_STACK.
    # Handlers earlier in the PYRO_STACK are applied first.
    def __enter__(self):
        PYRO_STACK.append(self)
    def __exit__(self, *args, **kwargs):
        assert PYRO_STACK[-1] is self
        PYRO_STACK.pop()

    def process_message(self, msg):
        pass
    def postprocess_message(self, msg):
        pass

    def __call__(self, *args, **kwargs):
        with self:
            return self.fn(*args, **kwargs)

这里给出第一个有用的 effect handler 例子. trace records the inputs and outputs of any primitive site it encloses, and returns a dictionary containing that data to the user.

[20]:
class trace(Messenger):
    def __enter__(self):
        super().__enter__()
        self.trace = OrderedDict()
        return self.trace

    # trace illustrates why we need postprocess_message in addition to process_message:
    # We only want to record a value after all other effects have been applied
    def postprocess_message(self, msg):
        assert msg["type"] != "sample" or msg["name"] not in self.trace, \
            "sample sites must have unique names"
        self.trace[msg["name"]] = msg.copy()

    def get_trace(self, *args, **kwargs):
        self(*args, **kwargs)
        return self.trace
[6]:
import pyro.distributions as dist
from pyro.poutine.trace_messenger import TraceMessenger
mu = 8.5
def scale_obs(mu):
    weight = pyro.sample("weight", dist.Normal(mu, 1.))
    return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=9.5)

with TraceMessenger() as tracer:
    scale_obs(mu)

trace = tracer.trace
logp = 0.
for name, node in trace.nodes.items():
    print(name, node['fn'], node['value'], node['is_observed'])
    if node["type"] == "sample":
        logp = logp + node["fn"].log_prob(node["value"]).sum()
weight Normal(loc: 8.5, scale: 1.0) tensor(7.6848) False
measurement Normal(loc: 7.684762001037598, scale: 0.75) 9.5 True
[7]:
# %psource TraceMessenger
# Return a handler that records the inputs and outputs of primitive calls and their dependencies.

这里给出第二个有用的 effect handler 例子用于给某个样本点设置 value. This illustrates why effect handlers are a useful PPL implementation technique: We can compose trace and replay to replace values but preserve distributions, allowing us to compute the joint probability density of samples under a model. See the definition of elbo(…) below for an example of this pattern.

[21]:
class replay(Messenger):
    def __init__(self, fn, guide_trace):
        self.guide_trace = guide_trace
        super().__init__(fn)

    def process_message(self, msg):
        if msg["name"] in self.guide_trace:
            msg["value"] = self.guide_trace[msg["name"]]["value"]

block allows the selective application of effect handlers to different parts of a model. Sites hidden by block will only have the handlers below block on the PYRO_STACK applied, allowing inference or other effectful computations to be nested inside models.

[ ]:
class block(Messenger):
    def __init__(self, fn=None, hide_fn=lambda msg: True):
        self.hide_fn = hide_fn
        super().__init__(fn)

    def process_message(self, msg):
        if self.hide_fn(msg):
            msg["stop"] = True
[ ]:
# seed is used to fix the RNG state when calling a model.
class seed(Messenger):
    def __init__(self, fn=None, rng_seed=None):
        self.rng_seed = rng_seed
        super().__init__(fn)

    def __enter__(self):
        self.old_state = {'torch': torch.get_rng_state(), 'random': random.getstate()}
        torch.manual_seed(self.rng_seed)
        random.seed(self.rng_seed)
        try:
            import numpy as np
            np.random.seed(self.rng_seed)
            self.old_state['numpy'] = np.random.get_state()
        except ImportError:
            pass

    def __exit__(self, type, value, traceback):
        torch.set_rng_state(self.old_state['torch'])
        random.setstate(self.old_state['random'])
        if 'numpy' in self.old_state:
            import numpy as np
            np.random.set_state(self.old_state['numpy'])
[8]:
# This limited implementation of PlateMessenger only implements broadcasting.
class PlateMessenger(Messenger):
    def __init__(self, fn, size, dim):
        assert dim < 0
        self.size = size
        self.dim = dim
        super().__init__(fn)

    def process_message(self, msg):
        if msg["type"] == "sample":
            batch_shape = msg["fn"].batch_shape
            if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size:
                batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape)
                batch_shape[self.dim] = self.size
                msg["fn"] = msg["fn"].expand(torch.Size(batch_shape))

    def __iter__(self):
        return range(self.size)

操作 Messenger

apply_stack is called by pyro.sample and pyro.param. It is responsible for applying each Messenger to each effectful operation.

[ ]:
def apply_stack(msg):
    for pointer, handler in enumerate(reversed(PYRO_STACK)):
        handler.process_message(msg)
        # When a Messenger sets the "stop" field of a message,
        # it prevents any Messengers above it on the stack from being applied.
        if msg.get("stop"):
            break
    if msg["value"] is None:
        msg["value"] = msg["fn"](*msg["args"])

    # A Messenger that sets msg["stop"] == True also prevents application
    # of postprocess_message by Messengers above it on the stack
    # via the pointer variable from the process_message loop
    for handler in PYRO_STACK[-pointer-1:]:
        handler.postprocess_message(msg)
    return msg
[ ]:
# sample is an effectful version of Distribution.sample(...)
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def sample(name, fn, *args, **kwargs):
    obs = kwargs.pop('obs', None)

    # if there are no active Messengers, we just draw a sample and return it as expected:
    if not PYRO_STACK:
        return fn(*args, **kwargs)

    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "sample",
        "name": name,
        "fn": fn,
        "args": args,
        "kwargs": kwargs,
        "value": obs,
    }

    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    return msg["value"]
[ ]:
# param is an effectful version of PARAM_STORE.setdefault that also handles constraints.
# When any effect handlers are active, it constructs an initial message and calls apply_stack.
def param(name, init_value=None, constraint=torch.distributions.constraints.real, event_dim=None):
    if event_dim is not None:
        raise NotImplementedError("minipyro.plate does not support the event_dim arg")

    def fn(init_value, constraint):
        if name in PARAM_STORE:
            unconstrained_value, constraint = PARAM_STORE[name]
        else:
            # Initialize with a constrained value.
            assert init_value is not None
            with torch.no_grad():
                constrained_value = init_value.detach()
                unconstrained_value = torch.distributions.transform_to(constraint).inv(constrained_value)
            unconstrained_value.requires_grad_()
            PARAM_STORE[name] = unconstrained_value, constraint

        # Transform from unconstrained space to constrained space.
        constrained_value = torch.distributions.transform_to(constraint)(unconstrained_value)
        constrained_value.unconstrained = weakref.ref(unconstrained_value)
        return constrained_value

    # if there are no active Messengers, we just draw a sample and return it as expected:
    if not PYRO_STACK:
        return fn(init_value, constraint)

    # Otherwise, we initialize a message...
    initial_msg = {
        "type": "param",
        "name": name,
        "fn": fn,
        "args": (init_value, constraint),
        "value": None,
    }

    # ...and use apply_stack to send it to the Messengers
    msg = apply_stack(initial_msg)
    return msg["value"]
[9]:
# boilerplate to match the syntax of actual pyro.plate:
def plate(name, size, dim=None):
    if dim is None:
        raise NotImplementedError("minipyro.plate requires a dim arg")
    return PlateMessenger(fn=None, size=size, dim=dim)

推断和优化

[ ]:
# This is a thin wrapper around the `torch.optim.Adam` class that
# dynamically generates optimizers for dynamically generated parameters.
# See http://docs.pyro.ai/en/0.3.1/optimization.html
class Adam:
    def __init__(self, optim_args):
        self.optim_args = optim_args
        # Each parameter will get its own optimizer, which we keep track
        # of using this dictionary keyed on parameters.
        self.optim_objs = {}

    def __call__(self, params):
        for param in params:
            # If we've seen this parameter before, use the previously
            # constructed optimizer.
            if param in self.optim_objs:
                optim = self.optim_objs[param]
            # If we've never seen this parameter before, construct
            # an Adam optimizer and keep track of it.
            else:
                optim = torch.optim.Adam([param], **self.optim_args)
                self.optim_objs[param] = optim
            # Take a gradient step for the parameter param.
            optim.step()
[ ]:
# This is a unified interface for stochastic variational inference in Pyro.
# The actual construction of the loss is taken care of by `loss`.
# See http://docs.pyro.ai/en/0.3.1/inference_algos.html
class SVI:
    def __init__(self, model, guide, optim, loss):
        self.model = model
        self.guide = guide
        self.optim = optim
        self.loss = loss

    # This method handles running the model and guide, constructing the loss
    # function, and taking a gradient step.
    def step(self, *args, **kwargs):
        # This wraps both the call to `model` and `guide` in a `trace` so that
        # we can record all the parameters that are encountered. Note that
        # further tracing occurs inside of `loss`.
        with trace() as param_capture:
            # We use block here to allow tracing to record parameters only.
            with block(hide_fn=lambda msg: msg["type"] == "sample"):
                loss = self.loss(self.model, self.guide, *args, **kwargs)
        # Differentiate the loss.
        loss.backward()
        # Grab all the parameters from the trace.
        params = [site["value"].unconstrained()
                  for site in param_capture.values()]
        # Take a step w.r.t. each parameter in params.
        self.optim(params)
        # Zero out the gradients so that they don't accumulate.
        for p in params:
            p.grad = torch.zeros_like(p)
        return loss.item()

    # This is a basic implementation of the Evidence Lower Bound, which is the
    # fundamental objective in Variational Inference.
    # See http://pyro.ai/examples/svi_part_i.html for details.
    # This implementation has various limitations (for example it only supports
    # random variables with reparameterized samplers), but all the ELBO
    # implementations in Pyro share the same basic logic.
    def elbo(model, guide, *args, **kwargs):
        # Run the guide with the arguments passed to SVI.step() and trace the execution,
        # i.e. record all the calls to Pyro primitives like sample() and param().
        guide_trace = trace(guide).get_trace(*args, **kwargs)
        # Now run the model with the same arguments and trace the execution. Because
        # model is being run with replay, whenever we encounter a sample site in the
        # model, instead of sampling from the corresponding distribution in the model,
        # we instead reuse the corresponding sample from the guide. In probabilistic
        # terms, this means our loss is constructed as an expectation w.r.t. the joint
        # distribution defined by the guide.
        model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
        # We will accumulate the various terms of the ELBO in `elbo`.
        elbo = 0.
        # Loop over all the sample sites in the model and add the corresponding
        # log p(z) term to the ELBO. Note that this will also include any observed
        # data, i.e. sample sites with the keyword `obs=...`.
        for site in model_trace.values():
            if site["type"] == "sample":
                elbo = elbo + site["fn"].log_prob(site["value"]).sum()
        # Loop over all the sample sites in the guide and add the corresponding
        # -log q(z) term to the ELBO.
        for site in guide_trace.values():
            if site["type"] == "sample":
                elbo = elbo - site["fn"].log_prob(site["value"]).sum()
        # Return (-elbo) since by convention we do gradient descent on a loss and
        # the ELBO is a lower bound that needs to be maximized.
        return -elbo


    # This is a wrapper for compatibility with full Pyro.
    def Trace_ELBO(**kwargs):
        return elbo
[ ]:
# This is a Jit wrapper around elbo() that (1) delays tracing until the first
# invocation, and (2) registers pyro.param() statements with torch.jit.trace.
# This version does not support variable number of args or non-tensor kwargs.
class JitTrace_ELBO:
    def __init__(self, **kwargs):
        self.ignore_jit_warnings = kwargs.pop("ignore_jit_warnings", False)
        self._compiled = None
        self._param_trace = None

    def __call__(self, model, guide, *args):
        # On first call, initialize params and save their names.
        if self._param_trace is None:
            with block(), trace() as tr, block(hide_fn=lambda m: m["type"] != "param"):
                elbo(model, guide, *args)
            self._param_trace = tr

        # Augment args with reads from the global param store.
        unconstrained_params = tuple(param(name).unconstrained()
                                     for name in self._param_trace)
        params_and_args = unconstrained_params + args

        # On first call, create a compiled elbo.
        if self._compiled is None:

            def compiled(*params_and_args):
                unconstrained_params = params_and_args[:len(self._param_trace)]
                args = params_and_args[len(self._param_trace):]
                for name, unconstrained_param in zip(self._param_trace, unconstrained_params):
                    constrained_param = param(name)  # assume param has been initialized
                    assert constrained_param.unconstrained() is unconstrained_param
                    self._param_trace[name]["value"] = constrained_param
                return replay(elbo, guide_trace=self._param_trace)(model, guide, *args)

            with validation_enabled(False), warnings.catch_warnings():
                if self.ignore_jit_warnings:
                    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
                self._compiled = torch.jit.trace(compiled, params_and_args, check_trace=False)

        return self._compiled(*params_and_args)

用 effect handlers 实现推断算法实战

Implementing inference algorithms with existing effect handlers: examples

It turns out that many inference operations, like our first version of make_log_joint above, have strikingly short implementations in terms of existing effect handlers in pyro.poutine.

例子: Variational inference with a Monte Carlo ELBO

For example, here is an implementation of variational inference with a Monte Carlo ELBO that uses poutine.trace, poutine.condition, and poutine.replay. This is very similar to the simple ELBO in pyro.contrib.minipyro.

[8]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
    conditioned_model = poutine.condition(model, data=batch)

    # we'll approximate the expectation in the ELBO with a single sample:
    # first, we run the guide forward unmodified and record values and distributions
    # at each sample site using poutine.trace
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

    # we use poutine.replay to set the values of latent variables in the model
    # to the values sampled above by our guide, and use poutine.trace
    # to record the distributions that appear at each sample site in in the model
    model_trace = poutine.trace(
        poutine.replay(conditioned_model, trace=guide_trace)
    ).get_trace(*args, **kwargs)

    elbo = 0.
    for name, node in model_trace.nodes.items():
        if node["type"] == "sample":
            elbo = elbo + node["fn"].log_prob(node["value"]).sum()
            if not node["is_observed"]:
                elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
    return -elbo

We use poutine.trace and poutine.block to record pyro.param calls for optimization:

[9]:
def train(model, guide, data):
    optimizer = pyro.optim.Adam({})
    for batch in data:
        # this poutine.trace will record all of the parameters that appear in the model and guide
        # during the execution of monte_carlo_elbo
        with poutine.trace() as param_capture:
            # we use poutine.block here so that only parameters appear in the trace above
            with poutine.block(hide_fn=lambda node: node["type"] != "param"):
                loss = monte_carlo_elbo(model, guide, batch)

        loss.backward()
        params = set(node["value"].unconstrained()
                     for node in param_capture.trace.nodes.values())
        optimizer.step(params)
        pyro.infer.util.zero_grads(params)

例子: exact inference via sequential enumeration

Here is an example of a very different inference algorithm–exact inference via enumeration–implemented with pyro.poutine. A complete explanation of this algorithm is beyond the scope of this tutorial and may be found in Chapter 3 of the short online book Design and Implementation of Probabilistic Programming Languages. This example uses poutine.queue, itself implemented using poutine.trace, poutine.replay, and poutine.block, to enumerate over possible values of all discrete variables in a model and compute a marginal distribution over all possible return values or the possible values at a particular sample site:

[10]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):

    from six.moves import queue  # queue data structures
    q = queue.Queue()  # Instantiate a first-in first-out queue
    q.put(poutine.Trace())  # seed the queue with an empty trace

    # as before, we fix the values of observed random variables with poutine.condition
    # assuming data is a dictionary whose keys are names of sample sites in model
    conditioned_model = poutine.condition(model, data=data)

    # we wrap the conditioned model in a poutine.queue,
    # which repeatedly pushes and pops partially completed executions from a Queue()
    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model
    enum_model = poutine.queue(conditioned_model, queue=q)

    # actually perform the enumeration by repeatedly tracing enum_model
    # and accumulate samples and trace log-probabilities for postprocessing
    samples, log_weights = [], []
    while not q.empty():
        trace = poutine.trace(enum_model).get_trace()
        samples.append(trace.nodes[site_name]["value"])
        log_weights.append(trace.log_prob_sum())

    # we take the samples and log-joints and turn them into a histogram:
    samples = torch.stack(samples, 0)
    log_weights = torch.stack(log_weights, 0)
    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
    return dist.Empirical(samples, log_weights)

(Note that sequential_discrete_marginal is very general, but is also quite slow. For high-performance parallel enumeration that applies to a less general class of models, see the enumeration tutorial.)

例子: implementing lazy evaluation with the Messenger API

Now that we’ve learned more about the internals of Messenger, let’s use it to implement a slightly more complicated effect: lazy evaluation. We first define a LazyValue class that we will use to build up a computation graph:

[11]:
class LazyValue:
    def __init__(self, fn, *args, **kwargs):
        self._expr = (fn, args, kwargs)
        self._value = None

    def __str__(self):
        return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))

    def evaluate(self):
        if self._value is None:
            fn, args, kwargs = self._expr
            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg
                         for arg in args)
            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v
                      for k, v in kwargs.items()}
            self._value = fn(*args, **kwargs)
        return self._value

With LazyValue, implementing lazy evaluation as a Messenger compatible with other effect handlers is suprisingly easy. We just make each msg["value"] a LazyValue and introduce a new operation type "apply" for deterministic operations:

[12]:
class LazyMessenger(pyro.poutine.messenger.Messenger):
    def _process_message(self, msg):
        if msg["type"] in ("apply", "sample") and not msg["done"]:
            msg["done"] = True
            msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])

Finally, just like torch.autograd overloads torch tensor operations to record an autograd graph, we need to wrap any operations we’d like to be lazy. We’ll use pyro.poutine.runtime.effectful as a decorator to expose these operations to LazyMessenger. effectful constructs a message much like the one above and sends it up and down the effect handler stack, but allows us to set the type (in this case, to "apply" instead of "sample") so that these operations aren’t mistaken for sample statements by other effect handlers like TraceMessenger:

[13]:
@effectful(type="apply")
def add(x, y):
    return x + y

@effectful(type="apply")
def mul(x, y):
    return x * y

@effectful(type="apply")
def sigmoid(x):
    return torch.sigmoid(x)

@effectful(type="apply")
def normal(loc, scale):
    return dist.Normal(loc, scale)

Applied to another model:

[14]:
def biased_scale(guess):
    weight = pyro.sample("weight", normal(guess, 1.))
    tolerance = pyro.sample("tolerance", normal(0., 0.25))
    return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))

with LazyMessenger():
    v = biased_scale(8.5)
    print(v)
    print(v.evaluate())
((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )
tensor(6.5436)

Together with other effect handlers like TraceMessenger and ConditionMessenger, with which it freely composes, LazyMessenger demonstrates how to use Poutine to quickly and concisely implement state-of-the-art PPL techniques like delayed sampling with Rao-Blackwellization.

References: EH 参考资料

Algebraic effects and handlers in programming language research

This section contains some references to PL papers for readers interested in this direction.

代数效应和处理程序始于2000年代初期,是编程语言社区中活跃的研究主题,它是一种通用抽象,for building modular implementations of nonstandard interpreters of particular statements in a programming language,例如pyro.samplepyro.param。They were originally introduced to address the difficulty of composing nonstandard interpreters implemented with monads and monad transformers.

  • For an accessible introduction to the effect handlers literature, see the excellent review/tutorial paper “Handlers in Action” by Ohad Kammar, Sam Lindley, and Nicolas Oury, and the references therein.

  • Algebraic effect handlers were originally introduced by Gordon Plotkin and Matija Pretnar in the paper “Handlers of Algebraic Effects”.

  • A useful mental model of effect handlers is as exception handlers that are capable of resuming computation in the try block after raising an exception and performing some processing in the except block. This metaphor is explored further in the experimental programming language Eff and its companion paper “Programming with Algebraic Effects and Handlers” by Andrej Bauer and Matija Pretnar.

  • Most effect handlers in Pyro are “linear,” meaning that they only resume once per effectful operation and do not alter the order of execution of the original program. One exception is poutine.queue, which uses an inefficient implementation strategy for multiple resumptions like the one described for delimited continuations in the paper “Capturing the Future by Replaying the Past” by James Koppel, Gabriel Scherer, and Armando Solar-Lezama.

  • More efficient implementation strategies for effect handlers in mainstream programming languages like Python or JavaScript is an area of active research. One promising line of work involves selective continuation-passing style transforms as in the paper “Type-Directed Compilation of Row-Typed Algebraic Effects” by Daan Leijen.