跟踪未知数量的对象¶
Tracking an Unknown Number of Objects
SVI 可以用于学习混合模型的 components 和 assignments, pyro.contrib.tracking
提供了一种更加高效的办法来估计 assignments. 本 notebook 演示了如何在SVI中使用 MarginalAssignmentPersistent
。
[1]:
import math
import os
import torch
from torch.distributions import constraints
from matplotlib import pyplot
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.contrib.tracking.assignment import MarginalAssignmentPersistent
from pyro.distributions.util import gather
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import Adam
%matplotlib inline
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)
smoke_test = ('CI' in os.environ)
让我们考虑一个具有确定性动力学的模型,比如说周期已知但相位和幅度未知的正弦曲线。
[2]:
def get_dynamics(num_frames):
time = torch.arange(float(num_frames)) / 4
return torch.stack([time.cos(), time.sin()], -1)
直接把对应的生成模型定义出来很棘手, 所以我们把数据生成过程 generate_data()
和推断会用到的因子图 model()
分开定义。
[5]:
def generate_data(args):
# Object model.
num_objects = int(round(args.expected_num_objects)) # Deterministic.
states = dist.Normal(0., 1.).sample((num_objects, 2))
# Detection model.
emitted = dist.Bernoulli(args.emission_prob).sample((args.num_frames, num_objects))
num_spurious = dist.Poisson(args.expected_num_spurious).sample((args.num_frames,))
max_num_detections = int((num_spurious + emitted.sum(-1)).max())
observations = torch.zeros(args.num_frames, max_num_detections, 1+1) # position+confidence
positions = get_dynamics(args.num_frames).mm(states.t())
noisy_positions = dist.Normal(positions, args.emission_noise_scale).sample()
for t in range(args.num_frames):
j = 0
for i, e in enumerate(emitted[t]):
if e:
observations[t, j, 0] = noisy_positions[t, i]
observations[t, j, 1] = 1
j += 1
n = int(num_spurious[t])
if n:
observations[t, j:j+n, 0] = dist.Normal(0., 1.).sample((n,))
observations[t, j:j+n, 1] = 1
return states, positions, observations
[4]:
def model(args, observations):
with pyro.plate("objects", args.max_num_objects):
exists = pyro.sample("exists",
dist.Bernoulli(args.expected_num_objects / args.max_num_objects))
with poutine.mask(mask=exists.bool()):
states = pyro.sample("states", dist.Normal(0., 1.).expand([2]).to_event(1))
positions = get_dynamics(args.num_frames).mm(states.t())
with pyro.plate("detections", observations.shape[1]):
with pyro.plate("time", args.num_frames):
# The combinatorial part of the log prob is approximated to allow independence.
is_observed = (observations[..., -1] > 0)
with poutine.mask(mask=is_observed):
assign = pyro.sample("assign",
dist.Categorical(torch.ones(args.max_num_objects + 1)))
is_spurious = (assign == args.max_num_objects)
is_real = is_observed & ~is_spurious
num_observed = is_observed.float().sum(-1, True)
pyro.sample("is_real",
dist.Bernoulli(args.expected_num_objects / num_observed),
obs=is_real.float())
pyro.sample("is_spurious",
dist.Bernoulli(args.expected_num_spurious / num_observed),
obs=is_spurious.float())
# The remaining continuous part is exact.
observed_positions = observations[..., 0]
with poutine.mask(mask=is_real):
bogus_position = positions.new_zeros(args.num_frames, 1)
augmented_positions = torch.cat([positions, bogus_position], -1)
predicted_positions = gather(augmented_positions, assign, -1)
pyro.sample("real_observations",
dist.Normal(predicted_positions, args.emission_noise_scale),
obs=observed_positions)
with poutine.mask(mask=is_spurious):
pyro.sample("spurious_observations", dist.Normal(0., 1.),
obs=observed_positions)
This guide uses a smart assignment solver but a naive state estimator. A smarter implementation would use message passing also for state estimation, e.g. a Kalman filter-smoother.
[5]:
def guide(args, observations):
# Initialize states randomly from the prior.
states_loc = pyro.param("states_loc", lambda: torch.randn(args.max_num_objects, 2))
states_scale = pyro.param("states_scale",
lambda: torch.ones(states_loc.shape) * args.emission_noise_scale,
constraint=constraints.positive)
positions = get_dynamics(args.num_frames).mm(states_loc.t())
# Solve soft assignment problem.
real_dist = dist.Normal(positions.unsqueeze(-2), args.emission_noise_scale)
spurious_dist = dist.Normal(0., 1.)
is_observed = (observations[..., -1] > 0)
observed_positions = observations[..., 0].unsqueeze(-1)
assign_logits = (real_dist.log_prob(observed_positions) -
spurious_dist.log_prob(observed_positions) +
math.log(args.expected_num_objects * args.emission_prob /
args.expected_num_spurious))
assign_logits[~is_observed] = -float('inf')
exists_logits = torch.empty(args.max_num_objects).fill_(
math.log(args.max_num_objects / args.expected_num_objects))
assignment = MarginalAssignmentPersistent(exists_logits, assign_logits)
with pyro.plate("objects", args.max_num_objects):
exists = pyro.sample("exists", assignment.exists_dist, infer={"enumerate": "parallel"})
with poutine.mask(mask=exists.bool()):
pyro.sample("states", dist.Normal(states_loc, states_scale).to_event(1))
with pyro.plate("detections", observations.shape[1]):
with poutine.mask(mask=is_observed):
with pyro.plate("time", args.num_frames):
assign = pyro.sample("assign", assignment.assign_dist, infer={"enumerate": "parallel"})
return assignment
We’ll define a global config object to make it easy to port code to argparse
.
[6]:
args = type('Args', (object,), {}) # A fake ArgumentParser.parse_args() result.
args.num_frames = 5
args.max_num_objects = 3
args.expected_num_objects = 2.
args.expected_num_spurious = 1.
args.emission_prob = 0.8
args.emission_noise_scale = 0.1
assert args.max_num_objects >= args.expected_num_objects
Generate data¶
[7]:
pyro.set_rng_seed(0)
true_states, true_positions, observations = generate_data(args)
true_num_objects = len(true_states)
max_num_detections = observations.shape[1]
assert true_states.shape == (true_num_objects, 2)
assert true_positions.shape == (args.num_frames, true_num_objects)
assert observations.shape == (args.num_frames, max_num_detections, 1+1)
print("generated {:d} detections from {:d} objects".format(
(observations[..., -1] > 0).long().sum(), true_num_objects))
generated 16 detections from 2 objects
Train¶
[8]:
def plot_solution(message=''):
assignment = guide(args, observations)
states_loc = pyro.param("states_loc")
positions = get_dynamics(args.num_frames).mm(states_loc.t())
pyplot.figure(figsize=(12,6)).patch.set_color('white')
pyplot.plot(true_positions.numpy(), 'k--')
is_observed = (observations[..., -1] > 0)
pos = observations[..., 0]
time = torch.arange(float(args.num_frames)).unsqueeze(-1).expand_as(pos)
pyplot.scatter(time[is_observed].view(-1).numpy(),
pos[is_observed].view(-1).numpy(), color='k', marker='+',
label='observation')
for i in range(args.max_num_objects):
p_exist = assignment.exists_dist.probs[i].item()
position = positions[:, i].detach().numpy()
pyplot.plot(position, alpha=p_exist, color='C0')
pyplot.title('Truth, observations, and predicted tracks ' + message)
pyplot.plot([], 'k--', label='truth')
pyplot.plot([], color='C0', label='prediction')
pyplot.legend(loc='best')
pyplot.xlabel('time step')
pyplot.ylabel('position')
pyplot.tight_layout()
[9]:
pyro.set_rng_seed(1)
pyro.clear_param_store()
plot_solution('(before training)')
[10]:
infer = SVI(model, guide, Adam({"lr": 0.01}), TraceEnum_ELBO(max_plate_nesting=2))
losses = []
for epoch in range(101 if not smoke_test else 2):
loss = infer.step(args, observations)
if epoch % 10 == 0:
print("epoch {: >4d} loss = {}".format(epoch, loss))
losses.append(loss)
epoch 0 loss = 89.270072937
epoch 10 loss = 85.940826416
epoch 20 loss = 86.1014556885
epoch 30 loss = 83.8865127563
epoch 40 loss = 85.354347229
epoch 50 loss = 82.01512146
epoch 60 loss = 78.1765365601
epoch 70 loss = 78.0290603638
epoch 80 loss = 74.915725708
epoch 90 loss = 74.3280792236
epoch 100 loss = 74.1109313965
[11]:
pyplot.plot(losses);
[12]:
plot_solution('(after training)')
[ ]: