Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Sampling with Stein Variational Gradient Descent (SVGD)

The SVGD Update Rule

Stein Variational Gradient Descent Liu & Wang, 2016 is a particle-based sampling algorithm that begins with samples drawn from an initial reference distribution q0q^0 and iteratively transports them via the SVGD velocity field ϕ\phi. After enough iterations, the transformed particles form an empirical distribution that closely matches the target.

Formally, at iteration l[0,,L1]l\in[0, \dots, L-1], the SVGD update rule is:

xil+1=xil+ϵExjlql[κ(xil,xjl)xjllogpˉ(xjl)+xjlκ(xil,xjl)]ϕ(xil),x_i^{l+1} = x_i^{l} + \epsilon \underbrace{ \mathbb{E}_{x_j^l \sim q^l} \left[ \kappa(x_i^l, x_j^l) \nabla_{x_j^l} \log \bar p(x_j^l) + \nabla_{x_j^l} \kappa(x_i^l, x_j^l) \right] }_{\phi(x_i^l)},

where xilx_i^l is the ii-th particle at iteration ll, ϵ\epsilon is the step-size, pˉ\bar p is the unnormalized density, and κ\kappa is a kernel function, most commonly the RBF kernel.

Particles’ evolution starting from q^0 and following the SVGD velocity field.

Figure 1:Particles’ evolution starting from q0q^0 and following the SVGD velocity field.

Code for Figure 1.
from svgd.sampler import SVGD
from svgd.distributions import TorchDistribution
from svgd.kernels import RBF
from svgd.kernels.parameters import HeuristicKP
from svgd.lrs import ParameterLR
from svgd.callbacks import Logger

import torch
from torch.distributions import MixtureSameFamily, Categorical, Independent, Normal

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

torch.manual_seed(0)

target_distribution = TorchDistribution(
    MixtureSameFamily(
        Categorical(torch.ones(2)),
        Independent(Normal(torch.tensor([[-1.0, 0.0], [1.0, 0.0]]), 1 / 3), 1),
    )
)
initial_distribution = TorchDistribution(Independent(Normal(torch.zeros(2), 1.0), 1))
kernel = RBF(HeuristicKP("median"))
lr = ParameterLR(torch.tensor(0.5))
logger = Logger(log_x=True)
logger.activated = True
svgd = SVGD(
    target_distribution=target_distribution,
    initial_distribution=initial_distribution,
    kernel=kernel,
    lr=lr,
    callbacks=[logger],
)

n_particles = 100
n_steps = 50
x, _, _ = svgd.sample(n_particles=n_particles, n_steps=n_steps)
x = x.detach()

grid = torch.arange(-2, 2, 0.001)
xg, yg = torch.meshgrid(grid, grid, indexing="ij")
grid = torch.cat((xg.reshape(-1)[:, None], yg.reshape(-1)[:, None]), dim=-1)
zg = target_distribution.log_prob(grid).exp().view(xg.shape)

fig, ax = plt.subplots()
ax.pcolormesh(xg, yg, zg, cmap="Oranges")
ax.set_xlim(-2.0, 2.0)
ax.set_ylim(-2.0, 2.0)
ax.axis("off")
scatter = ax.scatter([], [], color="black", alpha=0.6)


def animate(frame):
    scatter.set_offsets(logger.x[frame])
    return (scatter,)


animation = FuncAnimation(
    fig,
    animate,
    len(logger.x),
    interval=1000 / 60,
    blit=False,
)
animation.save("particle_evolution.gif", fps=120)

A key property of SVGD is that the velocity field ϕ\phi is chosen to maximally decrease the KL divergence between the particles distribution and the target. Intuitively, each update moves the particles in the direction that most closely makes their empirical distribution resemble the target distribution.

The RBF Kernel

The RBF kernel is defined as

κ(xil,xjl)=exp(12σ2xilxjl2),\kappa(x_i^l, x_j^l) = \exp\left( -\frac{1}{2\sigma^2} ||x_i^l - x_j^l||^2 \right),

where σ\sigma is the kernel bandwidth.

