{ "cells": [ { "cell_type": "markdown", "metadata": { "Collapsed": "false", "toc-hr-collapsed": false }, "source": [ "# SVI Part II: 条件独立性,子采样和 Amortization\n", "\n", "\n", "为了对大数据使用变分推断,我们需要使用条件独立, 子采样和 Amortization等技术。" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "> **The Goal: Scaling SVI to Large Datasets**\n", "\n", "一般情况下 SVI 过程中每次更新的计算复杂度是正比于样本数,所以我们需要是用 mini-batch 的办法减少复杂度。对于具备 $N$ 个样本的模型而言, running the `model` and `guide` and constructing the ELBO involves evaluating log pdf's 的计算复杂度随着正比于样本数 $N$。 这种情况在样本很多时会是一个问题, 幸运的是,目标函数 ELBO 天然的支持子采样 provided that 我们的 `model/guide` 具有一些我们可以利用的条件独立性结构. 例如, 在 observations 在给定潜变量下条件独立时, ELBO 目标函数 $\\mathbb{E}_{q_{\\phi}({\\bf z})} \\left [ \\log p_{\\theta}({\\bf x}, {\\bf z}) - \\log q_{\\phi}({\\bf z})\n", "\\right]$ 中的相应对数似然有如下近似: \n", "\n", "$$ E[\\log p({\\bf x}| {\\bf z})] \\approx \\frac{1}{N}\\sum_{i=1}^N \\log p({\\bf x}_i | {\\bf z}) \\approx \\frac{1}{M}\n", "\\sum_{i\\in{\\mathcal{I}_M}} \\log p({\\bf x}_i | {\\bf z}) $$\n", "\n", "where $\\mathcal{I}_M$ is a mini-batch of indices of size $M$ with $M `Pyro.plate`: 从 Sequential `plate` 到 Vectorized `plate` \n", "\n", "让我们回到 [previous tutorial](svi_part_i.ipynb) 中使用的例子。为了方便起见,让我们在这里回顾 `model` 的主要逻辑:\n", "\n", "```python\n", "def model(data):\n", " # sample f from the beta prior\n", " f = pyro.sample(\"latent_fairness\", dist.Beta(alpha0, beta0))\n", " # loop over the observed data using pyro.sample with the obs keyword argument\n", " for i in range(len(data)):\n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i])\n", "``` " ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "\n", "对于该模型,给定潜变量 `latent_fairness` ,观测样本是条件独立的. 在 Pyro 中声明这种独立性的方法基本上就是使用 Pyro 的 `plate` 替代 Python 内置函数 `range`。 \n", "\n", "```python\n", "# 我们通过 plate 来声明给定潜变量,观测样本之间的条件独立性。\n", "def model(data):\n", " # sample f from the beta prior\n", " f = pyro.sample(\"latent_fairness\", dist.Beta(alpha0, beta0))\n", " # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]\n", " for i in pyro.plate(\"data_loop\", len(data)): \n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i])\n", "```\n", "\n", "我们看到 `pyro.plate` 与 `range` 非常相似,但有一个关键区别:每次 `plate` 的调用需要额外指定 **a unique name.**" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "到目前为止,一切都很好。Pyro现在可以利用给定潜在随机变量下观测值的条件独立性。 But how does this actually work? 基本上,`pyro.plate`是使用上下文管理器(context manager)实现的。 At every execution of the body of the `for` loop we enter a new (conditional) independence context which is then exited at the end of the `for` loop body. 换句话说就是:\n", "\n", "- because each observed `pyro.sample` statement occurs within a different execution of the body of the `for` loop, Pyro marks 每个观测都是独立的。\n", "- 并且这种独立性准确来说是条件独立性 _given_ `latent_fairness` because `latent_fairness` is sampled _outside_ of the context of `data_loop`." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "在继续之前,让我们提一下使用 sequential `plate` 时要避免的一些陷阱。考虑上述代码片段的以下变体:\n", "\n", "```python\n", "# WARNING 不要这样做!\n", "my_reified_list = list(pyro.plate(\"data_loop\", len(data)))\n", "for i in my_reified_list: \n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i])\n", "```\n", "\n", "这将无法得到想要的效果, since `list()` will enter and exit the `data_loop` context completely before a single `pyro.sample` statement is called. 类似的,用户需要注意 NOT to leak mutable computations across the boundary of the context manager, as this may lead to subtle bugs. 例如, `pyro.plate` 不适用于时序模型,在该模型中循环的每次执行都依赖于上一次执行; 所以时序模型中应该使用 `range` 或者 `pyro.markov` ." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### `plate` 向量化\n", "\n", "概念上 vectorized `plate` 和 sequential `plate` 是一样的 except that it is a vectorized operation (as `torch.arange` is to `range`). 因此,它有可能实现大幅提速 compared to the explicit `for` loop that appears with sequential `plate`. Let's see how this looks for our running example. 首先我们需要把 `data` 写成张量的形式:\n", "\n", "```python\n", "data = torch.zeros(10)\n", "data[0:6] = torch.ones(6) # 6 heads and 4 tails\n", "```\n", "\n", "然后我们标记条件独立性:\n", "\n", "```python\n", "# 向量化 plate 能够帮助加速后续相关计算。\n", "with plate('observe_data'):\n", " pyro.sample('obs', dist.Bernoulli(f), obs=data)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "让我们将其与 sequential `plate` 用法进行 point-by-point 比较:\n", "\n", "- 这两种模式都要求用户指定 `plate` 唯一的 name。\n", "- 注意这个代码块只引入一个观测随机变量(namely `obs`), since the entire tensor is considered at once. \n", "- since there is no need for an iterator in this case, 无需指定 `plate` context 所涉及的张量的长度。\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "Hint for the blew section:\n", "\n", "```python\n", "for i in pyro.plate(\"data_loop\", len(data)): \n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i])\n", "\n", "for i in pyro.plate(\"data_loop\", len(data), subsample_size=5):\n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i]) \n", " \n", "with plate('observe_data'):\n", " pyro.sample('obs', dist.Bernoulli(f), obs=data) \n", " \n", "with plate('observe_data', size=10, subsample_size=5) as ind:\n", " pyro.sample('obs', dist.Bernoulli(f), obs=data.index_select(0, ind)) \n", " \n", "``` " ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false", "toc-hr-collapsed": true }, "source": [ "## 子采样\n", "\n", " 对于大规模数据集,每次训练只能用小批量样本进行训练,也就是 subsampling. \n", "\n", "现在,我们知道了如何在Pyro中标记条件独立性。这本身就很有用(请参见SVI第III部分中的 [dependency tracking section](svi_part_iii.ipynb)),但是我们也想进行子采样,以便可以对大型数据集进行 SVI 。根据 `model` 和 `guide` 的结构,Pyro支持几种进行子采样的方法。让我们一一讲解。" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### 自动子采样 with `plate`\n", "\n", "\n", "首先让我们看一下最简单的情况,在这种情况下,我们可以通过在 `plate` 中增加一个或者两个额外的参数来得到子采样:\n", "\n", "```python\n", "for i in pyro.plate(\"data_loop\", len(data), subsample_size=5):\n", " pyro.sample(\"obs_{}\".format(i), dist.Bernoulli(f), obs=data[i])\n", "``` " ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "That's all there is to it: we just use the argument `subsample_size`. 每当运行`model()` 的时候,我们只会计算 `data` 5个随机抽取的样本对数似然; 此外,对数似然将自动缩放 by the appropriate factor of $\\tfrac{10}{5} = 2$. 对于向量化 `plate` 如何子采样? 使用方法也完全类似:\n", "\n", "```python\n", "with plate('observe_data', size=10, subsample_size=5) as ind:\n", " pyro.sample('obs', dist.Bernoulli(f), obs=data.index_select(0, ind))\n", "```" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "重要的是,`plate`现在返回一个索引`ind`的张量,在这种情况下,它的长度为5。请注意,除了参数`subsample_size`外,我们还传递了参数`size`,以便`plate`为获得张量 `data` 的完整大小,以便它可以计算正确的缩放因子。就像sequential `plate` 一样,the user is responsible for selecting the correct datapoints using the indices provided by `plate`. \n", "\n", "最后, 请注意,如果数据在GPU上,则用户必须将 `device` 参数传递给 `plate`。" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false", "toc-hr-collapsed": true }, "source": [ "### 自定义子采样 with `plate`\n", "\n", "每次 `model()` 运行的时候,`plate` 都会进行新的子采样。由于这种子采样是 stateless,因此可能会导致一些问题:对于足够大的数据集,即使经过大量的迭代,也存在不可忽略的可能性,即从未选择某些数据点。为了避免这种情况,用户可以通过 `plate` 的参数 `subsample` 来控制子采样的过程。 See [the docs](http://docs.pyro.ai/en/dev/primitives.html#pyro.plate) for details.\n", "\n", " 思考:观测数据中有缺失值的时候,模型分布中的不同变量会有数量不同的观测值,此时子采样如何进行?同一个模型分布中,不同的 `plate` 可以有不同的子采样大小吗? \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false", "toc-hr-collapsed": true }, "source": [ "### 不同分布下子采样" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "#### 仅有局部变量时的子采样\n", "\n", "\n", "\n", "我们考虑具备如下联合概率密度,也就是只有局部随机变量的 `model`:\n", "\n", "$$ p({\\bf x}, {\\bf z}) = \\prod_{i=1}^N p({\\bf x}_i | {\\bf z}_i) p({\\bf z}_i) $$\n", "\n", "For a model with this dependency structure the scale factor introduced by subsampling scales all the terms in the ELBO by the same amount. 例如,vanilla VAE 就是这种情况。 这就解释了为什么对于VAE,用户可以完全控制子采样并将 mini-batches 直接传递给 `model` 和 `guide`; `plate` is still used, but `subsample_size` and `subsample` are not. To see how this looks in detail, see the [VAE tutorial](vae.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "#### 同时存在局部和全局变量时的子采样\n", "\n", "在掷硬币的例子中,因为唯一要抽样的是观测变量, 所以 `plate` 只出现在 `model` 而没有出现在 `guide` 中。让我们看一个更复杂的例子,也就是 `plate` 出现在 `model` 而没有出现在 `guide`中. To make things simple let's keep the discussion somewhat abstract and avoid writing a complete model and guide. " ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "考虑一个具备如下分布的 `model`:\n", "\n", "$$ p({\\bf x}, {\\bf z}, \\beta) = p(\\beta) \n", "\\prod_{i=1}^N p({\\bf x}_i | {\\bf z}_i) p({\\bf z}_i | \\beta) $$\n", "\n", "这里有 $N$ 个观测变量 $\\{ {\\bf x}_i \\}$ 和 $N$ 个局部潜变量 \n", "$\\{ {\\bf z}_i \\}$,还有一个全局潜变量 $\\beta$。 我们的 `gude` 将被分解为\n", "\n", "$$ q({\\bf z}, \\beta) = q(\\beta) \\prod_{i=1}^N q({\\bf z}_i | \\beta, \\lambda_i) $$" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "这里我们显式的引入来 $N$ 个局部变分参数 $\\{\\lambda_i \\}$, while the other variational parameters are left implicit. 模型分布都指导分布都具备条件独立性结构,具体来说就是:\n", "\n", "- 在模型分布中, 给定局部变量 $\\{ {\\bf z}_i \\}$ 观测变量 $\\{ {\\bf x}_i \\}$ 是条件独立的. 另外,给定 $\\beta$ 潜变量 $\\{\\bf {z}_i \\}$ 是条件独立的. \n", "- 在指导分布中, 给定局部变分参数 $\\{\\lambda_i \\}$ 和全部变量 $\\beta$ 潜变量 $\\{\\bf {z}_i \\}$ 是条件独立的. " ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "为了在 Pyro 中标记这些条件独立性和进行子采样, 我们需要在 `model` and `guide` 中都使用 `plate`. Let's sketch out the basic logic using sequential `plate` (a more complete piece of code would include `pyro.param` statements, etc.). 首先定义模型分布:\n", "\n", "```python\n", "def model(data):\n", " beta = pyro.sample(\"beta\", ...) # sample the global RV\n", " for i in pyro.plate(\"locals\", len(data)):\n", " z_i = pyro.sample(\"z_{}\".format(i), ...)\n", " # compute the parameter used to define the observation \n", " # likelihood using the local random variable\n", " theta_i = compute_something(z_i) \n", " pyro.sample(\"obs_{}\".format(i), dist.MyDist(theta_i), obs=data[i])\n", "```" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "对比前面掷硬币的例子,这里 `pyro.sample` 同时出现在 `plate` loop 的里面和外面. 接下来是指导分布:\n", "\n", "```python\n", "def guide(data):\n", " beta = pyro.sample(\"beta\", ...) # sample the global RV\n", " for i in pyro.plate(\"locals\", len(data), subsample_size=5):\n", " # sample the local RVs\n", " pyro.sample(\"z_{}\".format(i), ..., lambda_i)\n", "```\n", "\n", "请注意,是 `guide()` 的索引只会被子抽样一次,Pyro 后端确保在执行 `model()` 期间使用相同的索引集,因此只需在 `guide()` 中指定 `subsample_size`。" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### More about `plate` \n", "\n", " Tensor shapes and vectorized `plate` \n", "\n", "在本教程中,`pyro.plate` 的使用仅限于相对简单的情况。 For example, none of the `plate`s were nested inside of other `plate`s. In order to make full use of `plate`, the user must be careful to use Pyro's tensor shape semantics. For a discussion see the [tensor shapes tutorial](tensor_shapes.ipynb)." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Amortization\n", "\n", " 变分自编码器(VAE)的由来\n", "\n", "让我们再次考虑具有全局和局部潜变量的 `model()`,以及局部变分参数的 `guide()`:\n", "\n", "$$ p({\\bf x}, {\\bf z}, \\beta) = p(\\beta) \n", "\\prod_{i=1}^N p({\\bf x}_i | {\\bf z}_i) p({\\bf z}_i | \\beta) \\qquad \\qquad\n", "q({\\bf z}, \\beta) = q(\\beta) \\prod_{i=1}^N q({\\bf z}_i | \\beta, \\lambda_i) $$" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "对于样本数 $N$ 的情况,使用局部变分参数可以是个好方法。 但是当 $N$ 很大的时候, the fact that the space we're doing optimization over grows with $N$ can be a real problem. 一种避免这个问题的办法是 *amortization*.\n", "\n", "这种方法是这样的。 不同于引入局部变分参数 $\\lambda_i$, 我们学习一个单参数函数 $f(\\cdot)$,定义如下形式的变分分布:\n", "\n", "$$q(\\beta) \\prod_{n=1}^N q({\\bf z}_i | \\beta, f({\\bf x}_i))$$" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "\n", "函数$f(\\cdot)$ (就是观测数据映射到该数据点的局部变分参数)需要足够丰富来近似后验,从而使得我们可以处理大型数据集而不必为每个样本引入一个单独的变量参数。\n", "这种方法也有其他好处: for example, during learning $f(\\cdot)$ effectively allows us to share statistical power among different datapoints. 这正是 [VAE](vae.ipynb) 中使用的方法。" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false", "toc-hr-collapsed": false }, "source": [ "## 变分自编码\n", "\n", "这里使用便分布自编码器的例子查看本章所讲内容。\n", "\n" ] }, { "cell_type": "raw", "metadata": { "Collapsed": "false" }, "source": [ "
\n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", "
\n", "
\n", " Figure VAE: (Left) 模型分布\n", " (Right) 指导分布\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### 完整代码" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "Collapsed": "false" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[epoch 000] average training loss: 190.9630\n", "[epoch 000] average test loss: 155.7649\n", "[epoch 001] average training loss: 146.2289\n", "[epoch 002] average training loss: 132.9159\n", "[epoch 003] average training loss: 124.6936\n", "[epoch 004] average training loss: 119.5353\n" ] } ], "source": [ "import os, torch, pyro\n", "import numpy as np\n", "import torchvision.datasets as dset\n", "import torch.nn as nn\n", "import torchvision.transforms as transforms\n", "import pyro.distributions as dist\n", "import pyro.contrib.examples.util # patches torchvision\n", "from pyro.infer import SVI, Trace_ELBO\n", "from pyro.optim import Adam\n", "assert pyro.__version__.startswith('1.3.0')\n", "pyro.enable_validation(True)\n", "pyro.distributions.enable_validation(False)\n", "pyro.set_rng_seed(0)\n", "\n", "class Decoder(nn.Module): # 用于构建模型分布的 decoder\n", " def __init__(self, z_dim, hidden_dim):\n", " super().__init__()\n", " self.fc1 = nn.Linear(z_dim, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, 784)\n", " self.softplus = nn.Softplus()\n", " self.sigmoid = nn.Sigmoid()\n", "\n", " def forward(self, z):\n", " hidden = self.softplus(self.fc1(z))\n", " loc_img = self.sigmoid(self.fc21(hidden))\n", " return loc_img\n", "\n", "class Encoder(nn.Module): # 用于构建指导分布的 encoder\n", " def __init__(self, z_dim, hidden_dim):\n", " super().__init__()\n", " self.fc1 = nn.Linear(784, hidden_dim)\n", " self.fc21 = nn.Linear(hidden_dim, z_dim)\n", " self.fc22 = nn.Linear(hidden_dim, z_dim)\n", " self.softplus = nn.Softplus()\n", "\n", " def forward(self, x):\n", " x = x.reshape(-1, 784)\n", " hidden = self.softplus(self.fc1(x))\n", " z_loc = self.fc21(hidden)\n", " z_scale = torch.exp(self.fc22(hidden))\n", " return z_loc, z_scale\n", "\n", "class VAE(nn.Module):\n", " def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):\n", " super().__init__()\n", " self.encoder = Encoder(z_dim, hidden_dim)\n", " self.decoder = Decoder(z_dim, hidden_dim)\n", " if use_cuda:\n", " self.cuda()\n", " self.use_cuda = use_cuda\n", " self.z_dim = z_dim\n", "\n", " def model(self, x): # 模型分布 p(x|z)p(z)\n", " pyro.module(\"decoder\", self.decoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))\n", " z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))\n", " z = pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", " loc_img = self.decoder.forward(z)\n", " pyro.sample(\"obs\", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))\n", "\n", " def guide(self, x): # 指导分布 q(z|x)\n", " pyro.module(\"encoder\", self.encoder)\n", " with pyro.plate(\"data\", x.shape[0]):\n", " z_loc, z_scale = self.encoder.forward(x)\n", " pyro.sample(\"latent\", dist.Normal(z_loc, z_scale).to_event(1))\n", "\n", " def reconstruct_img(self, x):\n", " z_loc, z_scale = self.encoder(x) \n", " z = dist.Normal(z_loc, z_scale).sample()\n", " loc_img = self.decoder(z) # 注意在图像空间中我们没有抽样\n", " return loc_img\n", "\n", "def setup_data_loaders(batch_size=128, use_cuda=False):\n", " root = './data'\n", " download = False\n", " trans = transforms.ToTensor()\n", " train_set = dset.MNIST(root=root, train=True, transform=trans,\n", " download=download)\n", " test_set = dset.MNIST(root=root, train=False, transform=trans)\n", " kwargs = {'num_workers': 1, 'pin_memory': use_cuda}\n", " train_loader = torch.utils.data.DataLoader(dataset=train_set,\n", " batch_size=batch_size, shuffle=True, **kwargs)\n", " test_loader = torch.utils.data.DataLoader(dataset=test_set,\n", " batch_size=batch_size, shuffle=False, **kwargs)\n", " return train_loader, test_loader\n", "\n", "def train(svi, train_loader, use_cuda=False):\n", " epoch_loss = 0.\n", " for x, _ in train_loader:\n", " if use_cuda:\n", " x = x.cuda()\n", " epoch_loss += svi.step(x)\n", " normalizer_train = len(train_loader.dataset)\n", " total_epoch_loss_train = epoch_loss / normalizer_train\n", " return total_epoch_loss_train\n", "\n", "def evaluate(svi, test_loader, use_cuda=False):\n", " test_loss = 0.\n", " for x, _ in test_loader:\n", " if use_cuda:\n", " x = x.cuda()\n", " test_loss += svi.evaluate_loss(x)\n", " normalizer_test = len(test_loader.dataset)\n", " total_epoch_loss_test = test_loss / normalizer_test\n", " return total_epoch_loss_test\n", "\n", "# 模型训练\n", "LEARNING_RATE = 1.0e-3\n", "USE_CUDA = False\n", "NUM_EPOCHS = 5\n", "TEST_FREQUENCY = 5\n", "train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)\n", "pyro.clear_param_store()\n", "vae = VAE(use_cuda=USE_CUDA)\n", "adam_args = {\"lr\": LEARNING_RATE}\n", "optimizer = Adam(adam_args)\n", "svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())\n", "\n", "train_elbo = []\n", "test_elbo = []\n", "for epoch in range(NUM_EPOCHS):\n", " total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)\n", " train_elbo.append(-total_epoch_loss_train)\n", " print(\"[epoch %03d] average training loss: %.4f\" % (epoch, total_epoch_loss_train))\n", " if epoch % TEST_FREQUENCY == 0:\n", " # report test diagnostics\n", " total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)\n", " test_elbo.append(-total_epoch_loss_test)\n", " print(\"[epoch %03d] average test loss: %.4f\" % (epoch, total_epoch_loss_test))" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ "torch.Size([256, 784])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vae.reconstruct_img(x).shape" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "### VAE 中条件独立,子采样和 Amortization" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "使用了 `plate` 来表示条件独立性。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ " \u001b[0;32mdef\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 模型分布 p(x|z)p(z)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"decoder\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_loc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mz_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_scale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_ones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mz_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"latent\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloc_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"obs\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBernoulli\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc_img\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m784\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%psource vae.model\n", "# 条件独立" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "You’ll need to ensure that batch_shape is carefully controlled by either trimming it down with `.to_event(n)` or by declaring dimensions as independent via `pyro.plate`." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ " 子采样了吗? \n", "\n", "没有进行子采样,该程序通过 `setup_data_loaders` 控制了每次输入的样本数是 256。(上文提到“对于VAE,用户可以完全控制子采样并将 mini-batches 直接传递给 `model` 和 `guide`; `plate` is still used, but `subsample_size` and `subsample` are not.”)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "Collapsed": "false" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape is torch.Size([256, 1, 28, 28])\n", "encoder shape is torch.Size([256, 50])\n", "decoder shape is torch.Size([256, 784])\n" ] } ], "source": [ "train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)\n", "for x, _ in train_loader:\n", " break\n", "print('Input shape is ', x.shape)\n", "z = vae.encoder(x)[0]\n", "print('encoder shape is ', z.shape)\n", "x_ = vae.decoder.forward(z)\n", "print('decoder shape is ', x_.shape)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ " \u001b[0;32mdef\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 指导分布 q(z|x)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"encoder\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"latent\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%psource vae.guide\n", "# 可以看出这里没有进行子采样" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ " Amortization \n", "\n", "VAE 是只有局部随机变量的 `model`:\n", "\n", "$$ \n", "p({\\bf x}, {\\bf z}) = \n", "\\prod_{i=1}^N p_\\theta({\\bf x}_i | {\\bf z}_i) p({\\bf z}_i) \\qquad \\qquad\n", "q({\\bf z}) = \\prod_{i=1}^N q({\\bf z}_i |\\lambda_i) \n", "$$\n", "\n", "where amortization $\\lambda_i = (\\mu_\\phi(x_i), \\Sigma_\\phi(x_i))$" ] }, { "cell_type": "raw", "metadata": { "Collapsed": "false" }, "source": [ "
\n", " \n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", "
\n", "
\n", " Figure VAE: (Left) 模型分布\n", " (Right) 指导分布\n", "
\n", "
" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ " \u001b[0;32mdef\u001b[0m \u001b[0mguide\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 指导分布 q(z|x)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"encoder\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"latent\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%psource vae.guide\n", "# 使用解码器进行" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ " \u001b[0;32mdef\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# 模型分布 p(x|z)p(z)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"decoder\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"data\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_loc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_zeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mz_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz_scale\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_ones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mz_dim\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"latent\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz_loc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz_scale\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloc_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mpyro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"obs\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdist\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBernoulli\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloc_img\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m784\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%psource vae.model" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "Collapsed": "false" }, "outputs": [ { "data": { "text/plain": [ " \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftplus\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mloc_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msigmoid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfc21\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloc_img\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%psource vae.decoder.forward" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ " 参考文献 \n", "\n", "[1] `Stochastic Variational Inference`,\n", "
    \n", "Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley\n", "\n", "[2] `Auto-Encoding Variational Bayes`,
    \n", "Diederik P Kingma, Max Welling" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 4 }