Reparameterization Trick & Gumbel-Softmax: A Deep Dive
Make sense of the reparameterization trick and Gumbel-Softmax: why gradients can flow through sampling, how temperature trades bias for variance, and the practical pitfalls of training discrete latent variables end-to-end.
The moment your model contains a sampling step, training hits a hard wall: how do gradients flow through a random node?
The reparameterization trick has a clean answer — rewrite $z\sim p_\theta(z)$ as $z=g_\theta(\epsilon)$ , isolating the randomness in a parameter-free noise variable $\epsilon$ , so backprop can flow through $g_\theta$ . The trouble starts with discrete variables: operations like $\arg\max$ are not differentiable. Gumbel-Softmax (a.k.a. the Concrete distribution) replaces the discrete sample with a tempered softmax over Gumbel-perturbed logits, giving you a smooth, differentiable surrogate that you can train end-to-end.
This post walks through the derivations, the intuition, the implementation details, the bias-variance trade-off behind the temperature, and the most common training pitfalls.
What You Will Learn#
- Why gradients cannot flow through $z\sim\mathcal N(\mu,\sigma^2)$ , but flow freely through $z=\mu+\sigma\epsilon$ .
- Where the Gumbel distribution comes from and why “logits + Gumbel noise + argmax = softmax sampling” is exact.
- How the Gumbel-Softmax temperature $\tau$ trades bias for variance, plus annealing in practice.
- The Straight-Through estimator (ST-GS): hard forward, soft backward.
- Variance comparison vs. REINFORCE / score-function estimators (typically 1–3 orders of magnitude lower).
- Full PyTorch implementations for both continuous and discrete VAEs, with the most common training pitfalls.
Prerequisites#
- Expectations, densities, change-of-variables.
- Basic PyTorch and the chain rule under autograd.
- Familiarity with the variational autoencoder (see the companion post VAE: From Intuition to Implementation ).
Why we need reparameterization#
$$\mathcal L(\theta) \;=\; \mathbb E_{z\sim q_\theta(z)}\,[\,f(z)\,].$$In a VAE, $f$
is the reconstruction log-likelihood; in RL, it’s a return; in discrete structure learning, it’s the downstream loss. The trouble: $\theta$
appears both inside $f$
and inside the distribution. Naively “sample $z$
, evaluate $f(z)$
, call .backward()” breaks because the computation graph is severed at the sampling step.
A blunt PyTorch example:
| |
torch.normal(mu, sigma) is a stochastic op: there is no smooth functional relationship between its output $z$
and its inputs $\mu,\sigma$
. Change $\mu$
a tiny bit and the distribution shifts, but a particular drawn $z$
does not move smoothly with it — there is no derivative to define.
Core problem: training requires $\nabla_\theta \mathbb E_{q_\theta}[f(z)]$ , but the sampling step kills backprop.
Two families of fixes exist:
- Score-function / REINFORCE: rewrite the gradient as $\mathbb E[f(z)\nabla_\theta\log q_\theta(z)]$ . Universal, works on any random variable — but the variance is enormous.
- Reparameterization: pull the randomness out of the parameters, replace it with parameter-free noise, and keep the graph differentiable. Low variance and end-to-end trainable, but requires a distribution that admits such a rewrite.
Reparameterizing continuous distributions#
The general form#
$$z \;=\; g_\theta(\epsilon),\qquad \epsilon \sim p(\epsilon),$$ $$\mathcal L(\theta) \;=\; \mathbb E_{\epsilon\sim p(\epsilon)}\,[\,f(g_\theta(\epsilon))\,].$$ $$\nabla_\theta\,\mathcal L(\theta) \;=\; \mathbb E_{\epsilon\sim p(\epsilon)}\,[\,\nabla_\theta\, f(g_\theta(\epsilon))\,].$$A Monte Carlo estimate is just: draw one (or a few) $\epsilon$ , run autograd through $f(g_\theta(\epsilon))$ .

