Background

This is an extended design discussion for callback design in training fit loop of the Gluon Fit API design. Callback is a powerful tool, we can use it to deliver useful features during training for users. For example, metric update, validation, logging, and saving a model periodically.

In this document, we will discuss what functionality we should provide, how to manage shared/independent states in callbacks, how to pass arguments in callbacks, and what function should be implemented as callbacks.

Note: callbacks in the Gluon Fit-API/Estimator Design are called event handlers to avoid confusion of modules under mx.callbcaks (used for mxnet modules only)

Events to be triggered

Current supported events to be triggered, each of the following events will be triggered during different stages of the training loop.

  • train begin
  • epoch begin
  • batch begin
  • batch end
  • epoch end
  • train end

Can be added events

  • forward begin
  • forward end
  • backward begin
  • backward end

Callback implementations

Let's start with an example to discuss the ways to implement callbacks. Let's implement a Stopping Criterial to stop at certain number of batches or epochs. It tells the for loop whether to stop training.

stop = StopTrainingHandlerV2(max_batch=100, max_epoch=10)

for epoch in range(20):
for batch in range(25):
print('epoch: ', epoch)
print('batch: ', batch)
batch_result = stop.batch_end()
if batch_result:
break
epoch_result = stop.epoch_end()
if epoch_result:
break

Multi-inheritance

One base class for one type of event, base class keep states that's specific to that event only.
Rule of thumb: common states for both Base classes should be managed by sub class inherit those base classes.

class BatchEnd(object):
def __init__(self, max_batch=None):
self.batch_idx = 0
self.total_batch = 0
self.max_batch = max_batch

def batch_end(self, batch_result={}):
self.batch_idx += 1
self.total_batch += 1


class EpochEnd(object):
def __init__(self, max_epoch=None):
self.epoch = 0
self.max_epoch = max_epoch

def epoch_end(self, epoch_result={}):
self.epoch += 1

class StopTrainingHandler(BatchEnd, EpochEnd):
def __init__(self, max_batch, max_epoch):
super().__init__(max_batch)
super(BatchEnd, self).__init__(max_epoch)
self.stop_training = False

def batch_end(self, batch_result={}):
super(StopTrainingHandler, self).batch_end(batch_result)
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training

def epoch_end(self, epoch_result={}):
super(StopTrainingHandler, self).epoch_end(epoch_result)
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training

Method override

One base class with all event methods

class EventHandler(object):
"""Basic for event handlers

:py:class:`EventHandler` can perform user defined functions at
different stages of training: train begin, epoch begin, batch begin,
batch end, epoch end, train end.

Parameters
----------
estimator : Estimator
The :py:class:`Estimator` to get training statistics
"""

def __init__(self):
self._estimator = None

def train_begin(self, *args, **kwargs):
pass

def epoch_begin(self, *args, **kwargs):
pass

def batch_begin(self, *args, **kwargs):
pass

def batch_end(self, batch_id, batch_results=None, *args, **kwargs):
return False

def epoch_end(self, epoch, epoch_results=None, *args, **kwargs):
return False

def train_end(self, *args, **kwargs):
pass


class StopTrainingHandlerV2(object):
def __init__(self, max_batch, max_epoch):
self.batch_idx = 0
self.epoch = 0
self.total_batch = 0
self.max_epoch = max_epoch
self.max_batch = max_batch
self.stop_training = False

def batch_end(self, batch_result={}, *args, **kwargs):
self.batch_idx += 1
self.total_batch += 1
if self.total_batch == self.max_batch:
self.stop_training = True
return self.stop_training

def epoch_end(self, epoch_result={}, *args, **kwargs):
self.epoch += 1
# reset batch index at end
self.batch_idx = 0
if self.epoch == self.max_epoch:
self.stop_training = True
return self.stop_training

