Ignite is a high-level library to help with training neural networks in PyTorch.
Note: Ignite is currently in alpha, and as such the code is changing rapidly in master. We hope to stabalise the API as soon as possible and keep the examples up to date.
- How does this compare to Torchnet?
API documentation, examples and tutorials coming soon.
python setup.py install
The main component of Ignite is the
Trainer, an abstraction over your training loop. Getting started with the trainer is easy, the constructor only requires two things:
training_data: A collection of training batches allowing repeated iteration (e.g., list or DataLoader)
training_update_function: A function which is passed a
batchand passes data through and updates your model
Optionally, you can also provide validation_data and validation_update_function for evaluating on your validation set.
training_update_function will be something like:
optimzer = ... model = ... criterion = ... def training_update_function(batch): model.train() optimizer.zero_grad() x, y = Variable(batch), Variable(batch) prediction = model(x) loss = criterion(prediction, y) loss.backward() optimizer.step() return loss.data
You can then construct your
Trainer and train for num_epochs as follows:
from ignite.trainer import Trainer trainer = Trainer(train_dataloader, training_update_function) trainer.run(max_epochs=5)
Training & Validation History
The return values of your training and validation update functions are stored in the Trainer in the members training_history and validation_history. These can be accessed via event handlers (see below) and used for updating metrics, logging etc. Importantly, the return type of your update functions need not just be the loss, but can be any type (list, typle, dict, tensors etc.).
Events & Event Handlers
Trainer emits events during the training loop, which the user can attach event handlers to. The events that are emitted are defined in
ignite.trainer.TrainingEvents, which at present are:
Users can attach multiple handlers to each of these events, which allows them to control aspects of training such as early stopping, or reducing the learning rate as well as things such as logging or updating external dashboards like Visdom or TensorBoard (See Examples for more details on using Visdom).
Event handlers are any callable where the first argument is an instance of the
Trainer. Users can also pass any other arguments or keyword arguments to their event handlers. For example, if we want to terminate training after 100 iterations if the learning rate hasn't decreased in the last 10 iterations, we could define the following event handler and attach it to the
from ignite.trainer import TrainingEvents def early_stopping_handler(trainer, min_iterations, lookback=1): if trainer.current_iterations >= min_iterations: last_loss = trainer.training_history[-1] if not any(x < last_loss for x in trainer.training_history[-lookback:]): trainer.terminate() min_iterations = 100 trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED, early_stopping_handler, min_iterations, lookback=5)
Ignite uses python's standard library logging module, which means you can integrate the Ignite logs directly into your application logs. To do this, simply attach a log handler to the ignite logger:
import logging logger = logging.getLogger('ignite') logger.addHandler(logging.StreamHandler()) logger.setLevel(logging.INFO)
Ignite supports certain metrics which can be used to classify the performance of a given model. The metrics currently available in
binary_accuracy: This takes a
validation_history) and an optional callable transform and computes the binary accuracy which is 1 if the values are equal or 0 otherwise. This is generally used for binary classification tasks
categorical_accuracy: This is the
binary_accuracyequivalent for multi-class classification where number of classes are greater than 2.
top_k_categorical_accuracy: This computes the Top K classification accuracy, which is a popular mode of evaluating models on larger datasets with higher number of classes. The semantics are similar to
categorical_accuracyexcept there is an additional argument for the value of
mean_squared_error: Generally used in regression tasks, this computes the sum of squared deviations between the predicted value and the actual value for a given input datapoint. This function takes a
Historyobject and an optional callable transform and computes the mean squared error. The square root of this gives the root mean squared error (RMSE).
mean_absolute_error: This is similar to the
mean_squared_errorfunction, but instead computes the sum of absolute deviations between the predicted value and the actual value for a given input datapoint.
At present, there is an example of how to use ignite to train a digit classifier on MNIST in examples/, this example covers the following things:
- Attaching custom handlers to training events
- Attaching ignite's handlers to training events
- Using handlers to plot to a visdom server to visualize training loss and validation accuracy
How does this compare to Torchnet?
Ignite, in spirit is very similar to torchnet (and was inspired by torchnet).
The main differences with torchnet is the level of abstraction for the user. Ignite's higher level of abstraction assumes less about the type of network (or networks) that you are training, and we require the user to define the closure to be run in the training and validation loop. In contrast to this, torchnet creates this closure internally based on the network and optimizer you pass to it. This higher level of abstraction allows for a great deal more of flexibility, such as co-training multiple models (i.e. GANs) and computing/tracking multiple losses and metrics in your training loop.
Ignite also allows for multiple handlers to be attached to events, and a finer granularity of events in the loop.
That being said, there are some things from torchnet we really like and would like to port over, such as the integration with Visdom (and possibly add integration with TensorBoard).
As always, PRs are welcome :)
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.