Implementation of NIPS 2017 paper Pose Guided Person Image Generation in PyTorch.
- Python 3.6
The key task is to transfer the appearance of a person from a given pose to a desired pose keeping the important appearance details intact. A two-stage approach is proposed to address the task, with each stage focusing on one aspect. The first stage comprises of pose integration generating a coarse image. The second stage utilizes a variant of a conditional DCGAN to fill in more appearance details.
The generation framework utilizes the pose information explicitly and consists of two key stages: pose integration and image refinement. The architecture of generator is inspired by U-Net in both the key stages. In the first stage the condition image and the target pose are fed into the network to generate a coarse image of the person with the target pose. The second stage then refines the blurry result by training a generator in an adversarial way. The architecture of the generator and discriminator are shown below:
Clone the source code:
git clone https://github.com/harshitbansal05/Pose-Guided-Image-Generation/ cd Pose-Guided-Image-Generation
./run_prepare_data.sh. It creates a
datafolder in the root directory and downloads the data from the author's website. It extracts the zip file and processes the images for the train and test data sets.
- It creates the folders
DF_test_datain the directory
data/DF_img_posefor the train and test data sets respectively printing the count of both the sets at the end.
./run_train_model.sh. It performs the data preparation step if it has'nt been performed. The variable
gpuin the script must be changed to -1 if the training process is carried on the cpu. Finally, training begins.
- It parses the namespaces present in the file
config.pyand those passed explicitly. It also imports the models from the file
models.pyand the PyTorch DataLoader from the file
- It trains the two generators and the discriminator for a specified number of epochs, printing the loss at each step. It also periodically saves the generator and discriminator models in the directory specified by argument
checkpoint_dirin the file
The DeepFashion (In-shop Clothes Retrieval Benchmark) dataset consists of 52,712 in-shop clothes images, and 200,000 cross-pose/scale pairs. All images are in high-resolution of 256×256. The entire dataset can be downloaded from here.
Suggestions and pull requests are actively welcome.