mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <>
Subject [GitHub] rahul003 commented on a change in pull request #10283: [MXNET-242][Tutorial] Fine-tuning ONNX model in Gluon
Date Wed, 28 Mar 2018 22:42:13 GMT
rahul003 commented on a change in pull request #10283: [MXNET-242][Tutorial] Fine-tuning ONNX
model in Gluon

 File path: docs/tutorials/onnx/
 @@ -0,0 +1,441 @@
+# Fine-tuning an ONNX model with MXNet/Gluon
+Fine-tuning is a common practice in Transfer Learning. One can take advantage of the pre-trained
weights of a network, and use them as an initializer for their own task. Indeed, quite often
it is difficult to gather a dataset large enough that it would allow training from scratch
deep and complex networks such as ResNet152 or VGG16. For example in an image classification
task, using a network trained on a large dataset like ImageNet gives a good base from which
the weights can be slightly updated, or fine-tuned, to predict accurately the new classes.
We will see in this tutorial that this can be achieved even with a relatively small number
of new training examples.
+[Open Neural Network Exchange (ONNX)]( provides an open source
format for AI models. It defines an extensible computation graph model, as well as definitions
of built-in operators and standard data types.
+In this tutorial we will:
+- learn how to pick a specific layer from a pre-trained .onnx model file
+- learn how to load this model in Gluon and fine-tune it on a different dataset
+## Pre-requisite
+To run the tutorial you will need to have installed the following python modules:
+- [MXNet](
+- [onnx](
+- matplotlib
+- wget
+We recommend that you have done this tutorial:
+- [Inference using an ONNX model on MXNet Gluon](
+import numpy as np
+import mxnet as mx
+from mxnet import gluon, nd, autograd
+from import ImageFolderDataset
+from import DataLoader
+import mxnet.contrib.onnx as onnx_mxnet
+%matplotlib inline
+import matplotlib.pyplot as plt
+import tarfile, os
+import wget
+import json
+import multiprocessing
+### Downloading supporting files
+These are images and a vizualisation script
+image_folder = "images"
+utils_file = "" # contain utils function to plot nice visualization
+images = ['wrench', 'dolphin', 'lotus']
+base_url = "{}?raw=true"
+if not os.path.isdir(image_folder):
+    os.makedirs(image_folder)
+    for image in images:
+"{}/{}.jpg".format(image_folder, image)), image_folder)
+if not os.path.isfile(utils_file):
+from utils import *
+## Downloading a model from the ONNX model zoo
+We download a pre-trained model, in our case the [vgg16](
model, trained on [ImageNet]( from the [ONNX model zoo](
The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file
and some sample input/output data.
+base_url = "" 
+current_model = "vgg16"
+model_folder = "model"
+archive_file = "{}.tar.gz".format(current_model)
+archive_path = os.path.join(model_folder, archive_file)
+url = "{}{}".format(base_url, archive_file)
+onnx_path = os.path.join(model_folder, current_model, 'model.onnx')
+# Create the model folder and download the zipped model
+if not os.path.isdir(model_folder):
+    os.makedirs(model_folder)
+if not os.path.isfile(archive_path):
+    print('Downloading the {} model to {}...'.format(current_model, archive_path))
+, model_folder)
+    print('{} downloaded'.format(current_model))
+# Extract the model
+if not os.path.isdir(os.path.join(model_folder, current_model)):
+    print('Extracting {} in {}...'.format(archive_path, model_folder))
+    tar =, "r:gz")
+    tar.extractall(model_folder)
+    tar.close()
+    print('Model extracted.')
+## Downloading the Caltech101 dataset
+The [Caltech101 dataset]( is made
of pictures of objects belonging to 101 categories. About 40 to 800 images per category. Most
categories have about 50 images.
+*L. Fei-Fei, R. Fergus and P. Perona. Learning generative visual models from few training
examples: an incremental Bayesian approach tested on 101 object categories. IEEE. CVPR 2004,
Workshop on Generative-Model
+Based Vision. 2004*
+data_folder = "data"
+dataset_name = "101_ObjectCategories"
+archive_file = "{}.tar.gz".format(dataset_name)
+archive_path = os.path.join(data_folder, archive_file)
+data_url = ""
+if not os.path.isdir(data_folder):
+    os.makedirs(data_folder)
+if not os.path.isfile(archive_path):
+    print('Downloading {} in {}...'.format(archive_file, data_folder))
+"{}{}".format(data_url, archive_file), data_folder)
+    print('Extracting {} in {}...'.format(archive_file, data_folder))
+    tar =, "r:gz")
+    tar.extractall(data_folder)
+    tar.close()
+    print('Data extracted.')
 Review comment:
   Should we add this to as another option? So that other
tutorials can also benefit from this.

This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:

With regards,
Apache Git Services

View raw message