Counting 3,463 Big Data & Machine Learning Frameworks, Toolsets, and Examples...
Suggestion? Feedback? Tweet @stkim1

Last Commit
Jan. 23, 2019
Jan. 8, 2019


This codebase implements a method to train a neural network to produce high resolution (512x512) images WITHOUT utilizing a GAN.

Input into the network are sparse, contour lines and (optionally) a low resolution (16x16) colormap. Check out the for details.

The basic idea is to utilize a U-Net, modified perceptual loss (pearson instead of MAE), learned basis functions, and "mean teacher" training in order to synthesize images of high quality without the normal troubles of training a GAN.

Video describing the method:

To run you need a 12GB GPU, pytorch .40, python 3.


You will need to update the code in to reflect the path(s) to your dataset:

parser.add_argument('--dataroota', default=[
	], type=str)

By default it will utilize a network I trained on MS-COCO. If you want to start from scratch comment out (in

loadedSD = torch.load('./saves/autoEncoder--3.983832822715064.pth')
for k in netGStateDict.keys():
	if k in loadedSD and netGStateDict[k].size() == loadedSD[k].size():
		netGStateDict[k] = loadedSD[k]
		print('... copied')
autoEncoder.load_state_dict(netGStateDict  )

Example image output, after training: (Column order-> INPUT, OUTPUT, TARGET)

Techincal Details

Learned basis functions

The initial basis functions are a product of SVD on a pretrained (resnet18, imagenet) neural network. These 'basis functions' are then further tuned, per layer, inside of ConvSeluSVD

Check out the code for ConvSeluSVD inside of for how it is implemented.

Modified perceptual loss

Pearson distance, instead of MSE/MAE, is used. Checkout function pearsonr inside of

Mean Teacher

A running, exponetial average of the last N (5) weights are used to calculate the 'next' set of weights. In

if len(lastNWeights) > meanToStart:
		meanTeacher = np.array(lastNWeights)
		meanTeacher = np.average(meanTeacher, axis=0, weights=meanWeights).astype(np.float32)
		del lastNWeights[0]