Side-by-side: on the left, $z\sim\mathcal N(\mu,\sigma^2)$ is a stochastic node and gradients cannot flow through it; on the right, $z=\mu+\sigma\epsilon$ is a deterministic function of $\mu,\sigma$ and gradients flow through the green path back into the encoder.
The Gaussian case#
$$z \;=\; \mu \;+\; \sigma \odot \epsilon,\qquad \epsilon \sim \mathcal N(0,I).$$Verify on both sides: (i) $\mathbb E[z]=\mu$ , $\mathrm{Var}(z)=\sigma^2$ ; (ii) since linear transforms preserve normality, $z$ ’s marginal is exactly $\mathcal N(\mu,\sigma^2)$ . Crucially, $\partial z/\partial\mu=1$ and $\partial z/\partial\sigma=\epsilon$ — both trivial expressions, gradients are unobstructed.
In PyTorch, predict $\log\sigma^2$ rather than $\sigma$ for numerical stability:
| |
In the VAE#
$$ \mathcal L_{\text{ELBO}}(\theta,\phi;x) \;=\; \mathbb E_{q_\phi(z|x)}\!\bigl[\log p_\theta(x|z)\bigr] \;-\;\mathrm{KL}\!\bigl(q_\phi(z|x)\,\Vert\,p(z)\bigr). $$ $$ \nabla_\phi\,\mathbb E_{q_\phi(z|x)}[\log p_\theta(x|z)] \;=\;\mathbb E_{\epsilon\sim\mathcal N(0,I)}\!\bigl[\,\nabla_\phi\log p_\theta(x\mid \mu_\phi+\sigma_\phi\epsilon)\,\bigr]. $$ $$\mathrm{KL}=\tfrac12\sum_j\!\bigl(\mu_j^2+\sigma_j^2-1-\log\sigma_j^2\bigr),$$so no Monte Carlo is needed there.
Which continuous distributions are “naturally” reparameterizable?#
Anything that admits a location-scale form, or a base-noise + differentiable transform:
| Distribution | Reparameterization |
|---|---|
| $\mathcal N(\mu,\sigma^2)$ | $z=\mu+\sigma\epsilon$ , $\epsilon\sim\mathcal N(0,1)$ |
| $\mathrm{Logistic}(\mu,s)$ | $z=\mu+s\,\log\frac{u}{1-u}$ , $u\sim\mathrm U(0,1)$ |
| $\mathrm{Laplace}(\mu,b)$ | $z=\mu-b\,\mathrm{sign}(u)\log(1-2\lvert u\rvert)$ , $u\sim\mathrm U(-\tfrac12,\tfrac12)$ |
| $\mathrm{Exp}(\lambda)$ | $z=-\tfrac1\lambda\log(1-u)$ , $u\sim\mathrm U(0,1)$ |
| $\mathrm{Gumbel}(\mu,\beta)$ | $z=\mu-\beta\log(-\log u)$ , $u\sim\mathrm U(0,1)$ |
Counter-examples: Gamma, Beta, Dirichlet, Student-t do not have a simple location-scale form. They need implicit reparameterization gradients or pathwise tricks (Figurnov et al., NeurIPS 2018; see §7).
Reparameterizing discrete distributions: the Gumbel-Max trick#
The difficulty#
Take a $K$ -class categorical $\mathrm{Cat}(\pi_1,\dots,\pi_K)$ with $\pi_i=\mathrm{softmax}(\alpha)_i=\frac{\exp(\alpha_i)}{\sum_j\exp(\alpha_j)}$ . The naive “sample” is: compute $\pi$ , draw a class index $k$ from a multinomial. That step is completely non-differentiable — its output is a one-hot vector, with no notion of smooth variation.
$$ \nabla_\alpha\,\mathbb E_{k\sim\mathrm{Cat}(\pi)}[f(k)] \;=\;\mathbb E_{k}\bigl[f(k)\,\nabla_\alpha\log\pi_k\bigr], $$universal but with variance large enough to wreck training stability. Can we, like in the Gaussian case, factor sampling into “noise + deterministic transform”? The answer is the Gumbel-Max trick.
A quick tour of the Gumbel distribution#
$$F(g)=\exp(-e^{-g}),\qquad f(g)=\exp\!\bigl(-(g+e^{-g})\bigr).$$Mode at $0$ , mean at the Euler–Mascheroni constant $\gamma\approx 0.5772$ . Crucially, it is the extreme-value distribution for maxima of iid samples (under suitable scaling). Gumbel-Max exploits exactly this property.
$$u\sim\mathrm U(0,1) \;\Rightarrow\; g=-\log(-\log u)\sim\mathrm{Gumbel}(0,1).$$
Left: the Gumbel(0,1) PDF; middle: inverse-CDF sampling visualized; right: 20k samples drawn via $-\log(-\log u)$ , with the empirical histogram matching the analytic PDF.
The Gumbel-Max trick#
$$k^\star \;=\; \arg\max_i\,(\alpha_i + g_i).$$Then $k^\star\sim\mathrm{Cat}(\mathrm{softmax}(\alpha))$ . In words, adding Gumbel noise to logits and taking argmax is equivalent to sampling from the softmax distribution.
Proof (probability of class 1)#
$\Pr[k^\star=1] = \Pr\bigl[\,\forall i\neq 1:\;\alpha_1+g_1>\alpha_i+g_i\,\bigr] = \Pr\bigl[\,\forall i\neq 1:\;g_i<\alpha_1-\alpha_i+g_1\,\bigr]$ .
$$ \Pr[k^\star=1\mid g_1=t]=\prod_{i\neq 1}F(\alpha_1-\alpha_i+t) =\prod_{i\neq 1}\exp\!\bigl(-e^{-(\alpha_1-\alpha_i+t)}\bigr) =\exp\!\Bigl(-e^{-t}\!\!\sum_{i\neq 1}e^{\alpha_i-\alpha_1}\Bigr). $$ $$ \Pr[k^\star=1]=\int e^{-(t+e^{-t})}\cdot\exp(-S e^{-t})\,dt =\int e^{-t}\exp\!\bigl(-(1+S)e^{-t}\bigr)\,dt. $$Substituting $u=(1+S)e^{-t}$ , $du=-(1+S)e^{-t}dt$ gives $\frac{1}{1+S}\int_0^\infty e^{-u}du=\frac{1}{1+S}$ .
$$\Pr[k^\star=1]=\frac{e^{\alpha_1}}{\sum_i e^{\alpha_i}}=\mathrm{softmax}(\alpha)_1. \quad\blacksquare$$
Left: target categorical distribution; middle: a single draw — adding Gumbel noise to logits and taking the argmax selects class c2; right: 8000 draws — the empirical frequencies match the target softmax probabilities, confirming exact equivalence in expectation.
Why this matters#
- Efficient: $K$ uniforms + one argmax. No explicit normalization, no cumulative-distribution lookup.
- Numerically stable: addition followed by argmax is invariant to constant shifts in the logits.
- Most importantly: all the randomness now lives in $g$ ; the logits $\alpha$ enter through the deterministic $\alpha+g$ . We are one step away from “differentiable” — just need to soften the $\arg\max$ .
Gumbel-Softmax: softening the argmax#
Definition#
$$ y_i \;=\; \frac{\exp\!\bigl((\alpha_i + g_i)/\tau\bigr)}{\sum_{j=1}^K\exp\!\bigl((\alpha_j + g_j)/\tau\bigr)}, \qquad g_i\overset{iid}{\sim}\mathrm{Gumbel}(0,1). $$Here $y\in\Delta^{K-1}$ , the $(K-1)$ -simplex — a continuous vector. The two limits give intuition:
- $\tau\to 0^+$ : softmax degenerates into the indicator of the argmax, $y$ becomes one-hot — a true discrete sample;
- $\tau\to\infty$ : $(\alpha_i+g_i)/\tau\to 0$ , $y$ becomes the uniform vector $(1/K,\dots,1/K)$ .
Because $g$ is independent of $\alpha$ and softmax is everywhere differentiable, $y$ is differentiable in $\alpha$ (and therefore in any upstream parameter $\theta$ ). This is the heart of Gumbel-Softmax: a smooth, differentiable proxy for the discrete one-hot.
The temperature bias-variance trade-off#