Using the RBF kernel, the SVGD update rule becomes:

xil+1=xil+ϵE[κ(xil,xjl)xjllogpˉ(xjl)drift term1σ2κ(xil,xjl)(xjlxil)repulsion term].x_i^{l+1} = x_i^{l} + \epsilon \mathbb{E} \Bigg[ \underbrace{ \kappa(x_i^l, x_j^l) \nabla_{x_j^l} \log \bar p(x_j^l) }_{\text{drift term}} - \underbrace{ \frac{1}{\sigma^2} \kappa(x_i^l, x_j^l) (x_j^l - x_i^l) }_{\text{repulsion term}} \Bigg].

In the drift term, the kernel value κ(xil,xjl)\kappa(x_i^l, x_j^l) determines how strongly the particle xilx_i^l is influenced by the score xjllogpˉ(xjl)\nabla_{x_j^l} \log \bar p(x_j^l). This term moves particles toward regions of high probability. In the repulsion term, the same kernel value controls how far xilx_i^l is pushed away from xjlx_j^l along the direction (xjlxil)(x_j^l - x_i^l). This term enforces diversity among particles, preventing them from collapsing onto a single mode.

The kernel takes its maximum value of 1 when xil=xjlx_i^l = x_j^l, and approaches 0 as xilxjl2||x_i^l - x_j^l||^2 \to \infty. Intuitively, κ(xil,xjl)\kappa(x_i^l, x_j^l) measures similarity between particles based on Euclidean distance, with the effective neighborhood size controlled by σ\sigma.

The neighborhood of x_i^l=0 as a function of x_j^l in 1D in terms of similarity \kappa(0, x_j^l) and repulsive force \nabla_{x_j^l} \kappa(0, x_j^l) for different \sigma values.

Figure 2:The neighborhood of xil=0x_i^l=0 as a function of xjlx_j^l in 1D in terms of similarity κ(0,xjl)\kappa(0, x_j^l) and repulsive force xjlκ(0,xjl)\nabla_{x_j^l} \kappa(0, x_j^l) for different σ\sigma values.

Code for Figure 2.
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

torch.manual_seed(0)

bound = 10.0
x = torch.arange(-bound, bound, 0.1).unsqueeze(-1)

sigma = torch.tensor([0.5, 2.5, 5.0])
gamma = sigma.pow(2).mul(2).pow(-1)

norm = x.pow(2).sum(-1)
k = gamma.unsqueeze(-1).mul(norm).mul(-1).exp()
k_xj = gamma.mul(2).unsqueeze(-1).unsqueeze(-1).mul(k.unsqueeze(-1)).mul(x.mul(-1))

data = pd.DataFrame(
    [
        {
            "sigma": s.item(),
            "x": x[idx, 0].item(),
            "k": k[s_idx, idx].item(),
            "k_xj": k_xj[s_idx, idx, 0].item(),
        }
        for s_idx, s in enumerate(sigma)
        for idx in range(x.shape[0])
    ]
)

r, c = 1, 2
fig, ax = plt.subplots(r, c, figsize=(6.4 * c, 4.8 * r))

if c == 1:
    ax = np.array([ax])

if r == 1:
    ax = np.array([ax])

sns.lineplot(
    data,
    x="x",
    y="k",
    hue="sigma",
    ax=ax[0, 0],
    palette=sns.color_palette("tab10"),
)
ax[0, 0].grid()
ax[0, 0].set_ylabel("$\\kappa(0, x_j^l)$")
ax[0, 0].set_xlabel("$x_j^l$")

sns.lineplot(
    data,
    x="x",
    y="k_xj",
    hue="sigma",
    ax=ax[0, 1],
    palette=sns.color_palette("tab10"),
)
ax[0, 1].grid()
ax[0, 1].set_ylabel("$\\nabla_{x_j^l} \\kappa(0, x_j^l)$")
ax[0, 1].set_xlabel("$x_j^l$")

handles, labels = ax[0, 0].get_legend_handles_labels()
labels = [f"$\\sigma = {label}$" for label in labels]

ax[0, 0].legend_.remove()
ax[0, 1].legend_.remove()

