Nov 9, 2025

Wasserstein Auto-Encoder Research Reveals 33% Quality Gains With Hidden Training Constraints

Conor Bronsdon

Head of Developer Awareness

Conor Bronsdon

Head of Developer Awareness

How WAEs Beat VAEs by 33% Yet Hit Memory Limits | Galileo
How WAEs Beat VAEs by 33% Yet Hit Memory Limits | Galileo

Generative models face a fundamental trade-off: VAEs train stably but produce blurry samples, while GANs generate sharp images but train unpredictably.

This ICLR paper from Max Planck Institute and Google Brain introduced Wasserstein Auto-Encoders (WAEs) to solve both problems through optimal transport theory. The experiments on MNIST and CelebA show WAEs achieve 12-33% FID improvements over VAEs while maintaining training stability.

Yet challenges emerge beyond the benchmarks. WAE-MMD's O(m²) complexity limits batch sizes. Hyperparameter λ jumps 10x between datasets. Kernel selection is critical for successful training.

This article breaks down the MNIST (70,000 images) and CelebA (203,000 images) experiments, revealing four training challenges and what the results mean for choosing between WAE variants.

Summary: What WAEs deliver in production

WAEs minimize the penalized Wasserstein distance between the data distribution and model distribution. The framework was tested on MNIST (70,000 images) and CelebA (203,000 images) using DCGAN-style architectures.

Here's what the paper found:

  • Sample quality improved significantly. WAE-GAN achieved FID 42 versus VAE's 63 on CelebA, a 33% improvement. WAE-MMD reached FID 55 (12.7% improvement). Both variants generated 2x sharper samples measured at 6×10⁻³ versus VAE's 3×10⁻³.

  • Training stability depends on variant choice. The paper reports WAE-MMD "has a very stable training much like VAE"—no discriminator balancing, no mode collapse concerns. WAE-GAN delivers better quality but requires careful tuning with precise learning rate ratios throughout training.

  • The framework offers architectural flexibility. Unlike VAEs requiring stochastic encoders, WAEs support deterministic encoders. The key difference: WAEs match the aggregated posterior QZ to prior PZ rather than forcing per-sample matching.

  • Extended experiments narrowed the quality gap. Training 3,000+ models on TPU hardware with wider hyperparameter sweeps, bigWAE-GAN reached FID 35 while bigWAE-MMD achieved 37—just 6% difference versus 24% in standard experiments.

WAE-MMD's O(m²) computational complexity creates memory constraints you can't ignore. Hyperparameter tuning proved dataset-specific; λ jumped 10x from MNIST to CelebA. Kernel selection became critical: RBF kernels failed completely, requiring inverse multiquadratics for successful training.

Learn when to use multi-agent systems, how to design them efficiently, and how to build reliable systems that work in production.

The four challenges of training Wasserstein auto-encoders

You might think your generative model evaluation is complete after measuring FID scores and reconstruction error. The MNIST and CelebA experiments reveal otherwise.

Four training challenges emerge that affect your ability to scale, tune hyperparameters, and reproduce benchmark results. Each challenge represents a constraint the paper discovered through systematic experimentation. Challenge #1: O(m²) computational complexity creates memory bottlenecks

Your WAE-MMD trains on small batches. You want larger batch sizes for better optimization, but computational complexity becomes the constraint.

Here's why. WAE-MMD computes Maximum Mean Discrepancy through kernel matrices requiring m² operations for m encoded samples per minibatch. Algorithm 2 in the paper shows the MMD calculation:

λ/(n(n-1)) Σ k(z_l, z_j) + λ/(n(n-1)) Σ k(z̃_l, z̃_j) - 2λ/n² Σ k(z_l, z̃_j)

This unbiased U-statistic requires kernel evaluations for every sample pair. Standard VAEs compute in O(m) time via element-wise KL-divergence calculations that scale linearly with the batch size.

The paper's experimental setup reflects this constraint. Both MNIST and CelebA experiments used batch size 100 across all models—VAE, WAE-MMD, and WAE-GAN.

The paper used a batch size of 100 for both MNIST and CelebA experiments across all models—VAE, WAE-MMD, and WAE-GAN. GPU memory requirements scale quadratically with batch size for kernel matrix storage beyond standard forward-backward passes.

Here is what the numbers show. At batch size 100, the kernel computation processes 100×100 = 10,000 pairwise evaluations. Double the batch size to 200, and you need 40,000 evaluations—4x the computation for 2x the samples. VAE's element-wise operations scale linearly: 200 samples require 2x the computation of 100 samples.

For high-dimensional latent spaces where O(m²) computations become prohibitive, the paper concludes by identifying the investigation of Sliced-Wasserstein variants as future work.

Challenge #2: Hyperparameter sensitivity complicates tuning across datasets

Hyperparameters don't transfer between datasets. What works for MNIST fails on CelebA.

The regularization coefficient λ jumps by orders of magnitude:

  • MNIST WAE-MMD: λ=10

  • CelebA WAE-MMD: λ=100 (10x jump)

  • CelebA WAE-GAN: λ=1 (100x difference on same data)

Prior variance and latent dimensions changed between datasets. MNIST used σ²_z=1 with 8-dimensional latent space. CelebA required σ²_z=2 with 64-dimensional latent space. The paper explains why: "Choosing dz larger than intrinsic dimensionality of the dataset would force the encoded distribution QZ to live on a manifold in Z. This would make matching QZ to PZ impossible if PZ is Gaussian and may lead to numerical instabilities."

