贝叶斯回归简介(Part I)¶
回归是机器学习中最常见和最基本的监督学习任务之一。假设我们有如下形式的数据集 \(\mathcal{D}\):
线性回归的目标是根据数据拟合一个以下形式的函数:
where \(w\) and \(b\) are learnable parameters and \(\epsilon\) represents observation noise. Specifically \(w\) is a matrix of weights and \(b\) is a bias vector.
在本教程中,我们将首先在 PyTorch 中实现线性回归,并学习参数 \(w\) 和 \(b\) 的点估计。Then we will see how to incorporate uncertainty into our estimates by using Pyro to implement Bayesian regression. 此外,我们将学习 how to use the Pyro’s utility functions to do predictions and serve our model using TorchScript
.
[1]:
# 学完本文,您将理解如下程序
import os, torch, pyro
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyro.distributions as dist
from torch import nn
from functools import partial
from pyro.nn import PyroModule, PyroSample
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal
pyro.set_rng_seed(1)
pyro.enable_validation(True)
%matplotlib inline
plt.style.use('default')
def get_data():
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]
return x_data, y_data
x_data, y_data = get_data()
num_iterations =1000
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(num_iterations):
loss = svi.step(x_data, y_data)
if j % 200 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(x_data)))
guide.requires_grad_(False)
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name))
[iteration 0001] loss: 4.6074
[iteration 0201] loss: 2.5499
[iteration 0401] loss: 1.4601
[iteration 0601] loss: 1.4725
[iteration 0801] loss: 1.4677
AutoDiagonalNormal.loc Parameter containing:
tensor([-2.2916, -1.8635, -0.1926, 0.3305, 9.1682])
AutoDiagonalNormal.scale tensor([0.0559, 0.1428, 0.0459, 0.0847, 0.0635])
Setup¶
首先,导入所需的模块。
[1]:
%reset -s -f
[2]:
import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
# for CI testing
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)
pyro.set_rng_seed(1)
pyro.enable_validation(True)
# Set matplotlib settings
%matplotlib inline
plt.style.use('default')
Dataset
以下示例来自于文献 [1]. 我们希望探索 relationship between topographic heterogeneity of a nation as measured by the Terrain Ruggedness Index (数据中的 rugged 变量) and 人均GDP. 具体来说, it was noted by the authors in [2] that terrain ruggedness or bad geography is related to poorer economic performance outside of Africa, but rugged terrains have had a reverse effect on income for African nations. 让我们查看数据并研究这种关系,我们将关注数据集中的三个特征:
rugged
: quantifies the Terrain Ruggedness Indexcont_africa
: whether the given nation is in Africargdppc_2000
: Real GDP per capita for the year 2000
The response variable GDP is highly skewed, so we will log-transform it.
[3]:
DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
[4]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = df[df["cont_africa"] == 1]
non_african_nations = df[df["cont_africa"] == 0]
sns.scatterplot(non_african_nations["rugged"],
non_african_nations["rgdppc_2000"],
ax=ax[0])
ax[0].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="Non African Nations")
sns.scatterplot(african_nations["rugged"],
african_nations["rgdppc_2000"],
ax=ax[1])
ax[1].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="African Nations");
线性回归¶
我们想根据数据集中的两个特征预测一个国家的人均GDP的对数, 这两个特征包括 whether the nation is in Africa, and its Terrain Ruggedness Index.
We will create a trivial class called PyroModule[nn.Linear]
that subclasses PyroModule and torch.nn.Linear
. PyroModule
非常类似于 PyTorch’s nn.Module
, but additionally supports Pyro primitives as attributes that can be modified by Pyro’s effect handlers (see the next section on how we can have
module attributes that are pyro.sample
primitives). Some general notes:
PyTorch模块中可学习的参数是
nn.Parameter
的实例, in this case theweight
andbias
parameters of thenn.Linear
class. When declared inside aPyroModule
as attributes, these are automatically registered in Pyro’s param store. While this model does not require us to constrain the value of these parameters during optimization, this can also be easily achieved inPyroModule
using the PyroParam statement.请注意,虽然
PyroModule[nn.Linear]
的forward
方法继承自nn.Linear
, 也可以轻松重写. e.g. in the case of logistic regression, we apply a sigmoid transformation to the linear predictor.
[5]:
from torch import nn
from pyro.nn import PyroModule
assert issubclass(PyroModule[nn.Linear], nn.Linear)
assert issubclass(PyroModule[nn.Linear], PyroModule)
Training with PyTorch Optimizers¶
Note that in addition to the two features rugged
and cont_africa
, we also include an interaction term in our model, which lets us separately model the effect of ruggedness on the GDP for nations within and outside Africa.
我们使用均方误差(MSE)作为损失 and Adam as our optimizer from the torch.optim
module. We would like to optimize the parameters of our model, namely the weight
and bias
parameters of the network, which corresponds to our regression coefficents and the intercept.
[ ]:
# Dataset: Add a feature to capture the interaction between "cont_africa" and "rugged"
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]
# Regression model
linear_reg_model = PyroModule[nn.Linear](3, 1)
# Define loss and optimize
loss_fn = torch.nn.MSELoss(reduction='sum')
optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)
num_iterations = 1500 if not smoke_test else 2
def train():
# run the model forward on the data
y_pred = linear_reg_model(x_data).squeeze(-1)
# calculate the mse loss
loss = loss_fn(y_pred, y_data)
# initialize gradients to zero
optim.zero_grad()
# backpropagate
loss.backward()
# take a gradient step
optim.step()
return loss
for j in range(num_iterations):
loss = train()
if (j + 1) % 50 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss.item()))
# Inspect learned parameters
print("Learned parameters:")
for name, param in linear_reg_model.named_parameters():
print(name, param.data.numpy())
Plotting the Regression Fit¶
Let us plot the regression fit for our model, separately for countries outside and within Africa.
[ ]:
fit = df.copy()
fit["mean"] = linear_reg_model(x_data).detach().cpu().numpy()
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = fit[fit["cont_africa"] == 1]
non_african_nations = fit[fit["cont_africa"] == 0]
fig.suptitle("Regression Fit", fontsize=16)
ax[0].plot(non_african_nations["rugged"], non_african_nations["rgdppc_2000"], "o")
ax[0].plot(non_african_nations["rugged"], non_african_nations["mean"], linewidth=2)
ax[0].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="Non African Nations")
ax[1].plot(african_nations["rugged"], african_nations["rgdppc_2000"], "o")
ax[1].plot(african_nations["rugged"], african_nations["mean"], linewidth=2)
ax[1].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="African Nations");
We notice that the relationship between terrain ruggedness has an inverse relationship with GDP for non-African nations, but it positively affects the GDP for African nations. It is however unclear how robust this trend is. In particular, we would like to understand how the regression fit would vary due to parameter uncertainty. To address this, we will build a simple bayesian model for linear regression. Bayesian modeling offers a systematic framework for reasoning about model uncertainty. Instead of just learning point estimates, we’re going to learn a distribution over parameters that are consistent with the observed data.
使用 SVI 做贝叶斯回归¶
Bayesian Regression with Pyro’s Stochastic Variational Inference (SVI)
Model¶
In order to make our linear regression Bayesian, we need to put priors on the parameters \(w\) and \(b\). These are distributions that represent our prior belief about reasonable values for \(w\) and \(b\) (before observing any data).
Making a Bayesian model for linear regression is very intuitive using PyroModule
as earlier. Note the following:
The
BayesianRegression
module internally uses the samePyroModule[nn.Linear]
module. However, note that we replace theweight
and thebias
of the this module withPyroSample
statements. These statements allow us to place a prior over theweight
andbias
parameters, instead of treating them as fixed learnable parameters. For the bias component, we set a reasonably wide prior since it is likely to be substantially above 0.The
BayesianRegression.forward
method specifies the generative process. We generate the mean value of the response by calling thelinear
module (which, as you saw, samples theweight
andbias
parameters from the prior and returns a value for the mean response). Finally we use theobs
argument to thepyro.sample
statement to condition on the observed datay_data
with a learned observation noisesigma
. The model returns the regression line given by the variablemean
.
[ ]:
from pyro.nn import PyroSample
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
Using an AutoGuide¶
In order to do inference, i.e. learn the posterior distribution over our unobserved parameters, we will use Stochastic Variational Inference (SVI). The guide determines a family of distributions, and SVI
aims to find an approximate posterior distribution from this family that has the lowest KL divergence from the true posterior.
Users can write arbitrarily flexible custom guides in Pyro, but in this tutorial, we will restrict ourselves to Pyro’s autoguide library. In the next tutorial, we will explore how to write guides by hand.
To begin with, we will use the AutoDiagonalNormal
guide that models the distribution of unobserved parameters in the model as a Gaussian with diagonal covariance, i.e. it assumes that there is no correlation amongst the latent variables (quite a strong modeling assumption as we shall see in Part II). Under the hood, this defines a guide
that uses a Normal
distribution with learnable parameters corresponding to each sample
statement in the model.
e.g. in our case, this distribution should have a size of (5,)
correspoding to the 3 regression coefficients for each of the terms, and 1 component contributed each by the intercept term and sigma
in the model.
Autoguide also supports learning MAP estimates with AutoDelta
or composing guides with AutoGuideList
(see the docs for more information).
[ ]:
from pyro.infer.autoguide import AutoDiagonalNormal
model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)
Optimizing the Evidence Lower Bound¶
We will use stochastic variational inference (SVI) (for an introduction to SVI, see SVI Part I) for doing inference. Just like in the non-Bayesian linear regression model, each iteration of our training loop will take a gradient step, with the difference that in this case, we’ll use the Evidence Lower Bound (ELBO) objective instead of the MSE loss by constructing a Trace_ELBO
object that we pass to SVI
.
[ ]:
from pyro.infer import SVI, Trace_ELBO
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
Note that we use the Adam
optimizer from Pyro’s optim
module and not the torch.optim
module as earlier. Here Adam
is a thin wrapper around torch.optim.Adam
(see here for a discussion). Optimizers in pyro.optim
are used to optimize and update parameter values in Pyro’s parameter store. In particular, you will notice that we do not need to pass in learnable parameters to the optimizer since that is determined by the guide code and happens
behind the scenes within the SVI
class automatically. To take an ELBO gradient step we simply call the step method of SVI. The data argument we pass to SVI.step
will be passed to both model()
and guide()
. The complete training loop is as follows:
[ ]:
pyro.clear_param_store()
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss = svi.step(x_data, y_data)
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
We can examine the optimized parameter values by fetching from Pyro’s param store.
[ ]:
guide.requires_grad_(False)
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name))
As you can see, instead of just point estimates, we now have uncertainty estimates (AutoDiagonalNormal.scale
) for our learned parameters. Note that Autoguide packs the latent variables into a single tensor, in this case, one entry per variable sampled in our model. Both the loc
and scale
parameters have size (5,)
, one for each of the latent variables in the model, as we had remarked earlier.
To look at the distribution of the latent parameters more clearly, we can make use of the AutoDiagonalNormal.quantiles
method which will unpack the latent samples from the autoguide, and automatically constrain them to the site’s support (e.g. the variable sigma
must lie in (0, 10)
). We see that the median values for the parameters are quite close to the Maximum Likelihood point estimates we obtained from our first model.
[ ]:
guide.quantiles([0.25, 0.5, 0.75])
Model Evaluation¶
To evaluate our model, we’ll generate some predictive samples and look at the posteriors. For this we will make use of the Predictive utility class.
We generate 800 samples from our trained model. Internally, this is done by first generating samples for the unobserved sites in the
guide
, and then running the model forward by conditioning the sites to values sampled from theguide
. Refer to the Model Serving section for insight on how thePredictive
class works.Note that in
return_sites
, we specify both the outcome ("obs"
site) as well as the return value of the model ("_RETURN"
) which captures the regression line. Additionally, we would also like to capture the regression coefficients (given by"linear.weight"
) for further analysis.The remaining code is simply used to plot the 90% CI for the two variables from our model.
[ ]:
from pyro.infer import Predictive
def summary(samples):
site_stats = {}
for k, v in samples.items():
site_stats[k] = {
"mean": torch.mean(v, 0),
"std": torch.std(v, 0),
"5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
"95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
}
return site_stats
predictive = Predictive(model, guide=guide, num_samples=800,
return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_data)
pred_summary = summary(samples)
[ ]:
mu = pred_summary["_RETURN"]
y = pred_summary["obs"]
predictions = pd.DataFrame({
"cont_africa": x_data[:, 0],
"rugged": x_data[:, 1],
"mu_mean": mu["mean"],
"mu_perc_5": mu["5%"],
"mu_perc_95": mu["95%"],
"y_mean": y["mean"],
"y_perc_5": y["5%"],
"y_perc_95": y["95%"],
"true_gdp": y_data,
})
[ ]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = predictions[predictions["cont_africa"] == 1]
non_african_nations = predictions[predictions["cont_africa"] == 0]
african_nations = african_nations.sort_values(by=["rugged"])
non_african_nations = non_african_nations.sort_values(by=["rugged"])
fig.suptitle("Regression line 90% CI", fontsize=16)
ax[0].plot(non_african_nations["rugged"],
non_african_nations["mu_mean"])
ax[0].fill_between(non_african_nations["rugged"],
non_african_nations["mu_perc_5"],
non_african_nations["mu_perc_95"],
alpha=0.5)
ax[0].plot(non_african_nations["rugged"],
non_african_nations["true_gdp"],
"o")
ax[0].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="Non African Nations")
idx = np.argsort(african_nations["rugged"])
ax[1].plot(african_nations["rugged"],
african_nations["mu_mean"])
ax[1].fill_between(african_nations["rugged"],
african_nations["mu_perc_5"],
african_nations["mu_perc_95"],
alpha=0.5)
ax[1].plot(african_nations["rugged"],
african_nations["true_gdp"],
"o")
ax[1].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="African Nations");
The above figure shows the uncertainty in our estimate of the regression line, and the 90% CI around the mean. We can also see that most of the data points actually lie outside the 90% CI, and this is expected because we have not plotted the outcome variable which will be affected by sigma
! Let us do so next.
[ ]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
fig.suptitle("Posterior predictive distribution with 90% CI", fontsize=16)
ax[0].plot(non_african_nations["rugged"],
non_african_nations["y_mean"])
ax[0].fill_between(non_african_nations["rugged"],
non_african_nations["y_perc_5"],
non_african_nations["y_perc_95"],
alpha=0.5)
ax[0].plot(non_african_nations["rugged"],
non_african_nations["true_gdp"],
"o")
ax[0].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="Non African Nations")
idx = np.argsort(african_nations["rugged"])
ax[1].plot(african_nations["rugged"],
african_nations["y_mean"])
ax[1].fill_between(african_nations["rugged"],
african_nations["y_perc_5"],
african_nations["y_perc_95"],
alpha=0.5)
ax[1].plot(african_nations["rugged"],
african_nations["true_gdp"],
"o")
ax[1].set(xlabel="Terrain Ruggedness Index",
ylabel="log GDP (2000)",
title="African Nations");
We observe that the outcome from our model and the 90% CI accounts for the majority of the data points that we observe in practice. It is usually a good idea to do such posterior predictive checks to see if our model gives valid predictions.
Finally, let us revisit our earlier question of how robust the relationship between terrain ruggedness and GDP is against any uncertainty in the parameter estimates from our model. For this, we plot the distribution of the slope of the log GDP given terrain ruggedness for nations within and outside Africa. As can be seen below, the probability mass for African nations is largely concentrated in the positive region and vice-versa for other nations, lending further credence to the original hypothesis.
[ ]:
weight = samples["linear.weight"]
weight = weight.reshape(weight.shape[0], 3)
gamma_within_africa = weight[:, 1] + weight[:, 2]
gamma_outside_africa = weight[:, 1]
fig = plt.figure(figsize=(10, 6))
sns.distplot(gamma_within_africa, kde_kws={"label": "African nations"},)
sns.distplot(gamma_outside_africa, kde_kws={"label": "Non-African nations"})
fig.suptitle("Density of Slope : log(GDP) vs. Terrain Ruggedness");
Model Serving via TorchScript¶
Finally, note that the model
, guide
and the Predictive
utility class are all torch.nn.Module
instances, and can be serialized as TorchScript.
Here, we show how we can serve a Pyro model as a torch.jit.ModuleScript, which can be run separately as a C++ program without a Python runtime.
To do so, we will rewrite our own simple version of the Predictive
utility class using Pyro’s effect handling library. This uses:
the
trace
poutine to capture the execution trace from running the model/guide code.the
replay
poutine to condition the sites in the model to values sampled from the guide trace.
[ ]:
from collections import defaultdict
from pyro import poutine
from pyro.poutine.util import prune_subsample_sites
import warnings
class Predict(torch.nn.Module):
def __init__(self, model, guide):
super().__init__()
self.model = model
self.guide = guide
def forward(self, *args, **kwargs):
samples = {}
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
for site in prune_subsample_sites(model_trace).stochastic_nodes:
samples[site] = model_trace.nodes[site]['value']
return tuple(v for _, v in sorted(samples.items()))
predict_fn = Predict(model, guide)
predict_module = torch.jit.trace_module(predict_fn, {"forward": (x_data,)}, check_trace=False)
We use torch.jit.trace_module to trace the forward
method of this module and save it using torch.jit.save. This saved model reg_predict.pt
can be loaded with PyTorch’s C++ API using torch::jit::load(filename)
, or using the Python API as we do below.
[ ]:
torch.jit.save(predict_module, '/tmp/reg_predict.pt')
pred_loaded = torch.jit.load('/tmp/reg_predict.pt')
pred_loaded(x_data)
Let us check that our Predict
module was indeed serialized correctly, by generating samples from the loaded module and regenerating the previous plot.
[ ]:
weight = []
for _ in range(800):
# index = 1 corresponds to "linear.weight"
weight.append(pred_loaded(x_data)[1])
weight = torch.stack(weight).detach()
weight = weight.reshape(weight.shape[0], 3)
gamma_within_africa = weight[:, 1] + weight[:, 2]
gamma_outside_africa = weight[:, 1]
fig = plt.figure(figsize=(10, 6))
sns.distplot(gamma_within_africa, kde_kws={"label": "African nations"},)
sns.distplot(gamma_outside_africa, kde_kws={"label": "Non-African nations"})
fig.suptitle("Loaded TorchScript Module : log(GDP) vs. Terrain Ruggedness");
In the next section, we’ll look at how to write guides for variational inference as well as compare the results with inference via HMC.
参考文献¶
McElreath, D., Statistical Rethinking, Chapter 7, 2016
Nunn, N. & Puga, D., Ruggedness: The blessing of bad geography in Africa”, Review of Economics and Statistics 94(1), Feb. 2012