Conclusion:

  1. Any base class should not keep any states, each specific child class maintain all states. Because it's very common that some states initiated in batch begin will be used in batch end or epoch end. Then it will be managed by a child class inherit BatchBegin, BatchEnd, EpochEnd
  2. There is no difference in efficiency on avoiding empty method calls, both approach can do that. The key idea is to categorize all callbacks into 6 lists, depending on whether it override it's parent class.
    1. Multi-inheritance can be categorized by using isinstance(TrainBegin) etc
    2. Method override can be categorized by: handler.__class__.train_begin == EventHandler.train_begin
  3. Only difference is method can be changed and modified during run time.
  4. We will go with multi-inheritance, but keep all states in each specific class of event handlers. For example:
class TrainBegin(object):
def train_begin(self, estimator, *args, **kwargs):
pass


class TrainEnd(object):
def train_end(self, estimator, *args, **kwargs):
pass


class EpochBegin(object):
def epoch_begin(self, estimator, *args, **kwargs):
pass


class EpochEnd(object):
def epoch_end(self, estimator, *args, **kwargs):
return False


class BatchBegin(object):
def batch_begin(self, estimator, *args, **kwargs):
pass


class BatchEnd(object):
def batch_end(self, estimator, *args, **kwargs):
return False


class MetricHandler(EpochBegin, BatchEnd):
def __init__(self, train_metrics):
self.train_metrics = train_metrics
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.rank = 1

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
metric.reset()

def batch_end(self, estimator, *args, **kwargs):
pred = kwargs['pred']
label = kwargs['label']
loss = kwargs['loss']
for metric in self.train_metrics:
if isinstance(metric, Loss):
# metric wrapper for loss values
metric.update(0, loss)
else:
metric.update(label, pred)

Book keeping training states

There are many internal training states we want to keep track of during the training process. Fit loop can book keep some of these internal states and each callback can also book keep them. We need a criterial to decide who to manage each state.

There are mainly two design considerations:

  1. Shared states between callbacks, all managed by estimator/fit loop:
    1. Keeping everything in fit loop will make it hard for adding new callbacks, users need to know what variables are available to access and when are they created/updated so the value they get is correct and up to date.
    2. Custom call back can change epoch number and cause bugs in other callbacks
  2. Independent states, states in each callback are isolated from others
    1. Keeping everything in each callback will create duplicate states for those commonly used by different callbacks. Each callback need to update it's own number of epochs/batches can create error and bugs. Not to mention loss and metric values.


Fit loop has the following structure:

for epoch in range(max_epochs):
for i, batch in train_data:
...

here epoch number, batch number is naturally managed, epoch won't exceed max epoch, batch index is automatically set to 0 after each epoch. They should be managed in Fit loop and passed to callbacks

What states are ok to keep copies in each callback, what states should be managed by Fit loop?

What's already available in Estimator/Fit Loop:

  • Fit loop: data loader
  • Fit loop: data, label for each batch
  • Fit loop: output of net
  • Fit loop: current epoch
  • Fit loop: max epoch
  • Fit loop: batch index
  • Estimator: trainer
  • Estimator: net
  • Estimator: loss functions and training metrics

Common internal states most callback need to access:

  • learning rate
  • batch size
  • current epoch
  • max epoch
  • current batch index
  • max batch to train
  • training loss values
  • training metric values
  • validation loss values
  • validation metric values

Callback specific states:

  • best validation loss (for monitor)
  • last time a model was saved
  • training time
  • total steps/batches trained

Conclusion

We categorize states into 3 types and they should be managed differently

  1. Attributes in Estimator, these should be naturally managed by estimator itself.
    1. net
    2. trainer
    3. max epochs to train
  2. Dynamic counter-like states (current epoch, batch, total batch) should be managed by each callback
    1. duplicates of these states bring little to no memory/performance issue
    2. Bugs in updating the states are contained in specific callback, won't cause other callbacks to fail.
    3. Each callbacks only manage states used by itself.
  3. Loss/Metric values many callbacks need access to. They are shared between callbacks
    1. Create metric wrapper(mx.metric.Loss) to record and update loss values
    2. Make copies of loss/metric objects to record validation results
    3. Share these loss/metric with callbacks by passing their references to each callback during initialization