Learning rates don't follow a consistent pattern. MNIST used α=10⁻³ for encoder-decoder and α=5×10⁻⁴ for the adversary. CelebA WAE-GAN flipped this: α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for adversary. CelebA also needed multiple adjustments during training: rates decreased by factor of 2 after 30 epochs, factor of 5 after 50 epochs, and factor of 10 after 100 epochs.

Kernel choice matters for WAE-MMD. The paper reports "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay," forcing a switch to inverse multiquadratics kernels.

The paper states: "even slight differences between QZ and PZ may affect the quality of samples." Small mismatches degrade sample quality noticeably.

Challenge #3: Kernel selection failures require architecture changes

Your WAE-MMD starts training. Some encoded samples end up far from the prior's support. RBF kernel penalties drop to zero for these outliers, and the encoder gets no signal to improve.

The paper hit this exact problem. They report "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay." When encoded codes z̃ = µ_φ(x) land far from PZ support during early training, the kernel terms k(z, z̃) = e^(-||z̃-z||²₂/σ²_k) rapidly approach zero. No gradient for those outliers means no training progress.

The catch with RBF kernels? They fail exactly when you need them most. Early training produces encoded samples scattered across latent space. RBF values decay exponentially—distant samples get zero penalty and zero gradient. Training stalls because the encoder can't learn how to pull outliers toward the prior.

The fix required switching kernels entirely. Inverse multiquadratics k(x,y) = C/(C + ||x - y||²₂) have much heavier tails that maintain gradients even for distant points. The paper used C = 2dz σ²_z—the expected squared distance between two Gaussian vectors drawn from PZ. This matches the kernel scale to the prior's natural spread.

The performance difference was dramatic. The paper reports this "significantly improved performance" compared to RBF kernels, even those with matching bandwidth σ²_k = 2dz σ²_z. Kernel choice "proved critical for WAE-MMD's ability to match QZ to PZ during training."

Challenge #4: Sample quality vs training stability trade-offs

Your WAE-GAN achieves FID score 42 on CelebA. Your WAE-MMD gets 55. Both crush VAE's 63. But here's the catch: that extra quality from WAE-GAN comes with adversarial training headaches.

The quality numbers tell the story:

  • WAE-GAN: FID 42, sharpness 6×10⁻³

  • WAE-MMD: FID 55, sharpness 6×10⁻³

  • VAE: FID 63, sharpness 3×10⁻³

The paper reports "in some cases WAE-GAN seems to lead to better matching and generates better samples."

But WAE-GAN trains less predictably. The paper states WAE-GAN "is less stable than WAE-MMD" due to adversarial training. WAE-MMD "has a very stable training much like VAE"—no discriminator networks to balance, no mode collapse to watch for.

Training WAE-GAN demands precision. CelebA experiments used α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for the adversary. WAE-MMD requires no discriminator balancing across experiments.

Sample quality hinges on distribution matching. The paper states "the quality of samples strongly depends on how accurately QZ matches PZ" and "even slight differences between QZ and PZ may affect the quality of samples."

Here's why: during training, the decoder sees only samples from QZ (the encoded training data). When you generate by sampling z ~ PZ, the decoder encounters a distribution it wasn't trained on. Small mismatches translate to visible quality drops.

Extended tuning closes the gap. The paper trained 3,000+ models on 8 TPU-v2 accelerators (100,000 steps each). bigWAE-GAN reached FID 35 while bigWAE-MMD hit 37—just 6% difference versus the original 24% gap.

Key takeaways

The experiments reveal specific configurations that work and trade-offs you'll encounter when training WAEs:

  • WAE-GAN delivers 33% better FID but is less stable. WAE-MMD achieves 12.7% improvement with VAE-like stability—no discriminator balancing needed.

  • Hyperparameters don't transfer between datasets. Regularization λ jumped 10x from MNIST (10) to CelebA (100). On identical data, λ varied 100x between variants.

  • RBF kernels fail—use inverse multiquadratics. The paper reports RBF kernels "fail to penalize outliers because of quick tail decay." Inverse multiquadratics "significantly improved performance."

  • Distribution matching determines sample quality. The paper states "even slight differences between QZ and PZ may affect sample quality." The decoder trains only on QZ, so mismatches degrade generation.

  • Extended tuning narrows the quality gap. With 3,000+ models on TPU hardware, bigWAE-GAN (FID 35) versus bigWAE-MMD (FID 37) showed 6% difference—down from 24%.

  • O(m²) complexity limits batch sizes. The paper used batch size 100 across all experiments due to quadratic memory scaling.

Final thoughts

The ICLR experiments expose a reality check: WAEs deliver measurable improvements with FID scores 12-33% better than VAEs and 2x sharper samples, yet the paper tested only MNIST and CelebA benchmark datasets.

Deployment constraints remain. O(m²) complexity constrained batch sizes to 100 in the experiments. Hyperparameters vary 10-100x across datasets with no transfer patterns. Kernel selection proves critical—RBF kernels fail completely.

The gap between controlled benchmark performance and real-world deployment stays largely unexplored.

