This document proposes a new module in MXNet to support import/export functionalities for different formats like ONNX, CoreML.
Objectives
This document will talk more about ONNX format as a starting point but the design proposed should be generic enough to extend it to other formats later when needed.
ONNX is an intermediate representation for describing a neural network computation graph and weights. With AWS, Microsoft and Facebook defining and promoting ONNX, major Deep Learning frameworks such as MXNet, Pytorch, Caffe2, and CNTK are building native support for model import and export.
This document will also define the use cases for ONNX usage in MXNet. The ONNX import functionality is already implemented but the code it is in an external repository under onnx org controlled by Facebook.
Use cases:
1) Import ONNX into MXNet symbolic interface for inference.
2) Import ONNX into MXNet Gluon for inference.
3) Export MXNet symbolic model into ONNX.
4) Export MXNet Gluon model into ONNX.
Future Use Cases (out of scope of this design):
1) Hyper parameter tuning a previously trained ONNX model in MXNet.
2) Use MXNet to do transfer learning.
Usage example and Method Parameters
To implement the import/export functionality in MXNet, I propose to expose a MXNet python module to do the model conversion:
- “serde”(serialization/deserialization - name taken from Apache Hive project) which can be called as (mx.serde.import/ mx.serde.export)
- OR It can also go under contrib. MXNet contrib is used to add experimental features which can be later moved outside. This can qualify as an experimental feature as there is a gap in operator implementation. (See Appendix)
mx.contrib.serde.import/ mx.contrib.serde.export
This serialization/deserialization module should have following methods supporting different formats:
Import model into mxnet.
sym, params = mx.serde.import(input_file, input_format=‘onnx’, output_format=’gluon’)
- input_file : input model file (e.g., protobuf model file for onnx)
- input_format : (optional) onnx, coreml
- output_format : (optional) “gluon/symbolic”. by default, gluon will be used.
Note: Currently gluon does not provide an easy way to import a pre-trained model. (there is a workaround using which this can be done).
Export mxnet model to specified input format.
1) mx.serde.export(sym, params, input_format='symbolic', output_format=‘onnx’, filename_prefix=”model_name”)
2) mx.serde.export(module, input_format='symbolic', output_format=‘onnx’, filename_prefix=”model_name”)
3) mx.serde.export(gluon_model, input_format='gluon', output_format=‘onnx’, filename_prefix=”model_name”)
- sym : model definition
- module : mxnet module object
- gluon_model : model definition (HybridBlock)
- params : weights
- input_format : symbolic/gluon
- output_format : onnx, coreml
- filename_prefix: a filename prefix to be used to model save files. E.g., for onnx, a binary protobuf will be written to output file with “.onnx” extension.
In addition to this core format conversion module(`serde`), I think implementing a wrapper according to the different namespace will be more user friendly.
For example,
- Gluon
mx.gluon.import(input_file, input_format=‘onnx’)
mx.gluon.export(output_format=’onnx’, filename_prefix=”model_name”)
If any pre-processing/post-processing logic is required for gluon specific models, it can go under this gluon wrapper. These functions will internally call `serde` module APIs.
e.g. `mx.gluon.import(input_file, input_format=‘onnx’)` will internally call `mx.serde.import(input_file, input_format=‘onnx’, output_format='gluon')`
- For Symbolic interface
sym, params = mx.mod.Module.import(input_file, input_format=‘onnx’)
mx.mod.Module.export(output_format=‘onnx’, filename_prefix=”model_name”)
This function will directly save the file called “model_name.onnx” on disk.
Implementation approaches
There are two approaches which can be taken to import/export onnx model.
Through MXNet's symbolic operators
Implement at the MXNet layer by parsing the ONNX model(in protobuf format) and turn into MXNet Symbolic operators and build MXNet model directly. Similarly, MXNet model can be converted to ONNX format at this layer.
Pros:
- Stable APIs currently used by users.
- More operator support available in mxnet(70%) than available in nnvm/top currently(32%). (see Appendix)
Cons:
- In the future, may have to reimplement at the nnvm/tvm layer, in case MXNet moves to the nnvm/tvm backend. If this happens, will need to implement conversion for the existing symbolic operators for backward compatibility which can be leveraged for onnx-mxnet conversion as well.
- MXNet's API contains some legacy issues which are supposedly fixed in nnvm/top operators. #issue lists down some of the issues and plan to fix them.
Internal API design
The whole implementation will go under MXNet repo.
def import(..., input_format='onnx'):
# returns mxnet graph with parameters
...
return sym, params
def export(..., output_format='onnx'):
# returns ONNX protobuf object
...
return onnx_proto
Through NNVM/top operators
The DMLC community has released the nnvm/tvm complier and an intermediate representation of the models.
Pros:
- Less engineering work in case mxnet moves to nnvm/tvm
- nnvm/tvm would become a hub to convert to different formats.
- Nnvm/top operators are more in parity with mxnet’s gluon APIs this could be useful in case Gluon becomes the only standard that MXNet will support.
- More hardware backends to mxnet, including opencl, metal, Raspberry Pi, web browser. These things are automatically enabled by going through this layer.
- nnvm/top operators do not suffer from legacy issues and strictly follows convention of numpy and Gluon. Cleaner set of operators.
Cons:
- Does not support all operators that exist in MXNet Symbolic API or onnx. 1-1 mapping for 32% of onnx operators. (see Appendix)
- Current Apache MXNet project does not use nnvm/tvm backend. So, users will need to install nnvm/tvm package separately for now.
Internal API design
Implementation will go under nnvm repo.
Wrapper API(serde) will go under mxnet repo which will internally call nnvm package methods.
import nnvm.frontend
def import(..., input_format='onnx'):
# convert from onnx to nnvm graph
nnvm_graph, params = nnvm.frontend.from_onnx(...) # Exists
# convert fron nnvm graph to mxnet graph
mxnet_graph, params = nnvm.frontend.to_mxnet(...) # Need to implement
return mxnet_graph, params
def export(..., output_format='onnx'):
# convert from mxnet to nnvm graph
nnvm_graph, params = nnvm.frontend.from_mxnet(...) # Exists
# convert fron nnvm graph to onnx proto format
onnx_proto = nnvm.frontend.to_onnx(...) # Need to implement
return onnx_proto
Suggested approach:
As a middle ground for both of the above implementation choices, I propose to take the first approach and implement MXNet->ONNX conversion for export functionality and if someone wants to take advantage of NNVM/TVM optimized engine for their usage, they can do it by leveraging import functionality provided in NNVM/TVM package.
Recently, NVIDIA has worked on MXNet->ONNX exporter which is in mxnet_to_onnx github repo. This implementation is also based on the first approach. There is already an issue created to contribute this functionality into MXNet. Though this functionality currently does only file to file conversion (sym, params->protobuf), it can be extended further to do in memory model conversion (module->protobuf).
How users will use the package in training/inference
Whether we decide to take any of the above approaches, this will be an implementation detail which won't change the way these APIs will be delivered to users.
General structure of onnx export after training would look something like this:
data_iter = data_process(mnist)
sym = construct_model_def(...)
# create module for training
mod = mx.mod.Module(symbol=sym, ...)
# train model
mod.fit(data_iter, optimizer='...', optimizer_params='...', ...)
# get parameters of trained model
params = mod.get_params()
# save into different format using sym/params
# This will internally call serde package API `serde.export`
mx.mod.Module.export(sym, params, output_format=‘onnx’, filename=”model.onnx”)
# OR save into different format using module directly
mx.mod.Module.export(output_format=‘onnx’, filename=”model.onnx”)
The general structure of onnx import would look something like this:
# Import model
# This will internally call serde package API `serde.import`
sym, params = mxnet.mod.Module.import(input_file, input_format=‘onnx’)
# create module for inference
mod = mx.mod.Module(symbol=sym, data_names=..., context=mx.cpu(), label_names=None)
mod.bind(for_training=False, data_shapes=..., label_shapes=None)
# set parameters
mod.set_params(arg_params=params, aux_params=None, allow_missing=True)
# forward on the provided data batch
mod.forward(Batch([data_batch]))
Testing
Integrate MXNet's import/export(backend/frontend) functionality with onnx test infrastructure. It has both the combination of unit tests per operator and model level tests. It'd make our testing procedure easier. Refer test backend added for onnx-mxnet package.
References
- ONNX operators
- NNVM/top operators
- MXNet operators
- onnx-mxnet wiki
- Pytorch export method, Caffe2 export method
Appendix
ONNX vs MXNet vs NNVM/top operators available:
(~70% 1:1 mapping between onnx and mxnet) vs (~32% 1:1 mapping between onnx and nnvm):
ONNX | MXNet | NNVM | comments |
add/ add_n | with/without broadcasting | ||
argmin | |||
cast | cast | ||
concatenate | |||
identity? | |||
deconvolution | |||
with/without broadcasting | |||
LeakyReLU | |||
RNN | |||
linalg_gemm | FC, transpose combination | ||
RNN | |||
LeakyReLU | |||
L2Normalization, Pooling? - partial support | |||
linalg_gemm2 without alpha | |||
maximum | |||
RoiPooling | |||
mean | |||
minimum | |||
with/without broadcasting | |||
LeakyReLU | |||
pad | |||
pow | |||
RNN | |||
random_normal | |||
random_normal | |||
random_uniform | |||
random_uniform | |||
reciprocal | |||
max | |||
mean | |||
min | |||
prod | |||
sum | |||
relu | |||
reshape | |||
sigmoid | |||
multiple mxnet slice operators | |||
softmax | |||
Activation-softrelu | |||
implement formula using sign | |||
split | |||
sqrt | |||
broadcast_sub | with/without broadcasting | ||
ElementWiseSum | |||
tanh | |||
tile | |||
transpose | |||
slice | |||
embedding | |||
fullyconnected | |||
identity | |||