legend = fig.legend(
    handles,
    labels,
    loc="lower center",
    bbox_to_anchor=(0.42, 0.7),
    framealpha=1.0,
)

fig.savefig("bandwidth.svg", bbox_inches="tight")

Choosing an appropriate value for σ\sigma is nontrivial: both extremely small and extremely large bandwidths cause the repulsion term to vanish. A common heuristic is the median trick, which sets σ=median{xilxjl}i,j=0M1/2logM\sigma = \mathrm{median} \{ ||x_i^l - x_j^l|| \}_{i,j=0}^{M-1} / \sqrt{2 \log M}, where MM is the number of particles. This choice roughly ensures that jiκ(xil,xjl)1\sum_{j \neq i} \kappa(x_i^l, x_j^l) \approx 1.

Derivation of the SVGD Update Rule

Suppose we wish to approximate a target distribution pp using particles xq0x \sim q^0. The central idea of SVGD is to iteratively transport these particles using a smooth transformation

f(x)=x+ϵϕ(x),f(x) = x + \epsilon\, \phi(x),

where ϕ\phi is a velocity field chosen to maximally decrease the KL divergence to the target distribution.

Formally, SVGD seeks the direction

argmaxϕF(ϵDKL(qϵp)ϵ=0),\underset{\phi \in \mathcal{F}}{\mathrm{argmax}}\,\left( -\left.\nabla_{\epsilon} D_{\mathrm{KL}}(q^{\epsilon} \,\|\, p)\right|_{\epsilon=0} \right),

where qϵq^{\epsilon} denotes the distribution obtained by pushing q0q^0 through ff, and F\mathcal{F} is a suitable function space.

Assuming ff is invertible, one can derive the first-order variation of the KL divergence as

ϵDKL(qϵp)ϵ=0=Exq0[xlogp(x)ϕ(x)+Tr(xϕ(x))].-\left.\nabla_{\epsilon} D_{\mathrm{KL}}(q^{\epsilon} \,\|\, p)\right|_{\epsilon=0} = \mathbb{E}_{x\sim q^0} \left[ \nabla_{x} \log p(x)^{\top}\,\phi(x) + \operatorname{Tr}\big( \nabla_x \phi(x) \big) \right].

The first term attracts particles toward regions of high probability density, while the second encourages volume expansion and prevents particle collapse.

To obtain a tractable optimization problem, SVGD restricts ϕ\phi to a reproducing kernel Hilbert space (RKHS) HD\mathcal{H}^D with kernel κ(x,y)\kappa(x, y). Using the reproducing property of the RKHS, the KL gradient can be expressed as

ϕ(),  Exq0[κ(x,)xlogp(x)+xκ(x,)]HD.\left\langle \phi(\cdot),\; \mathbb{E}_{x\sim q^0} \big[ \kappa(x, \cdot)\nabla_{x} \log p(x) + \nabla_{x} \kappa(x, \cdot) \big] \right\rangle_{\mathcal{H}^D}.

If we constrain ϕHD1\|\phi\|_{\mathcal{H}^D} \leq 1, the inner product is maximized when ϕ\phi is proportional to the second argument. The optimal velocity field is therefore

ϕp,q0()Exq0[κ(x,)xlogp(x)+xκ(x,)].\phi_{p,q^0}(\cdot) \propto \mathbb{E}_{x\sim q^0} \Big[ \kappa(x, \cdot)\nabla_{x} \log p(x) + \nabla_{x} \kappa(x, \cdot) \Big].

For targets known only up to a normalization constant,

p(x)=pˉ(x)Z,p(x) = \frac{\bar p(x)}{Z},

the score function satisfies

xlogp(x)=xlogpˉ(x),\nabla_x \log p(x) = \nabla_x \log \bar p(x),

since the normalization constant vanishes under differentiation. Consequently, SVGD requires only the score of the unnormalized density and never needs to evaluate ZZ.

References
  1. Liu, Q., & Wang, D. (2016). Stein variational gradient descent: A general purpose bayesian inference algorithm. NeurIPS.