Moving forward means understanding that benchmark improvements don't automatically translate to production success. The paper's four challenges make one thing clear: you need systematic evaluation beyond FID scores to validate whether generative models actually work for your use case.

Explore how Galileo helps teams evaluate AI systems systematically, tracking performance across multiple dimensions without relying on scattered benchmarks or manual validation.

Generative models face a fundamental trade-off: VAEs train stably but produce blurry samples, while GANs generate sharp images but train unpredictably.

This ICLR paper from Max Planck Institute and Google Brain introduced Wasserstein Auto-Encoders (WAEs) to solve both problems through optimal transport theory. The experiments on MNIST and CelebA show WAEs achieve 12-33% FID improvements over VAEs while maintaining training stability.

Yet challenges emerge beyond the benchmarks. WAE-MMD's O(m²) complexity limits batch sizes. Hyperparameter λ jumps 10x between datasets. Kernel selection is critical for successful training.

This article breaks down the MNIST (70,000 images) and CelebA (203,000 images) experiments, revealing four training challenges and what the results mean for choosing between WAE variants.

Summary: What WAEs deliver in production

WAEs minimize the penalized Wasserstein distance between the data distribution and model distribution. The framework was tested on MNIST (70,000 images) and CelebA (203,000 images) using DCGAN-style architectures.

Here's what the paper found:

  • Sample quality improved significantly. WAE-GAN achieved FID 42 versus VAE's 63 on CelebA, a 33% improvement. WAE-MMD reached FID 55 (12.7% improvement). Both variants generated 2x sharper samples measured at 6×10⁻³ versus VAE's 3×10⁻³.

  • Training stability depends on variant choice. The paper reports WAE-MMD "has a very stable training much like VAE"—no discriminator balancing, no mode collapse concerns. WAE-GAN delivers better quality but requires careful tuning with precise learning rate ratios throughout training.

  • The framework offers architectural flexibility. Unlike VAEs requiring stochastic encoders, WAEs support deterministic encoders. The key difference: WAEs match the aggregated posterior QZ to prior PZ rather than forcing per-sample matching.

  • Extended experiments narrowed the quality gap. Training 3,000+ models on TPU hardware with wider hyperparameter sweeps, bigWAE-GAN reached FID 35 while bigWAE-MMD achieved 37—just 6% difference versus 24% in standard experiments.

WAE-MMD's O(m²) computational complexity creates memory constraints you can't ignore. Hyperparameter tuning proved dataset-specific; λ jumped 10x from MNIST to CelebA. Kernel selection became critical: RBF kernels failed completely, requiring inverse multiquadratics for successful training.

Learn when to use multi-agent systems, how to design them efficiently, and how to build reliable systems that work in production.

The four challenges of training Wasserstein auto-encoders

You might think your generative model evaluation is complete after measuring FID scores and reconstruction error. The MNIST and CelebA experiments reveal otherwise.

Four training challenges emerge that affect your ability to scale, tune hyperparameters, and reproduce benchmark results. Each challenge represents a constraint the paper discovered through systematic experimentation. Challenge #1: O(m²) computational complexity creates memory bottlenecks

Your WAE-MMD trains on small batches. You want larger batch sizes for better optimization, but computational complexity becomes the constraint.

Here's why. WAE-MMD computes Maximum Mean Discrepancy through kernel matrices requiring m² operations for m encoded samples per minibatch. Algorithm 2 in the paper shows the MMD calculation:

λ/(n(n-1)) Σ k(z_l, z_j) + λ/(n(n-1)) Σ k(z̃_l, z̃_j) - 2λ/n² Σ k(z_l, z̃_j)

This unbiased U-statistic requires kernel evaluations for every sample pair. Standard VAEs compute in O(m) time via element-wise KL-divergence calculations that scale linearly with the batch size.

The paper's experimental setup reflects this constraint. Both MNIST and CelebA experiments used batch size 100 across all models—VAE, WAE-MMD, and WAE-GAN.

The paper used a batch size of 100 for both MNIST and CelebA experiments across all models—VAE, WAE-MMD, and WAE-GAN. GPU memory requirements scale quadratically with batch size for kernel matrix storage beyond standard forward-backward passes.

Here is what the numbers show. At batch size 100, the kernel computation processes 100×100 = 10,000 pairwise evaluations. Double the batch size to 200, and you need 40,000 evaluations—4x the computation for 2x the samples. VAE's element-wise operations scale linearly: 200 samples require 2x the computation of 100 samples.

For high-dimensional latent spaces where O(m²) computations become prohibitive, the paper concludes by identifying the investigation of Sliced-Wasserstein variants as future work.

Challenge #2: Hyperparameter sensitivity complicates tuning across datasets

Hyperparameters don't transfer between datasets. What works for MNIST fails on CelebA.

The regularization coefficient λ jumps by orders of magnitude:

  • MNIST WAE-MMD: λ=10

  • CelebA WAE-MMD: λ=100 (10x jump)

  • CelebA WAE-GAN: λ=1 (100x difference on same data)

Prior variance and latent dimensions changed between datasets. MNIST used σ²_z=1 with 8-dimensional latent space. CelebA required σ²_z=2 with 64-dimensional latent space. The paper explains why: "Choosing dz larger than intrinsic dimensionality of the dataset would force the encoded distribution QZ to live on a manifold in Z. This would make matching QZ to PZ impossible if PZ is Gaussian and may lead to numerical instabilities."

