Triplet Loss and Online Triplet Mining in PyTorch (GPU)

PyTorch conversion of the excellent post on the same topic in Tensorflow. Simply an implementation of a triple loss with online mining of candidate triplets used in semi-supervised learning.


Include the file in your project and include either batch_hard_triplet_loss or batch_all_triplet_loss.

Example usage:

from triplet_loss import batch_hard_triplet_loss

labels = torch.randint(5) # our five labels

embeddings = model(labels)

loss = batch_hard_triplet_loss(labels, embeddings, margin=0.2)
# and so on


pip install -r requirements.txt
python3 -m pytest