Passing arguments for callbacks

Based on the conclusion above, we need to decide how to pass those states to callbacks.

For states managed by each callback, they can access and update them as in StopTrainingHandlerV2 above. Here we discuss how to pass external arguments (those managed by fit loop)
There are a few options

  1. Inject estimator class to callbacks
    1. Example: https://github.com/apache/incubator-mxnet/pull/14629/files#diff-7f58d4d4cb6c2e6088afa89097fbb7e3R250
    2. Tightly couples callbacks with estimator class
    3. Callbacks can access anything from estimator, inject anything in estimator
    4. can do anything in callback
  2. Define common arguments and **kwargs and pass everything (Keras and Fast.ai)
    1. pass number of epoch/batch as args, and loss/metric values as kwargs
    2. how about batch size info, batch data from data loader?
    3. User still need to know how to access **kwargs
  3. Register each callback for arguments to pass during call
    1. Can register each callback's *args and **kwargs as in Pytorch Ignite: https://github.com/pytorch/ignite/blob/master/ignite/engine/engine.py#L122
    2. How to register args not available before hand? some args only available after fit loop started

Conclusion:

We can pass some of these states through kwargs during callback calls. We can also inject estimator into each callback so they have access to everything managed by estimator. We need to draw a line on what states will be book kept by what. Based on the conclusion on different states:

  1. net, trainer, max epochs can be passed to callbacks as attributes of estimator. These states/reference to the states will not likely to change. We will pass a weak reference of estimator at each callback.
  2. output of net, loss value, data, label of each batch will be passed to callbacks through kwargs. These are information generated by fit loop and are different each batch. These information are only used in batch end.
  3. epoch number, batch number, and other specific states should be managed by each callback itself. We don't pass them during callback call.
  4. Loss and metric values will be passed to callbacks during initialization, fit loop will pass an reference of each loss/metric instance to callbacks. Estimator will provide convenience method to help use create metric wrappers for loss and make copies loss/metric for validation result.


What should become a callback

The following are provided as callbacks for now:

  • Logging
  • Checkpoint
  • EarlyStopping

Here we discuss whether we should make other features as callbacks and what additional requirement they need to become a callback

  • Validation (Need to create additional loss and metric object to record values, other callbacks will access them)
    • on epoch begin: reset loss and metrics
    • on epoch end: call validation logic
    • input: validation data
    • output: validation loss and metrics
    • user create copy of val metrics
  • Metrics (Need to happen before other callbacks access loss/metrics value)
    • on epoch begin: reset loss and metrics
    • on batch end: update loss and metric
    • input: loss, prediction, label
    • output: updated loss and metrics value

Conclusion:

  1. Making Metric and Validation callbacks and make sure they are called before other callbacks
    1. we can give each callback a ranking and sort them before calling them
  2. We write convenience method to help user create copies of loss and metrics and manage them.
  3. Provide metrics, validation and logging as default callbacks.


Other considerations:

  1. attach and detach a callback (remove a callback at certain condition)
  2. Combine train_loss and train_metrics into single attribute and update them depends on type:
    1. isinstance(metric, mx.metric.Loss): metric.update(0, loss_value)
    2. isinstance(metric, mx.metric.EvalMetric): metric.update(label, pred)
  3. Change convenience method to prepare loss as a callback
  4. change epoch end to true callback ( class EpochBegin: def __call_())
    1. cause confusion in multi-inheritance

References:

  1. Gluon hook: https://github.com/apache/incubator-mxnet/pull/10989
  2. Pytorch ignite: https://github.com/pytorch/ignite/blob/master/ignite/handlers/checkpoint.py
  3. Keras Callbacks: https://github.com/keras-team/keras/blob/master/keras/callbacks.py
  4. Fast.ai: https://github.com/fastai/fastai/tree/master/fastai/callbacks
  • No labels