Background

When doing multi-class classification, categorical cross entropy loss is used a lot. It compares the predicted label and true label and calculates the loss. In Keras with TensorFlow backend support Categorical Cross-entropy, and a variant of it: Sparse Categorical Cross-entropy. Before Keras-MXNet v2.2.2, we only support the former one. We added sparse categorical cross-entropy in Keras-MXNet v2.2.2 and a new multi-host categorical cross-entropy in v2.2.4. In this document, we will review how these losses are implemented.

Categorical Cross Entropy:

Following is the definition of cross-entropy when the number of classes is larger than 2. Sparse Categorical Cross-entropy and multi-hot categorical cross-entropy use the same equation and should have the same output. The difference is both variants covers a subset of use cases and the implementation can be different to speed up the calculation.

Following is the pseudo code of implementation in MXNet backend following the equation:

loss = - mx.sym.sum(y_true * mx.sym.log(y_pred), axis=axis)

Both the shape of y_true and y_pred is in (num_samples, num_classes). For example, for a 3 class classification problem:

y_true = [[0, 0, 1],
          [1, 0, 0],
          ...
          [0, 0, 1]]
where len(y_true) = num_samples, len(y_true[0]) = num_classes

y_pred is in the format of predicted probabilities of each class:

y_pred = [[0.1, 0.1, 0.8],
          [0.5, 0.2, 0.3],
          ...
          [0.0, 0.2, 0.8]]

Sparse Categorical Cross Entropy

Definition

The only difference between sparse categorical cross entropy and categorical cross entropy is the format of true labels. When we have a single-label, multi-class classification problem, the labels are mutually exclusive for each data, meaning each data entry can only belong to one class. Then we can represent y_true using one-hot embeddings.

For example, y_true with 3 samples, each belongs to class 2, 0, and 2.

y_true = [[0, 0, 1], 
          [1, 0, 0], 
          [0, 0, 1]]

Will become:

y_true_one_hot = [2, 0, 2]

Improvements

This saves memory when the label is sparse (the number of classes is very large). Now how we can improve the speed?

When labels are mutually exclusive, in normal cross-entropy calculation, there is a lot of log/dot/sum operations applied on y_pred probabilities that will eventually be 0s. We actually don't need those operations.

Instead of doing log operations on all y_pred and then dot with y_true, we can directly pick the probability from y_pred at the index provided by y_true. Then apply log operation. No dot and sum operation is needed.

Following is the pseudo code:

loss = mx.sym.pick(y_pred, y_true, axis=axis, keepdims=True)
loss = - mx.sym.log(loss, axis=axis)

Result

This will have at around 2 times speed up when the number of classes is 1000, implementation can be found at https://github.com/awslabs/keras-apache-mxnet/pull/145

Multi-hot Sparse Categorical Cross Entropy

Definition

For multi-label mult-class problems, each data can belong to many classes, and in practice, the total number of classes can be very large, and the number of classes each data can belong to is relatively small. For example, total 3000 classes, and each data can belong to multiple but no mroe than 5 classes.

So y_true in tradition cross-entropy will become:

y_true = [[0, 1, 1, ..., 0],
          [1, 1, 0, ..., 0],
          [1, 1, 1, ..., 0],
          ...
          [0, 0, 0, ..., 1]]
where len(labels) = num_samples, len(labels[0]) = num_classes

And we can use multi-hot embedding to record y_true, similar to sparse categorical cross-entropy, note the length of each y_true can be different, can vary from 1 to 5 for example.

y_true = [[1, 2],
          [0, 1],
          [0, 1, 2],
          ...
          [999]]

Implementation

Process input

The first step is to make the label length consistent, we can pad the labels with -1, to it can be differentiated with actual labels with value in the range from 0 to num_class. In this way label shape will still be much smaller than original (num_samples, num_class). It will become:

padded_y_true = [[-1, -1, -1, 1, 2],
                 [-1, -1, -1, 0, 1],
                 [-1, -1, 0, 1, 2],
                   ...
                 [-1, -1, -1, -1, 999]]

Speed up using for loop and take operator

Similar to how we used pick in sparse categorical cross-entropy, we want to directly take out the prediction probabilities at the index to avoid log/dot/sum operations.

However, pick only allows you to take one index for each row, we have to use take which allows us to take multiple indices for each row. Then we just need to loop over y_pred and take indices according to y_true for each row.

To loop over a symbol in mxnet, we have to slice it at each row, and concat at the end. The steps will become:

  1. loop over y_pred, take indicies according to y_true
  2. concatenate looped outputs
  3. remove the padded values we added and marked with -1 label
  4. apply log and sum operators on a much smaller matrix

The pseudo code became:

outputs = []
for i in range(0, len(y_pred)):
    pred_i = mx.sym.slice_axis(y_pred, begin=i, end=i+1, axis=0)
    true_i = mx.sym.slice_axis(y_true, begin=i, end=i+1, axis=0)
    out_i = mx.sym.take(pred_i, true_i)
    outputs.append(out_i)
outputs = mx.sym.concats(**outputs)
outputs = - mx.sym.sum(mx.sym.broadcast_greater_equal(y_true, mx.sym.zeros((1, 1))) *
                       mx.sym.log(outputs), axis=axis)

Improve using foreach operator

Now with MXNet's control flow operators, we can further improve the speed by replacing python for loop using foreach operator. Details of the API and design can be found here and here

So we can remove step 2 above and the pseudo code became:

data = [y_pred, y_true]

# using control flow ops to iterate output and take target (true label)
_step = lambda data, _: (mx.sym.take(data[0], data[1]), [])
data = [mx_output, target.symbol]
outputs, _ = mx.symbol.contrib.foreach(_step, data, [])

# calculate loss
# check if target is larger than 0, remove padded labels (-1)
outputs = - mx.sym.sum(mx.sym.broadcast_greater_equal(target.symbol, mx.sym.zeros((1, 1))) *
                       mx.sym.log(outputs), axis=axis)

Result

With 3000 classes, and each data belong to most 5 classes, we acheievd around 3 times speed up compare to normal categorical cross entropy.

You can find the implementation here:  https://github.com/awslabs/keras-apache-mxnet/pull/163

and an example here:https://github.com/awslabs/keras-apache-mxnet/blob/master/examples/multi_hot_sparse_categorical_crossentropy.py

Reference

1. Cross Entropy Loss: https://ml-cheatsheet.readthedocs.io/en/latest/loss_functions.html#cross-entropy

2. sparse categorical cross-entropy PR: https://github.com/awslabs/keras-apache-mxnet/pull/145

3. multi-hot sparse categorical cross-entropy PR: https://github.com/awslabs/keras-apache-mxnet/pull/163

4. control flow operators: https://cwiki.apache.org/confluence/display/MXNET/Optimize+dynamic+neural+network+models+with+control+flow+operators

  • No labels