Triplet Loss and Online Triplet Mining in PyTorch (GPU)
PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.
triplet_loss.py file in your project and include either
from triplet_loss import batch_hard_triplet_loss labels = torch.randint(5) # our five labels embeddings = model(labels) loss = batch_hard_triplet_loss(labels, embeddings, margin=0.2) loss.backward() # and so on
pip install -r requirements.txt python3 -m pytest test_triplet_loss.py
- Triplet Loss and Online Triplet Mining in Tensorflow
- Facenet paper
- adambielski's nice implementation (unfortunately context switches between CPU / GPU)