Learning rates don't follow a consistent pattern. MNIST used α=10⁻³ for encoder-decoder and α=5×10⁻⁴ for the adversary. CelebA WAE-GAN flipped this: α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for adversary. CelebA also needed multiple adjustments during training: rates decreased by factor of 2 after 30 epochs, factor of 5 after 50 epochs, and factor of 10 after 100 epochs.

Kernel choice matters for WAE-MMD. The paper reports "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay," forcing a switch to inverse multiquadratics kernels.

The paper states: "even slight differences between QZ and PZ may affect the quality of samples." Small mismatches degrade sample quality noticeably.

Challenge #3: Kernel selection failures require architecture changes

Your WAE-MMD starts training. Some encoded samples end up far from the prior's support. RBF kernel penalties drop to zero for these outliers, and the encoder gets no signal to improve.

The paper hit this exact problem. They report "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay." When encoded codes z̃ = µ_φ(x) land far from PZ support during early training, the kernel terms k(z, z̃) = e^(-||z̃-z||²₂/σ²_k) rapidly approach zero. No gradient for those outliers means no training progress.

The catch with RBF kernels? They fail exactly when you need them most. Early training produces encoded samples scattered across latent space. RBF values decay exponentially—distant samples get zero penalty and zero gradient. Training stalls because the encoder can't learn how to pull outliers toward the prior.

The fix required switching kernels entirely. Inverse multiquadratics k(x,y) = C/(C + ||x - y||²₂) have much heavier tails that maintain gradients even for distant points. The paper used C = 2dz σ²_z—the expected squared distance between two Gaussian vectors drawn from PZ. This matches the kernel scale to the prior's natural spread.

The performance difference was dramatic. The paper reports this "significantly improved performance" compared to RBF kernels, even those with matching bandwidth σ²_k = 2dz σ²_z. Kernel choice "proved critical for WAE-MMD's ability to match QZ to PZ during training."

Challenge #4: Sample quality vs training stability trade-offs

Your WAE-GAN achieves FID score 42 on CelebA. Your WAE-MMD gets 55. Both crush VAE's 63. But here's the catch: that extra quality from WAE-GAN comes with adversarial training headaches.

The quality numbers tell the story:

  • WAE-GAN: FID 42, sharpness 6×10⁻³

  • WAE-MMD: FID 55, sharpness 6×10⁻³

  • VAE: FID 63, sharpness 3×10⁻³

The paper reports "in some cases WAE-GAN seems to lead to better matching and generates better samples."

But WAE-GAN trains less predictably. The paper states WAE-GAN "is less stable than WAE-MMD" due to adversarial training. WAE-MMD "has a very stable training much like VAE"—no discriminator networks to balance, no mode collapse to watch for.

Training WAE-GAN demands precision. CelebA experiments used α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for the adversary. WAE-MMD requires no discriminator balancing across experiments.

Sample quality hinges on distribution matching. The paper states "the quality of samples strongly depends on how accurately QZ matches PZ" and "even slight differences between QZ and PZ may affect the quality of samples."

Here's why: during training, the decoder sees only samples from QZ (the encoded training data). When you generate by sampling z ~ PZ, the decoder encounters a distribution it wasn't trained on. Small mismatches translate to visible quality drops.

Extended tuning closes the gap. The paper trained 3,000+ models on 8 TPU-v2 accelerators (100,000 steps each). bigWAE-GAN reached FID 35 while bigWAE-MMD hit 37—just 6% difference versus the original 24% gap.

Key takeaways

The experiments reveal specific configurations that work and trade-offs you'll encounter when training WAEs:

  • WAE-GAN delivers 33% better FID but is less stable. WAE-MMD achieves 12.7% improvement with VAE-like stability—no discriminator balancing needed.

  • Hyperparameters don't transfer between datasets. Regularization λ jumped 10x from MNIST (10) to CelebA (100). On identical data, λ varied 100x between variants.

  • RBF kernels fail—use inverse multiquadratics. The paper reports RBF kernels "fail to penalize outliers because of quick tail decay." Inverse multiquadratics "significantly improved performance."

  • Distribution matching determines sample quality. The paper states "even slight differences between QZ and PZ may affect sample quality." The decoder trains only on QZ, so mismatches degrade generation.

  • Extended tuning narrows the quality gap. With 3,000+ models on TPU hardware, bigWAE-GAN (FID 35) versus bigWAE-MMD (FID 37) showed 6% difference—down from 24%.

  • O(m²) complexity limits batch sizes. The paper used batch size 100 across all experiments due to quadratic memory scaling.

Final thoughts

The ICLR experiments expose a reality check: WAEs deliver measurable improvements with FID scores 12-33% better than VAEs and 2x sharper samples, yet the paper tested only MNIST and CelebA benchmark datasets.

Deployment constraints remain. O(m²) complexity constrained batch sizes to 100 in the experiments. Hyperparameters vary 10-100x across datasets with no transfer patterns. Kernel selection proves critical—RBF kernels fail completely.

The gap between controlled benchmark performance and real-world deployment stays largely unexplored.

Moving forward means understanding that benchmark improvements don't automatically translate to production success. The paper's four challenges make one thing clear: you need systematic evaluation beyond FID scores to validate whether generative models actually work for your use case.

