# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
"""
This example demonstrates how to use plated ``einsum`` with different backends
to compute logprob, gradients, MAP estimates, posterior samples, and marginals.
The interface for adjoint algorithms requires four steps:
1. Call ``require_backward()`` on all inputs.
2. Call ``x, = einsum(..., backend=...)`` with a nonstandard backend.
3. Call ``x._pyro_backward()` on the einsum output.
4. Retrieve results from ``._pyro_backward_result`` attributes of the inputs.
The results of these computations are returned, but this script does not
make use of them; instead we simply time the operations for profiling.
All profiling is done on jit-compiled functions. We exclude jit compilation
time from profiling results, assuming this can be done once.
You can measure complexity of different einsum problems by specifying
``--equation`` and ``--plates``.
"""
import argparse
import timeit
import torch
from torch.autograd import grad
from pyro.ops.contract import einsum
from pyro.ops.einsum.adjoint import require_backward
from pyro.util import ignore_jit_warnings
# We will cache jit-compiled versions of each function.
_CACHE = {}
def jit_prob(equation, *operands, **kwargs):
"""
Runs einsum to compute the partition function.
This is cheap but less numerically stable than using the torch_log backend.
"""
key = 'prob', equation, kwargs['plates']
if key not in _CACHE:
# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(equation, *operands, **kwargs)
_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)
return _CACHE[key](*operands)
def jit_logprob(equation, *operands, **kwargs):
"""
Runs einsum to compute the log partition function.
This simulates evaluating an undirected graphical model.
"""
key = 'logprob', equation, kwargs['plates']
if key not in _CACHE:
# This simply wraps einsum for jit compilation.
def _einsum(*operands):
return einsum(equation, *operands, backend='pyro.ops.einsum.torch_log', **kwargs)
_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)
return _CACHE[key](*operands)
def jit_gradient(equation, *operands, **kwargs):
"""
Runs einsum and calls backward on the partition function.
This is simulates training an undirected graphical model.
"""
key = 'gradient', equation, kwargs['plates']
if key not in _CACHE:
# This wraps einsum for jit compilation, but we will call backward on the result.
def _einsum(*operands):
return einsum(equation, *operands, backend='pyro.ops.einsum.torch_log', **kwargs)
_CACHE[key] = torch.jit.trace(_einsum, operands, check_trace=False)
# Run forward pass.
losses = _CACHE[key](*operands)
# Work around PyTorch 1.0.0 bug https://github.com/pytorch/pytorch/issues/14875
# whereby tuples of length 1 are unwrapped by the jit.
if not isinstance(losses, tuple):
losses = (losses,)
# Run backward pass.
grads = tuple(grad(loss, operands, retain_graph=True, allow_unused=True)
for loss in losses)
return grads
def _jit_adjoint(equation, *operands, **kwargs):
"""
Runs einsum in forward-backward mode using ``pyro.ops.adjoint``.
This simulates serving predictions from an undirected graphical model.
"""
backend = kwargs.pop('backend', 'pyro.ops.einsum.torch_sample')
key = backend, equation, tuple(x.shape for x in operands), kwargs['plates']
if key not in _CACHE:
# This wraps a complete adjoint algorithm call.
@ignore_jit_warnings()
def _forward_backward(*operands):
# First we request backward results on each input operand.
# This is the pyro.ops.adjoint equivalent of torch's .requires_grad_().
for operand in operands:
require_backward(operand)
# Next we run the forward pass.
results = einsum(equation, *operands, backend=backend, **kwargs)
# The we run a backward pass.
for result in results:
result._pyro_backward()
# Finally we retrieve results from the ._pyro_backward_result attribute
# that has been set on each input operand. If you only want results on a
# subset of operands, you can call require_backward() on only those.
results = []
for x in operands:
results.append(x._pyro_backward_result)
x._pyro_backward_result = None
return tuple(results)
_CACHE[key] = torch.jit.trace(_forward_backward, operands, check_trace=False)
return _CACHE[key](*operands)
def jit_map(equation, *operands, **kwargs):
return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_map', **kwargs)
def jit_sample(equation, *operands, **kwargs):
return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_sample', **kwargs)
def jit_marginal(equation, *operands, **kwargs):
return _jit_adjoint(equation, *operands, backend='pyro.ops.einsum.torch_marginal', **kwargs)
def time_fn(fn, equation, *operands, **kwargs):
iters = kwargs.pop('iters')
_CACHE.clear() # Avoid memory leaks.
fn(equation, *operands, **kwargs)
time_start = timeit.default_timer()
for i in range(iters):
fn(equation, *operands, **kwargs)
time_end = timeit.default_timer()
return (time_end - time_start) / iters
def main(args):
if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
if args.method == 'all':
for method in ['prob', 'logprob', 'gradient', 'marginal', 'map', 'sample']:
args.method = method
main(args)
return
print('Plate size Time per iteration of {} (ms)'.format(args.method))
fn = globals()['jit_{}'.format(args.method)]
equation = args.equation
plates = args.plates
inputs, outputs = equation.split('->')
inputs = inputs.split(',')
# Vary all plate sizes at the same time.
for plate_size in range(8, 1 + args.max_plate_size, 8):
operands = []
for dims in inputs:
shape = torch.Size([plate_size if d in plates else args.dim_size
for d in dims])
operands.append((torch.empty(shape).uniform_() + 0.5).requires_grad_())
time = time_fn(fn, equation, *operands, plates=plates, modulo_total=True,
iters=args.iters)
print('{: <11s} {:0.4g}'.format('{} ** {}'.format(plate_size, len(args.plates)), time * 1e3))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='plated einsum profiler')
parser.add_argument('-e', '--equation', default='a,abi,bcij,adj,deij->')
parser.add_argument('-p', '--plates', default='ij')
parser.add_argument('-d', '--dim-size', default=32, type=int)
parser.add_argument('-s', '--max-plate-size', default=32, type=int)
parser.add_argument('-n', '--iters', default=10, type=int)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('-m', '--method', default='all',
help='one of: prob, logprob, gradient, marginal, map, sample')
args = parser.parse_args()
main(args)