Skip to content

[AAAI20] TensorFlow implementation of the Collaborative Sampling in Generative Adversarial Networks

Notifications You must be signed in to change notification settings

vita-epfl/collaborative-gan-sampling

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Collaborative Sampling in Generative Adversarial Networks

This repository provides a TensorFlow implementation of the Collaborative Sampling in Generative Adversarial Networks.


Overview

Once GAN training completes, we use both the generator and the discriminator to produce samples collaboratively. Our sampling scheme consists of one sample proposal step and multiple sample refinement steps. (I) The fixed generator proposes samples. (II) Subsequently, the discriminator provides gradients, with respect to activation maps of the proposed samples, back to a particular layer of the generator. Gradient-based updates of the activation maps are performed repeatedly until the samples are classified as real by the discriminator.


GANs for modelling an imbalanced mixture of 8 Gaussians. Vanilla GANs are prone to mode collapse. The accept-reject sampling algorithms including Discriminator Rejection Sampling (DRS) and Metropolis-Hastings method (MH-GAN) suffer from severe distribution bias due to the mismatch between distribution supports. Our collaborative sampling scheme applied to early terminated GANs succeeds in recovering all modes without compromising sample quality, significantly outperforming the baseline methods.

Real GAN
1K Iter
GAN
9K Iter
DRS
at 1K Iter
MH-GAN
at 1K Iter
Refine
at 1K Iter
Collab
at 1K Iter
Quality Diversity Overall

DCGAN for modelling human faces on the CelebA dataset. (Top) Samples from standard sampling. (Middle) Samples from our collaborative sampling method. (Bottom) The difference between the top and the middle row.

Cifar10 CelebA

CycleGAN for unpaired image-to-image translation. (Top) Samples from standard sampling. (Middle) Samples from our collaborative sampling method. (Bottom) The difference between the top and the middle row.


Dependencies:

  • tensorflow==1.13.0
  • CUDA==10.0
  • pillow
  • scipy=1.2
  • matplotlib
  • requests
  • tqdm

Citation:

If you use this code for your research, please cite our papers.

@inproceedings{liu2019collaborative,
  title={Collaborative Sampling in Generative Adversarial Networks},
  author={Liu, Yuejiang and Kothari, Parth and Alahi, Alexandre},
  booktitle={Thirty-first AAAI conference on artificial intelligence},
  year={2020}
}

Acknowledgements

The baseline implementation has been based on this repository