Explore how Galileo helps teams evaluate AI systems systematically, tracking performance across multiple dimensions without relying on scattered benchmarks or manual validation.

Generative models face a fundamental trade-off: VAEs train stably but produce blurry samples, while GANs generate sharp images but train unpredictably.

This ICLR paper from Max Planck Institute and Google Brain introduced Wasserstein Auto-Encoders (WAEs) to solve both problems through optimal transport theory. The experiments on MNIST and CelebA show WAEs achieve 12-33% FID improvements over VAEs while maintaining training stability.

Yet challenges emerge beyond the benchmarks. WAE-MMD's O(m²) complexity limits batch sizes. Hyperparameter λ jumps 10x between datasets. Kernel selection is critical for successful training.

This article breaks down the MNIST (70,000 images) and CelebA (203,000 images) experiments, revealing four training challenges and what the results mean for choosing between WAE variants.

Summary: What WAEs deliver in production

WAEs minimize the penalized Wasserstein distance between the data distribution and model distribution. The framework was tested on MNIST (70,000 images) and CelebA (203,000 images) using DCGAN-style architectures.

Here's what the paper found:

  • Sample quality improved significantly. WAE-GAN achieved FID 42 versus VAE's 63 on CelebA, a 33% improvement. WAE-MMD reached FID 55 (12.7% improvement). Both variants generated 2x sharper samples measured at 6×10⁻³ versus VAE's 3×10⁻³.

  • Training stability depends on variant choice. The paper reports WAE-MMD "has a very stable training much like VAE"—no discriminator balancing, no mode collapse concerns. WAE-GAN delivers better quality but requires careful tuning with precise learning rate ratios throughout training.

  • The framework offers architectural flexibility. Unlike VAEs requiring stochastic encoders, WAEs support deterministic encoders. The key difference: WAEs match the aggregated posterior QZ to prior PZ rather than forcing per-sample matching.

  • Extended experiments narrowed the quality gap. Training 3,000+ models on TPU hardware with wider hyperparameter sweeps, bigWAE-GAN reached FID 35 while bigWAE-MMD achieved 37—just 6% difference versus 24% in standard experiments.

WAE-MMD's O(m²) computational complexity creates memory constraints you can't ignore. Hyperparameter tuning proved dataset-specific; λ jumped 10x from MNIST to CelebA. Kernel selection became critical: RBF kernels failed completely, requiring inverse multiquadratics for successful training.

Learn when to use multi-agent systems, how to design them efficiently, and how to build reliable systems that work in production.

The four challenges of training Wasserstein auto-encoders

You might think your generative model evaluation is complete after measuring FID scores and reconstruction error. The MNIST and CelebA experiments reveal otherwise.

Four training challenges emerge that affect your ability to scale, tune hyperparameters, and reproduce benchmark results. Each challenge represents a constraint the paper discovered through systematic experimentation. Challenge #1: O(m²) computational complexity creates memory bottlenecks

Your WAE-MMD trains on small batches. You want larger batch sizes for better optimization, but computational complexity becomes the constraint.

Here's why. WAE-MMD computes Maximum Mean Discrepancy through kernel matrices requiring m² operations for m encoded samples per minibatch. Algorithm 2 in the paper shows the MMD calculation:

λ/(n(n-1)) Σ k(z_l, z_j) + λ/(n(n-1)) Σ k(z̃_l, z̃_j) - 2λ/n² Σ k(z_l, z̃_j)

This unbiased U-statistic requires kernel evaluations for every sample pair. Standard VAEs compute in O(m) time via element-wise KL-divergence calculations that scale linearly with the batch size.

The paper's experimental setup reflects this constraint. Both MNIST and CelebA experiments used batch size 100 across all models—VAE, WAE-MMD, and WAE-GAN.

The paper used a batch size of 100 for both MNIST and CelebA experiments across all models—VAE, WAE-MMD, and WAE-GAN. GPU memory requirements scale quadratically with batch size for kernel matrix storage beyond standard forward-backward passes.

Here is what the numbers show. At batch size 100, the kernel computation processes 100×100 = 10,000 pairwise evaluations. Double the batch size to 200, and you need 40,000 evaluations—4x the computation for 2x the samples. VAE's element-wise operations scale linearly: 200 samples require 2x the computation of 100 samples.

For high-dimensional latent spaces where O(m²) computations become prohibitive, the paper concludes by identifying the investigation of Sliced-Wasserstein variants as future work.

Challenge #2: Hyperparameter sensitivity complicates tuning across datasets

Hyperparameters don't transfer between datasets. What works for MNIST fails on CelebA.

The regularization coefficient λ jumps by orders of magnitude:

  • MNIST WAE-MMD: λ=10

  • CelebA WAE-MMD: λ=100 (10x jump)

  • CelebA WAE-GAN: λ=1 (100x difference on same data)

Prior variance and latent dimensions changed between datasets. MNIST used σ²_z=1 with 8-dimensional latent space. CelebA required σ²_z=2 with 64-dimensional latent space. The paper explains why: "Choosing dz larger than intrinsic dimensionality of the dataset would force the encoded distribution QZ to live on a manifold in Z. This would make matching QZ to PZ impossible if PZ is Gaussian and may lead to numerical instabilities."

