Deep Spectral Clustering Learning
Pytorch Implementation of Deep Spectral Clustering Learning, the state of the art of Deep Metric Learning Paper.
- Python 3.6+
- Pytorch 0.4.0+
Currently only fine-tuning method on CARS dataset is supported.
If you want to use your own custom data set, look at the class CustomDataset in data_loader.py and datasets.py
- To visualize intermediate results and loss plots, run
python -m visdom.serverand go to the URL http://localhost:8097
$ python train.py --data_dir=/data_path --width_size=299 --lr=1e-5 --label_size=98 --large_batch_epoch=400 --large_batch_size=100 --small_batch_size=60 --dropout_rate=0.30 --model=inception_crop
$ python test.py --data_dir='/hdd/DeepSpectralClustering/data' --width_size=299 --large_batch_epoch=410 --k=8 --model=inception
- There are 2 methods(last layer / end-to-end) described in the paper, But I only included fine-tuning method because of the GPU memory issue.
- This code does not include
DSCL Normalized Spectral Clustering, which is a post processing method to improve score metric.
- Loss function is implemented as "implementation detail" described in the paper.
- I used top@k recall score for testing, except NMI score with K-means clustering.
- Training of DSCL is very sensitive to batch size, learning rate, image augmentation and dropout rate. I strongly suggest handle these hyper parameters carefully.
- I achieved about 80% top@8 recall score on CARS data set, but it is low compared to 93% top@8 recall score in the paper.
- Metric score in the paper can be achieved with proper hyper parameters.
- To prevent training explosion, I skipped applying gradient when loss is more than 500M
Results on CARS data set
|Top K Recall||R@1||R@2||R@4||R@8|
|Scores In The Paper||67.54||77.77||85.74||90.95|
Visualization code(visualizer.py, utils.py) references to pytorch-CycleGAN-and-pix2pix(https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) by Jun-Yan Zhu