Counting 2,784 Big Data & Machine Learning Frameworks, Toolsets, and Examples...
Suggestion? Feedback? Tweet @stkim1


Triplet loss in TensorFlow Build Status

Author: Olivier Moindrot

This repository contains a triplet loss implementation in TensorFlow with online triplet mining. Please check the blog post for a full description.

The code structure is adapted from code I wrote for CS230 in this repository at tensorflow/vision. A set of tutorials for this code can be found here.


We recommend using python3 and a virtual environment. The default venv should be used, or virtualenv with python3.

python3 -m venv .env
source .env/bin/activate
pip install -r requirements_cpu.txt

If you are using a GPU, you will need to install tensorflow-gpu so do:

pip install -r requirements_gpu.txt

Triplet loss

Triplet loss on two positive faces (Obama) and one negative face (Macron)

The interesting part, defining triplet loss with triplet mining can be found in model/

Everything is explained in the blog post.

To use the "batch all" version, you can do:

from model.triplet_loss import batch_all_triplet_loss

loss, fraction_positive = batch_all_triplet_loss(labels, embeddings, margin, squared=False)

In this case fraction_positive is a useful thing to plot in TensorBoard to track the average number of hard and semi-hard triplets.

To use the "batch hard" version, you can do:

from model.triplet_loss import batch_hard_triplet_loss

loss = batch_hard_triplet_loss(labels, embeddings, margin, squared=False)

Training on MNIST

To run a new experiment called base_model, do:

python --model_dir experiments/base_model

You will first need to create a configuration file like this one: params.json. This json file specifies all the hyperparameters for the model. All the weights and summaries will be saved in the model_dir.

Once trained, you can visualize the embeddings by running:

python --model_dir experiments/base_model

And run tensorboard in the experiment directory:

tensorboard --logdir experiments/base_model

Here is the result (link to gif):

Embeddings of the MNIST test images visualized with T-SNE (perplexity 25)


To run all the tests, run this from the project directory:


To run a specific test:

pytest model/tests/