Learning rates don't follow a consistent pattern. MNIST used α=10⁻³ for encoder-decoder and α=5×10⁻⁴ for the adversary. CelebA WAE-GAN flipped this: α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for adversary. CelebA also needed multiple adjustments during training: rates decreased by factor of 2 after 30 epochs, factor of 5 after 50 epochs, and factor of 10 after 100 epochs.

Kernel choice matters for WAE-MMD. The paper reports "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay," forcing a switch to inverse multiquadratics kernels.

The paper states: "even slight differences between QZ and PZ may affect the quality of samples." Small mismatches degrade sample quality noticeably.

Challenge #3: Kernel selection failures require architecture changes

Your WAE-MMD starts training. Some encoded samples end up far from the prior's support. RBF kernel penalties drop to zero for these outliers, and the encoder gets no signal to improve.

The paper hit this exact problem. They report "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay." When encoded codes z̃ = µ_φ(x) land far from PZ support during early training, the kernel terms k(z, z̃) = e^(-||z̃-z||²₂/σ²_k) rapidly approach zero. No gradient for those outliers means no training progress.

The catch with RBF kernels? They fail exactly when you need them most. Early training produces encoded samples scattered across latent space. RBF values decay exponentially—distant samples get zero penalty and zero gradient. Training stalls because the encoder can't learn how to pull outliers toward the prior.

The fix required switching kernels entirely. Inverse multiquadratics k(x,y) = C/(C + ||x - y||²₂) have much heavier tails that maintain gradients even for distant points. The paper used C = 2dz σ²_z—the expected squared distance between two Gaussian vectors drawn from PZ. This matches the kernel scale to the prior's natural spread.

The performance difference was dramatic. The paper reports this "significantly improved performance" compared to RBF kernels, even those with matching bandwidth σ²_k = 2dz σ²_z. Kernel choice "proved critical for WAE-MMD's ability to match QZ to PZ during training."

Challenge #4: Sample quality vs training stability trade-offs

Your WAE-GAN achieves FID score 42 on CelebA. Your WAE-MMD gets 55. Both crush VAE's 63. But here's the catch: that extra quality from WAE-GAN comes with adversarial training headaches.

The quality numbers tell the story:

  • WAE-GAN: FID 42, sharpness 6×10⁻³

  • WAE-MMD: FID 55, sharpness 6×10⁻³

  • VAE: FID 63, sharpness 3×10⁻³

The paper reports "in some cases WAE-GAN seems to lead to better matching and generates better samples."

But WAE-GAN trains less predictably. The paper states WAE-GAN "is less stable than WAE-MMD" due to adversarial training. WAE-MMD "has a very stable training much like VAE"—no discriminator networks to balance, no mode collapse to watch for.

Training WAE-GAN demands precision. CelebA experiments used α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for the adversary. WAE-MMD requires no discriminator balancing across experiments.

Sample quality hinges on distribution matching. The paper states "the quality of samples strongly depends on how accurately QZ matches PZ" and "even slight differences between QZ and PZ may affect the quality of samples."

Here's why: during training, the decoder sees only samples from QZ (the encoded training data). When you generate by sampling z ~ PZ, the decoder encounters a distribution it wasn't trained on. Small mismatches translate to visible quality drops.

Extended tuning closes the gap. The paper trained 3,000+ models on 8 TPU-v2 accelerators (100,000 steps each). bigWAE-GAN reached FID 35 while bigWAE-MMD hit 37—just 6% difference versus the original 24% gap.

Key takeaways

The experiments reveal specific configurations that work and trade-offs you'll encounter when training WAEs:

  • WAE-GAN delivers 33% better FID but is less stable. WAE-MMD achieves 12.7% improvement with VAE-like stability—no discriminator balancing needed.

  • Hyperparameters don't transfer between datasets. Regularization λ jumped 10x from MNIST (10) to CelebA (100). On identical data, λ varied 100x between variants.

  • RBF kernels fail—use inverse multiquadratics. The paper reports RBF kernels "fail to penalize outliers because of quick tail decay." Inverse multiquadratics "significantly improved performance."

  • Distribution matching determines sample quality. The paper states "even slight differences between QZ and PZ may affect sample quality." The decoder trains only on QZ, so mismatches degrade generation.

  • Extended tuning narrows the quality gap. With 3,000+ models on TPU hardware, bigWAE-GAN (FID 35) versus bigWAE-MMD (FID 37) showed 6% difference—down from 24%.

  • O(m²) complexity limits batch sizes. The paper used batch size 100 across all experiments due to quadratic memory scaling.

Final thoughts

The ICLR experiments expose a reality check: WAEs deliver measurable improvements with FID scores 12-33% better than VAEs and 2x sharper samples, yet the paper tested only MNIST and CelebA benchmark datasets.

Deployment constraints remain. O(m²) complexity constrained batch sizes to 100 in the experiments. Hyperparameters vary 10-100x across datasets with no transfer patterns. Kernel selection proves critical—RBF kernels fail completely.

The gap between controlled benchmark performance and real-world deployment stays largely unexplored.

Moving forward means understanding that benchmark improvements don't automatically translate to production success. The paper's four challenges make one thing clear: you need systematic evaluation beyond FID scores to validate whether generative models actually work for your use case.

Explore how Galileo helps teams evaluate AI systems systematically, tracking performance across multiple dimensions without relying on scattered benchmarks or manual validation.

Generative models face a fundamental trade-off: VAEs train stably but produce blurry samples, while GANs generate sharp images but train unpredictably.