Four temperatures: in each panel the light-blue bars show the target softmax, while the orange polylines are five independent Gumbel-Softmax draws.
- $\tau=5$ : very smooth, near-uniform — low gradient variance but highly biased away from a true discrete sample;
- $\tau=1$ : balanced;
- $\tau=0.5$ : starts looking one-hot;
- $\tau=0.1$ : nearly hard one-hot — low bias but high variance (and softmax internals are prone to over/underflow).
You can read this off the gradient estimator directly: smaller $\tau$ makes $y$ closer to the true one-hot $z$ , reducing bias; but the softmax Jacobian $\partial y/\partial\alpha$ scales like $1/\tau$ , so variance grows like $\tau^{-2}$ .
Practical annealing:
- Start at $\tau\in[1.0, 2.0]$ , end at $\tau\in[0.1, 0.5]$ .
- Exponential schedule $\tau_t=\max(\tau_{\min},\tau_0\,e^{-rt})$ with a check every $\sim 1000$ steps.
- Don’t drive $\tau\to 0$ — softmax numerics blow up and gradient variance dominates. Let the model “see” the soft distribution early, then “harden” it later.
Straight-Through Gumbel-Softmax (ST-GS)#
$$ \boxed{ \;\;y_{\text{hard}}=\mathrm{onehot}(\arg\max_i y_i), \qquad \tilde y \;=\; y_{\text{hard}}\;-\;\mathrm{stop\_grad}(y_{\text{soft}})\;+\;y_{\text{soft}}.\;\; } $$Forward: $\tilde y=y_{\text{hard}}$ (strict one-hot). Backward: $\partial\tilde y/\partial\alpha=\partial y_{\text{soft}}/\partial\alpha$ (the soft Jacobian). It’s a biased but low-variance estimator: hard sample for the loss, soft sample for the gradient.
In PyTorch:
| |
PyTorch ships torch.nn.functional.gumbel_softmax(logits, tau, hard) with the same semantics; rolling your own is just to make every step explicit.
Comparison with REINFORCE: orders-of-magnitude variance gap#
Score-function / REINFORCE estimator#
$$ \nabla_\theta\,\mathbb E_{z\sim q_\theta}[f(z)] =\mathbb E_{z\sim q_\theta}\!\bigl[f(z)\,\nabla_\theta\log q_\theta(z)\bigr]. $$Its strength: it makes no assumption about differentiability of $z$ in $\theta$ — discrete, control-flow, even black-box external calls all work. Its weakness: very high variance. Intuitively, the entire scalar value $f(z)$ rides on the gradient with no differentiable path to cancel signs.
Standard variance reductions:
- Baselines / control variates: replace $f(z)$ with $f(z)-b$ for some $z$ -independent $b$ (e.g. a running average).
- Rao–Blackwellization, antithetic sampling, importance sampling, etc.
Even with all these tricks, REINFORCE typically remains 1–3 orders of magnitude noisier than reparameterization.
Empirical comparison#
We estimate $\nabla_{\alpha_0}\,\mathbb E_z[r^\top z]$ on a synthetic 8-class categorical (with a fixed reward vector $r$ ). Each curve aggregates 200 trials.

