Counting 3,463 Big Data & Machine Learning Frameworks, Toolsets, and Examples...
Suggestion? Feedback? Tweet @stkim1

Last Commit
Jul. 25, 2018
Apr. 4, 2018

Prediction Optimizer (to stabilize GAN training)


This is a PyTorch implementation of 'prediction method' introduced in the following paper ...

  • Abhay Yadav et al., Stabilizing Adversarial Nets with Prediction Methods, ICLR 2018, Link
  • (Just for clarification, I'm not an author of the paper.)

The authors proposed a simple (but effective) method to stabilize GAN trainings. With this Prediction Optimizer, you can easily apply the method to your existing GAN codes. This impl. is compatible with most of PyTorch optimizers and network structures. (Please let me know if you have any issues using this)



  • Import
    • from prediction import PredOpt
  • Initialize just like an optimizer
    • pred = PredOpt(net.parameters())
  • Run the model in a 'with' block to get results from a model with predicted params.
    • With 'step' argument, you can control lookahead step size (1.0 by default)
    • with pred.lookahead(step=1.0):
          output = net(input)
  • Call step() after an update of the network parameters
    • optim_net.step()


  • You can find a sample code in this repository (
  • A sample snippet
  • import torch.optim as optim
    from prediction import PredOpt
    # ...
    optim_G = optim.Adam(netG.parameters(), lr=0.01)
    optim_D = optim.Adam(netD.parameters(), lr=0.01)
    pred_G = PredOpt(netG.parameters())             # Create an prediction optimizer with target parameters
    pred_D = PredOpt(netD.parameters())
    for i, data in enumerate(dataloader, 0):
        # (1) Training D with samples from predicted generator
        with pred_G.lookahead(step=1.0):            # in the 'with' block, the model works as a 'predicted' model
            fake_predicted = netG(Z)                           
            # Compute gradients and loss 
        # (2) Training G
        with pred_D.lookahead(step=1.0:)            # 'Predicted D'
            fake = netG(Z)                          # Draw samples from the real model. (not predicted one)
            D_outs = netD(fake)
            # Compute gradients and loss
            pred_G.step()                           # You should call PredOpt.step() after each update

Output samples

You can find more images at the following issues.

Training w/ large learning rate (0.01)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 01 ep25_cifar_pred_lr 0 01
ep25_celeba_base_lr 0 01 ep25_celeba_pred_lr 0 01

Training w/ medium learning rate (1e-4)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 0001 ep25_cifar_pred_lr 0 0001
ep25_celeba_base_lr 0 0001 ep25_celeba_pred_lr 0 0001

Training w/ small learning rate (1e-5)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 00001 ep25_cifar_pred_lr 0 00001
ep25_celeba_base_lr 0 00001 ep25_celeba_pred_lr 0 00001

External links


  • : Impl. as an optimizer
  • : Support pip install
  • : Add some experimental results