This ICLR paper from Max Planck Institute and Google Brain introduced Wasserstein Auto-Encoders (WAEs) to solve both problems through optimal transport theory. The experiments on MNIST and CelebA show WAEs achieve 12-33% FID improvements over VAEs while maintaining training stability.

Yet challenges emerge beyond the benchmarks. WAE-MMD's O(m²) complexity limits batch sizes. Hyperparameter λ jumps 10x between datasets. Kernel selection is critical for successful training.

This article breaks down the MNIST (70,000 images) and CelebA (203,000 images) experiments, revealing four training challenges and what the results mean for choosing between WAE variants.

Summary: What WAEs deliver in production

WAEs minimize the penalized Wasserstein distance between the data distribution and model distribution. The framework was tested on MNIST (70,000 images) and CelebA (203,000 images) using DCGAN-style architectures.

Here's what the paper found:

  • Sample quality improved significantly. WAE-GAN achieved FID 42 versus VAE's 63 on CelebA, a 33% improvement. WAE-MMD reached FID 55 (12.7% improvement). Both variants generated 2x sharper samples measured at 6×10⁻³ versus VAE's 3×10⁻³.

  • Training stability depends on variant choice. The paper reports WAE-MMD "has a very stable training much like VAE"—no discriminator balancing, no mode collapse concerns. WAE-GAN delivers better quality but requires careful tuning with precise learning rate ratios throughout training.

  • The framework offers architectural flexibility. Unlike VAEs requiring stochastic encoders, WAEs support deterministic encoders. The key difference: WAEs match the aggregated posterior QZ to prior PZ rather than forcing per-sample matching.

  • Extended experiments narrowed the quality gap. Training 3,000+ models on TPU hardware with wider hyperparameter sweeps, bigWAE-GAN reached FID 35 while bigWAE-MMD achieved 37—just 6% difference versus 24% in standard experiments.

WAE-MMD's O(m²) computational complexity creates memory constraints you can't ignore. Hyperparameter tuning proved dataset-specific; λ jumped 10x from MNIST to CelebA. Kernel selection became critical: RBF kernels failed completely, requiring inverse multiquadratics for successful training.

Learn when to use multi-agent systems, how to design them efficiently, and how to build reliable systems that work in production.

The four challenges of training Wasserstein auto-encoders

You might think your generative model evaluation is complete after measuring FID scores and reconstruction error. The MNIST and CelebA experiments reveal otherwise.

Four training challenges emerge that affect your ability to scale, tune hyperparameters, and reproduce benchmark results. Each challenge represents a constraint the paper discovered through systematic experimentation. Challenge #1: O(m²) computational complexity creates memory bottlenecks

Your WAE-MMD trains on small batches. You want larger batch sizes for better optimization, but computational complexity becomes the constraint.

Here's why. WAE-MMD computes Maximum Mean Discrepancy through kernel matrices requiring m² operations for m encoded samples per minibatch. Algorithm 2 in the paper shows the MMD calculation:

λ/(n(n-1)) Σ k(z_l, z_j) + λ/(n(n-1)) Σ k(z̃_l, z̃_j) - 2λ/n² Σ k(z_l, z̃_j)

This unbiased U-statistic requires kernel evaluations for every sample pair. Standard VAEs compute in O(m) time via element-wise KL-divergence calculations that scale linearly with the batch size.

The paper's experimental setup reflects this constraint. Both MNIST and CelebA experiments used batch size 100 across all models—VAE, WAE-MMD, and WAE-GAN.

The paper used a batch size of 100 for both MNIST and CelebA experiments across all models—VAE, WAE-MMD, and WAE-GAN. GPU memory requirements scale quadratically with batch size for kernel matrix storage beyond standard forward-backward passes.

Here is what the numbers show. At batch size 100, the kernel computation processes 100×100 = 10,000 pairwise evaluations. Double the batch size to 200, and you need 40,000 evaluations—4x the computation for 2x the samples. VAE's element-wise operations scale linearly: 200 samples require 2x the computation of 100 samples.

For high-dimensional latent spaces where O(m²) computations become prohibitive, the paper concludes by identifying the investigation of Sliced-Wasserstein variants as future work.

Challenge #2: Hyperparameter sensitivity complicates tuning across datasets

Hyperparameters don't transfer between datasets. What works for MNIST fails on CelebA.

The regularization coefficient λ jumps by orders of magnitude:

  • MNIST WAE-MMD: λ=10

  • CelebA WAE-MMD: λ=100 (10x jump)

  • CelebA WAE-GAN: λ=1 (100x difference on same data)

Prior variance and latent dimensions changed between datasets. MNIST used σ²_z=1 with 8-dimensional latent space. CelebA required σ²_z=2 with 64-dimensional latent space. The paper explains why: "Choosing dz larger than intrinsic dimensionality of the dataset would force the encoded distribution QZ to live on a manifold in Z. This would make matching QZ to PZ impossible if PZ is Gaussian and may lead to numerical instabilities."

Learning rates don't follow a consistent pattern. MNIST used α=10⁻³ for encoder-decoder and α=5×10⁻⁴ for the adversary. CelebA WAE-GAN flipped this: α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for adversary. CelebA also needed multiple adjustments during training: rates decreased by factor of 2 after 30 epochs, factor of 5 after 50 epochs, and factor of 10 after 100 epochs.

