{ "cells": [ { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "# 卡尔曼滤子\n", "\n", "Kalman filters are linear models for state estimation of dynamic systems [1]. They have been the de facto standard in many robotics and tracking/prediction applications because they are well suited for systems with uncertainty about an observable dynamic process. They use a \"observe, predict, correct\" paradigm to extract information from an otherwise noisy signal. In Pyro, we can build differentiable Kalman filters with learnable parameters using the `pyro.contrib.tracking` [library](http://docs.pyro.ai/en/dev/contrib.tracking.html#module-pyro.contrib.tracking.extended_kalman_filter)" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Dynamic process\n", "\n", "To start, consider this simple motion model:\n", "\n", "$$ X_{k+1} = FX_k + \\mathbf{W}_k $$\n", "$$ \\mathbf{Z}_k = HX_k + \\mathbf{V}_k $$\n", "\n", "where $k$ is the state, $X$ is the signal estimate, $Z_k$ is the observed value at timestep $k$, $\\mathbf{W}_k$ and $\\mathbf{V}_k$ are independent noise processes (ie $\\mathbb{E}[w_k v_j^T] = 0$ for all $j, k$) which we'll approximate as Gaussians. Note that the state transitions are linear." ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Kalman Update\n", "At each time step, we perform a prediction for the mean and covariance:\n", "$$ \\hat{X}_k = F\\hat{X}_{k-1}$$\n", "$$\\hat{P}_k = FP_{k-1}F^T + Q$$\n", "\n", "and a correction for the measurement:\n", "\n", "$$ K_k = \\hat{P}_k H^T(H\\hat{P}_k H^T + R)^{-1}$$\n", "$$ X_k = \\hat{X}_k + K_k(z_k - H\\hat{X}_k)$$\n", "$$ P_k = (I-K_k H)\\hat{P}_k$$\n", "\n", "where $X$ is the position estimate, $P$ is the covariance matrix, $K$ is the Kalman Gain, and $Q$ and $R$ are covariance matrices.\n", "\n", "For an in-depth derivation, see \\[2\\]" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "## Nonlinear Estimation: Extended Kalman Filter\n", "\n", "What if our system is non-linear, eg in GPS navigation? Consider the following non-linear system:\n", "\n", "$$ X_{k+1} = \\mathbf{f}(X_k) + \\mathbf{W}_k $$\n", "$$ \\mathbf{Z}_k = \\mathbf{h}(X_k) + \\mathbf{V}_k $$\n", "\n", "Notice that $\\mathbf{f}$ and $\\mathbf{h}$ are now (smooth) non-linear functions.\n" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "The Extended Kalman Filter (EKF) attacks this problem by using a local linearization of the Kalman filter via a [Taylors Series expansion](https://en.wikipedia.org/wiki/Taylor_series).\n", "\n", "$$ f(X_k, k) \\approx f(x_k^R, k) + \\mathbf{H}_k(X_k - x_k^R) + \\cdots$$\n", "\n", "where $\\mathbf{H}_k$ is the Jacobian matrix at time $k$, $x_k^R$ is the previous optimal estimate, and we ignore the higher order terms. At each time step, we compute a Jacobian conditioned the previous predictions (this computation is handled by Pyro under the hood), and use the result to perform a prediction and update.\n", "\n", "Omitting the derivations, the modification to the above predictions are now:\n", "$$ \\hat{X}_k \\approx \\mathbf{f}(X_{k-1}^R)$$\n", "$$ \\hat{P}_k = \\mathbf{H}_\\mathbf{f}(X_{k-1})P_{k-1}\\mathbf{H}_\\mathbf{f}^T(X_{k-1}) + Q$$\n", "\n", "and the updates are now:\n", "\n", "$$ X_k \\approx \\hat{X}_k + K_k\\big(z_k - \\mathbf{h}(\\hat{X}_k)\\big)$$\n", "$$ K_k = \\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) \\Big(\\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\hat{P}_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k) + R_k\\Big)^{-1} $$\n", "$$ P_k = \\big(I - K_k \\mathbf{H}_\\mathbf{h}(\\hat{X}_k)\\big)\\hat{P}_K$$\n", "\n", "In Pyro, all we need to do is create an `EKFState` object and use its `predict` and `update` methods. Pyro will do exact inference to compute the innovations and we will use SVI to learn a MAP estimate of the position and measurement covariances.\n", "\n", "As an example, let's look at an object moving at near-constant velocity in 2-D in a discrete time space over 100 time steps." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "import os\n", "import math\n", "\n", "import torch\n", "import pyro\n", "import pyro.distributions as dist\n", "from pyro.infer.autoguide import AutoDelta\n", "from pyro.optim import Adam\n", "from pyro.infer import SVI, Trace_ELBO, config_enumerate\n", "from pyro.contrib.tracking.extended_kalman_filter import EKFState\n", "from pyro.contrib.tracking.distributions import EKFDistribution\n", "from pyro.contrib.tracking.dynamic_models import NcvContinuous\n", "from pyro.contrib.tracking.measurements import PositionMeasurement\n", "\n", "smoke_test = ('CI' in os.environ)\n", "assert pyro.__version__.startswith('1.3.0')\n", "pyro.enable_validation(True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "dt = 1e-2\n", "num_frames = 10\n", "dim = 4\n", "\n", "# Continuous model\n", "ncv = NcvContinuous(dim, 2.0)\n", "\n", "# Truth trajectory\n", "xs_truth = torch.zeros(num_frames, dim)\n", "# initial direction\n", "theta0_truth = 0.0\n", "# initial state\n", "with torch.no_grad():\n", " xs_truth[0, :] = torch.tensor([0.0, 0.0, math.cos(theta0_truth), math.sin(theta0_truth)])\n", " for frame_num in range(1, num_frames):\n", " # sample independent process noise\n", " dx = pyro.sample('process_noise_{}'.format(frame_num), ncv.process_noise_dist(dt))\n", " xs_truth[frame_num, :] = ncv(xs_truth[frame_num-1, :], dt=dt) + dx" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "Next, let's specify the measurements. Notice that we only measure the positions of the particle." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "Collapsed": "false" }, "outputs": [], "source": [ "# Measurements\n", "measurements = []\n", "mean = torch.zeros(2)\n", "# no correlations\n", "cov = 1e-5 * torch.eye(2)\n", "with torch.no_grad():\n", " # sample independent measurement noise\n", " dzs = pyro.sample('dzs', dist.MultivariateNormal(mean, cov).expand((num_frames,)))\n", " # compute measurement means\n", " zs = xs_truth[:, :2] + dzs" ] }, { "cell_type": "markdown", "metadata": { "Collapsed": "false" }, "source": [ "We'll use a [Delta autoguide](http://docs.pyro.ai/en/dev/infer.autoguide.html#autodelta) to learn MAP estimates of the position and measurement covariances. 