Generating Atari images using Generative Adversarial Networks.
_ ________ _ ______ __ / \ |__ __| / \ | _ \ | |  ,----.___ / . \ | | / . \ | |_) / | | ||_/___ '. / /_\ \ | | / /_\ \ | ( | | __||_/___ / ___ \ | | / ___ \ | |\ \ | | / O|| /| ) /__/ \__\ |__| /__/ \__\ |__| \__\ |__| / "" / / =._/ /________/ / |________|/
Table of Contents
Pytorch implementation of the process of Image Generation using Generative Adverserial Networks. The dataset is generated on the fly using Open AI's gym framework.
Clone the project
> git clone [email protected]:satwikkansal/atari_gan.git > cd atari_gan
Installing the dependencies first (create a virtualenv if you like)
> pip3 install -r requirements.txt
To train the model, run the command
> python train.py
with an optional
--gpu flag to enable CUDA computations to speed up the training process, and a
--restore flag which accepts path to directory containing saved models (which are saved in
model/ directory by default) to restore the models from an earlier saved state.
To visualize the results, run the tensorbard server using the following command,
> tensorboard --logdir runs --host localhost
Once run, now you can visit
http://localhost:6006 in your browser to visualize the Discriminator loss, Generator loss and the generated images by the Generator Network during the training. Your dashboard will look something like this,
Is you'd like to only generate images using gym's atari games, please have a look at generate_data.py script. To use it, simply run
> python generate_data.py
This would generate
10000 images and store it in
data/ directory. If you want to manually provide these parameters, you can use the
And if you want to tweak more of the parameters involved in training, have a look at the constants defined in top of each file in this project.
All patches welcome!
Here's the pure tensorflow implementation of Generative Adversarial Networks for MNIST-dataset: https://gist.github.com/satwikkansal/201bda0c08b5e12c44c6b07e6db8853e
MIT License - see the LICENSE file for details.