Kernel choice matters for WAE-MMD. The paper reports "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay," forcing a switch to inverse multiquadratics kernels.

The paper states: "even slight differences between QZ and PZ may affect the quality of samples." Small mismatches degrade sample quality noticeably.

Challenge #3: Kernel selection failures require architecture changes

Your WAE-MMD starts training. Some encoded samples end up far from the prior's support. RBF kernel penalties drop to zero for these outliers, and the encoder gets no signal to improve.

The paper hit this exact problem. They report "We tried WAE-MMD with the RBF kernel but observed that it fails to penalize the outliers of QZ because of the quick tail decay." When encoded codes z̃ = µ_φ(x) land far from PZ support during early training, the kernel terms k(z, z̃) = e^(-||z̃-z||²₂/σ²_k) rapidly approach zero. No gradient for those outliers means no training progress.

The catch with RBF kernels? They fail exactly when you need them most. Early training produces encoded samples scattered across latent space. RBF values decay exponentially—distant samples get zero penalty and zero gradient. Training stalls because the encoder can't learn how to pull outliers toward the prior.

The fix required switching kernels entirely. Inverse multiquadratics k(x,y) = C/(C + ||x - y||²₂) have much heavier tails that maintain gradients even for distant points. The paper used C = 2dz σ²_z—the expected squared distance between two Gaussian vectors drawn from PZ. This matches the kernel scale to the prior's natural spread.

The performance difference was dramatic. The paper reports this "significantly improved performance" compared to RBF kernels, even those with matching bandwidth σ²_k = 2dz σ²_z. Kernel choice "proved critical for WAE-MMD's ability to match QZ to PZ during training."

Challenge #4: Sample quality vs training stability trade-offs

Your WAE-GAN achieves FID score 42 on CelebA. Your WAE-MMD gets 55. Both crush VAE's 63. But here's the catch: that extra quality from WAE-GAN comes with adversarial training headaches.

The quality numbers tell the story:

  • WAE-GAN: FID 42, sharpness 6×10⁻³

  • WAE-MMD: FID 55, sharpness 6×10⁻³

  • VAE: FID 63, sharpness 3×10⁻³

The paper reports "in some cases WAE-GAN seems to lead to better matching and generates better samples."

But WAE-GAN trains less predictably. The paper states WAE-GAN "is less stable than WAE-MMD" due to adversarial training. WAE-MMD "has a very stable training much like VAE"—no discriminator networks to balance, no mode collapse to watch for.

Training WAE-GAN demands precision. CelebA experiments used α=3×10⁻⁴ for encoder-decoder and α=10⁻³ for the adversary. WAE-MMD requires no discriminator balancing across experiments.

Sample quality hinges on distribution matching. The paper states "the quality of samples strongly depends on how accurately QZ matches PZ" and "even slight differences between QZ and PZ may affect the quality of samples."

Here's why: during training, the decoder sees only samples from QZ (the encoded training data). When you generate by sampling z ~ PZ, the decoder encounters a distribution it wasn't trained on. Small mismatches translate to visible quality drops.

Extended tuning closes the gap. The paper trained 3,000+ models on 8 TPU-v2 accelerators (100,000 steps each). bigWAE-GAN reached FID 35 while bigWAE-MMD hit 37—just 6% difference versus the original 24% gap.

Key takeaways

The experiments reveal specific configurations that work and trade-offs you'll encounter when training WAEs:

  • WAE-GAN delivers 33% better FID but is less stable. WAE-MMD achieves 12.7% improvement with VAE-like stability—no discriminator balancing needed.

  • Hyperparameters don't transfer between datasets. Regularization λ jumped 10x from MNIST (10) to CelebA (100). On identical data, λ varied 100x between variants.

  • RBF kernels fail—use inverse multiquadratics. The paper reports RBF kernels "fail to penalize outliers because of quick tail decay." Inverse multiquadratics "significantly improved performance."

  • Distribution matching determines sample quality. The paper states "even slight differences between QZ and PZ may affect sample quality." The decoder trains only on QZ, so mismatches degrade generation.

  • Extended tuning narrows the quality gap. With 3,000+ models on TPU hardware, bigWAE-GAN (FID 35) versus bigWAE-MMD (FID 37) showed 6% difference—down from 24%.

  • O(m²) complexity limits batch sizes. The paper used batch size 100 across all experiments due to quadratic memory scaling.

Final thoughts

The ICLR experiments expose a reality check: WAEs deliver measurable improvements with FID scores 12-33% better than VAEs and 2x sharper samples, yet the paper tested only MNIST and CelebA benchmark datasets.

Deployment constraints remain. O(m²) complexity constrained batch sizes to 100 in the experiments. Hyperparameters vary 10-100x across datasets with no transfer patterns. Kernel selection proves critical—RBF kernels fail completely.

The gap between controlled benchmark performance and real-world deployment stays largely unexplored.

Moving forward means understanding that benchmark improvements don't automatically translate to production success. The paper's four challenges make one thing clear: you need systematic evaluation beyond FID scores to validate whether generative models actually work for your use case.

Explore how Galileo helps teams evaluate AI systems systematically, tracking performance across multiple dimensions without relying on scattered benchmarks or manual validation.

If you find this helpful and interesting,

Conor Bronsdon