Left pipeline: logits → +Gumbel noise → /τ → softmax → $y_{\text{soft}}$ ; optional argmax → $y_{\text{hard}}$ , combined via STE for hard-forward / soft-backward. Right curves: $x$ -axis is the number of MC samples per gradient estimate, $y$ -axis is the variance of the estimate (log-log). Gumbel-Softmax with $\tau=0.5$ is consistently more than an order of magnitude lower than REINFORCE; both decay as $1/n$ (slope $-1$ ).
This is exactly why end-to-end discrete training only became practical once reparameterization-based estimators were available.
Full PyTorch: continuous and discrete VAEs#
Continuous VAE (reparameterization)#
| |
Implementation notes:
- Use
binary_cross_entropy_with_logitsinstead ofBCELoss + sigmoidto avoid numerical issues in saturation regions. - Compute the KL term in closed form (rather than via sampling) to reduce gradient variance.
- Optionally anneal KL ($\beta$ in $\beta$ -VAE from 0 → 1) to mitigate posterior collapse.
Discrete-latent VAE (categorical + Gumbel-Softmax)#
| |
Implementation notes:
- Treat each latent as a $K$
-way categorical; logits have shape
[B, n_cat, K], softmax over the last dimension. - KL has the closed form $\mathrm{KL}(q\|\mathrm{Uniform})=\log K - H(q)$ — no sampling required.
hard=Trueenables ST-GS (forces a true discrete forward pass);hard=Falsekeeps gradients smoother.- Anneal:
tau0=1.0,tau_min=0.5,r=1e-5, checked every ~1k steps.
Common training pitfalls#
| Symptom | Likely cause | Fix |
|---|---|---|
| NaN loss | $\tau$ annealed too low; softmax over/underflow; or $u=0$ in $\log(-\log u)$ | clamp $u\in[\epsilon, 1-\epsilon]$ ; keep $\tau_{\min}\ge 0.1$ |
| Gradient variance explodes | One MC sample + tiny $\tau$ | More MC samples per batch; slow the $\tau$ schedule |
| Discrete structure not actually “hard” | Using soft outputs at eval time | Set hard=True at inference, or argmax; switch to ST-GS late in training |
| Posterior collapse (continuous VAE) | KL too strong from the start | Anneal $\beta$ from 0 → 1; use free-bits |
| Discrete VAE doesn’t learn structure | Decoder too powerful, KL too weak | Reduce decoder capacity; lower KL weight, then raise it back |
| Gradients insensitive to logits | $\tau$ too large, output near-uniform | Lower $\tau$ or raise the learning rate |
| Forward / backward mismatch | ST-GS line written incorrectly | Re-check y_hard - y_soft.detach() + y_soft |
Recent work#
- Implicit Reparameterization Gradients (Figurnov et al., NeurIPS 2018) — uses the implicit function theorem to make Gamma, Beta, Dirichlet, Student-t reparameterizable, with low bias and clean derivations.
- REBAR / RELAX (Tucker et al., NeurIPS 2017; Grathwohl et al., ICLR 2018) — combine Gumbel-Softmax with REINFORCE and learn a control variate via a neural network, yielding unbiased and lower-variance discrete gradient estimators.
- Hard Concrete gates (Louizos et al., ICLR 2018) — stretch and clip Concrete/Gumbel-Softmax samples to $[0,1]$ , producing differentiable gates with literal zeros for $L_0$ regularization and sparsification.
- Top-$k$ Gumbel and Plackett–Luce — extend Gumbel-Max to without-replacement sampling of $k$ classes; useful for sparse attention and routing.
- Permutation-equivariant relaxations (Mena et al., ICLR 2018) — Sinkhorn operator + Gumbel noise to differentiate through permutation matrices.
Summary#
- In continuous settings, reparameterization rewrites a random variable $z$ as a deterministic, differentiable transform $g_\theta(\epsilon)$ of a fixed noise — gradients flow through the deterministic path back to the parameters. This is what makes VAEs trainable with SGD.
- In discrete settings, Gumbel-Max gives the exact equivalence “logits + Gumbel noise + argmax = softmax sample”; Gumbel-Softmax then softens the argmax into a temperature-scaled softmax, making the entire sampling step differentiable.
- The temperature $\tau$ is the central knob: small $\tau$ → low bias / high variance; large $\tau$ → the opposite. Annealing (large → small) is the key to stable training.
- When a strict discrete forward pass is required, use ST-GS: hard forward, soft backward — the most common engineering compromise.
- Versus REINFORCE, reparameterization-based estimators give 1–3 orders of magnitude lower gradient variance — without that, end-to-end discrete training would not be practical.
References#
- E. Jang, S. Gu, B. Poole. “Categorical Reparameterization with Gumbel-Softmax.” ICLR, 2017.
- C. Maddison, A. Mnih, Y. W. Teh. “The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables.” ICLR, 2017.
- D. P. Kingma, M. Welling. “Auto-Encoding Variational Bayes.” ICLR, 2014.
- M. Figurnov, S. Mohamed, A. Mnih. “Implicit Reparameterization Gradients.” NeurIPS, 2018.
- G. Tucker et al. “REBAR: Low-variance, unbiased gradient estimates for discrete latent variable models.” NeurIPS, 2017.
- C. Louizos, M. Welling, D. P. Kingma. “Learning Sparse Neural Networks through $L_0$ Regularization.” ICLR, 2018.
- W. Kool, H. van Hoof, M. Welling. “Stochastic Beams and Where to Find Them: The Gumbel-Top-$k$ Trick for Sampling Sequences Without Replacement.” ICML, 2019.