mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] safrooze closed pull request #9093: example/nce-loss fixes and improvements
Date Thu, 01 Jan 1970 00:00:00 GMT
safrooze closed pull request #9093: example/nce-loss fixes and improvements
URL: https://github.com/apache/incubator-mxnet/pull/9093
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/nce-loss/README.md b/example/nce-loss/README.md
index 8da444df0d..edcb04a9ee 100644
--- a/example/nce-loss/README.md
+++ b/example/nce-loss/README.md
@@ -1,32 +1,43 @@
-#Examples of NCE Loss
+# Examples of NCE Loss
 
-nce-loss is used to speedup multi-class classification when class num is huge.
+[Noise-contrastive estimation](http://proceedings.mlr.press/v9/gutmann10a/gutmann10a.pdf) loss (nce-loss) is used to speedup multi-class classification when class num is huge.
+
+Examples in this folder utilize [text8](http://mattmahoney.net/dc/textdata.html) dataset, which is a 100MB of cleaned up English Wikipedia XML data. Wikipedia data is multi-licensed under the [Creative Commons Attribution-ShareAlike 3.0 License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_Creative_Commons_Attribution-ShareAlike_3.0_Unported_License) (CC-BY-SA) and the [GNU Free Documentation License](https://en.wikipedia.org/wiki/Wikipedia:Text_of_the_GNU_Free_Documentation_License) (GFDL). For information on licensing of Wikipedia data please visit [here](https://en.wikipedia.org/wiki/Wikipedia:Database_download).
 
 ## Toy example
 
-* toy_softmax.py: a multi class example using softmax output
-* toy_nce.py: a multi-class example using nce loss
+* toy_softmax.py: a multi class example using softmax output. Command to start training on CPU:
+```
+python toy_softmax.py
+```
 
-## Word2Vec
+* toy_nce.py: equivalent example to the above toy_softmax, except using nce loss. Command to start training on CPU:
+```
+python toy_nce.py
+```
 
-* word2vec.py: a CBOW word2vec example using nce loss
+## Dataset Download
 
-You can run it by
+The dataset used in the following examples is [text8](http://mattmahoney.net/dc/textdata.html) dataset mentioned above. The example scripts expect the dataset to exist in a folder named 'data'. The included get_text8.sh script downloads the dataset into the correct path. Command to download:
 
 ```
 ./get_text8.sh
+```
+
+## Word2Vec
+
+* word2vec.py: a CBOW word2vec example using nce loss. You need to [download the text8 dataset](#dataset-download) before running this script. Command to start training on CPU (pass -g for training on GPU):
+
+```
 python word2vec.py
 
 ```
 
 ## LSTM
 
-* lstm_word.py: a lstm example use nce loss
-
-You can run it by
+* lstm_word.py: a lstm example use nce loss. You need to [download the text8 dataset](#dataset-download) before running this script. Pass -h (or --help) to see command line option for GPU training. Command to start training on CPU (pass -g for training on GPU):
 
 ```
-./get_text8.sh
 python lstm_word.py
 ```
 
@@ -37,14 +48,12 @@ You can refer to [http://www.jianshu.com/p/e439b43ea464](http://www.jianshu.com/
 
 ## Word2Vec in NCE-loss with Subword Representation
 
+wordvec_subwords.py: Reproducing the work Microsoft Research presented in CIKM'14, in which it's a basis of DSSM([Deep Semantics Similarity Model](https://www.microsoft.com/en-us/research/project/dssm/)), you can get its lectures [here](https://www.microsoft.com/en-us/research/publication/deep-learning-for-natural-language-processing-theory-and-practice-tutorial/). You need to [download the text8 dataset](#dataset-download) before running this script. Command to start training on CPU (pass -g for training on GPU):
+
 ```
-./get_text8.sh
 python wordvec_subwords.py
 ```
 
-Reproducing the work Microsoft Research presented in CIKM'14, in which it's a basis of DSSM([Deep Semantics Similarity Model](https://www.microsoft.com/en-us/research/project/dssm/)), you can get its lectures [here](https://www.microsoft.com/en-us/research/publication/deep-learning-for-natural-language-processing-theory-and-practice-tutorial/).
-
-
 ### Motivation
 
 The motivation is to design a more robust and scalable word vector system, by reducing the size of lookup-table, and handle unknown words(out-of-vocabulary) better.
@@ -64,18 +73,16 @@ If you use sub-word sequence and feed into a word2vec training processing, it co
 
 ### Analysis
 
-> Experiment data on MacBook Pro'16 with 4 cpus.
-
-Here we print the training log below, using text8 data, to get some intuitions on its benefits:
+This experiment was performed on MacBook Pro with 4 cpus. Here we print the training log below, using text8 data, to get some intuitions on its benefits:
 
 *With subword units representation*
 
-It converge much faster.
+Then network training converges much faster.
 
 ```
 2016-11-26 19:07:31,742 Start training with [cpu(0), cpu(1), cpu(2), cpu(3)]
 2016-11-26 19:07:31,783 DataIter start.
-2016-11-26 19:07:45,099 Epoch[0] Batch [50]		Speed: 4020.37 samples/sec	Train-nce-auc=0.693178
+2016-11-26 19:07:45,099 Epoch[0] Batch [50]	Speed: 4020.37 samples/sec	Train-nce-auc=0.693178
 2016-11-26 19:07:57,870 Epoch[0] Batch [100]	Speed: 4009.19 samples/sec	Train-nce-auc=0.741482
 2016-11-26 19:08:10,196 Epoch[0] Batch [150]	Speed: 4153.73 samples/sec	Train-nce-auc=0.764026
 2016-11-26 19:08:22,497 Epoch[0] Batch [200]	Speed: 4162.61 samples/sec	Train-nce-auc=0.785248
@@ -93,18 +100,6 @@ It converge much faster.
 2016-11-26 19:10:53,362 Epoch[0] Batch [800]	Speed: 4123.59 samples/sec	Train-nce-auc=0.834170
 2016-11-26 19:11:05,645 Epoch[0] Batch [850]	Speed: 4168.32 samples/sec	Train-nce-auc=0.836135
 2016-11-26 19:11:18,035 Epoch[0] Batch [900]	Speed: 4132.51 samples/sec	Train-nce-auc=0.842253
-2016-11-26 19:11:30,257 Epoch[0] Batch [950]	Speed: 4189.27 samples/sec	Train-nce-auc=0.834119
-2016-11-26 19:11:42,600 Epoch[0] Batch [1000]	Speed: 4148.01 samples/sec	Train-nce-auc=0.828049
-2016-11-26 19:11:54,850 Epoch[0] Batch [1050]	Speed: 4179.55 samples/sec	Train-nce-auc=0.844856
-2016-11-26 19:12:07,052 Epoch[0] Batch [1100]	Speed: 4196.35 samples/sec	Train-nce-auc=0.856587
-2016-11-26 19:12:19,286 Epoch[0] Batch [1150]	Speed: 4185.10 samples/sec	Train-nce-auc=0.845370
-2016-11-26 19:12:31,703 Epoch[0] Batch [1200]	Speed: 4123.25 samples/sec	Train-nce-auc=0.851430
-2016-11-26 19:12:44,177 Epoch[0] Batch [1250]	Speed: 4104.76 samples/sec	Train-nce-auc=0.851357
-2016-11-26 19:12:56,497 Epoch[0] Batch [1300]	Speed: 4155.90 samples/sec	Train-nce-auc=0.854957
-2016-11-26 19:13:08,839 Epoch[0] Batch [1350]	Speed: 4148.39 samples/sec	Train-nce-auc=0.853684
-2016-11-26 19:13:21,052 Epoch[0] Batch [1400]	Speed: 4192.37 samples/sec	Train-nce-auc=0.849442
-2016-11-26 19:13:33,386 Epoch[0] Batch [1450]	Speed: 4151.24 samples/sec	Train-nce-auc=0.853365
-2016-11-26 19:13:45,709 Epoch[0] Batch [1500]	Speed: 4154.65 samples/sec	Train-nce-auc=0.855938
 ```
 
 
diff --git a/example/nce-loss/lstm_net.py b/example/nce-loss/lstm_net.py
new file mode 100644
index 0000000000..e67477826a
--- /dev/null
+++ b/example/nce-loss/lstm_net.py
@@ -0,0 +1,104 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+from collections import namedtuple
+
+import mxnet as mx
+from nce import nce_loss
+
+LSTMState = namedtuple("LSTMState", ["c", "h"])
+LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
+                                     "h2h_weight", "h2h_bias"])
+
+
+def _lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
+    """LSTM Cell symbol"""
+    if dropout > 0.:
+        indata = mx.sym.Dropout(data=indata, p=dropout)
+    i2h = mx.sym.FullyConnected(data=indata,
+                                weight=param.i2h_weight,
+                                bias=param.i2h_bias,
+                                num_hidden=num_hidden * 4,
+                                name="t%d_l%d_i2h" % (seqidx, layeridx))
+    h2h = mx.sym.FullyConnected(data=prev_state.h,
+                                weight=param.h2h_weight,
+                                bias=param.h2h_bias,
+                                num_hidden=num_hidden * 4,
+                                name="t%d_l%d_h2h" % (seqidx, layeridx))
+    gates = i2h + h2h
+    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
+                                      name="t%d_l%d_slice" % (seqidx, layeridx))
+    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
+    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
+    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
+    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
+    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
+    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
+    return LSTMState(c=next_c, h=next_h)
+
+
+def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden):
+    param_cells = []
+    last_states = []
+    for i in range(num_lstm_layer):
+        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
+                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
+                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
+                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
+        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
+                          h=mx.sym.Variable("l%d_init_h" % i))
+        last_states.append(state)
+
+    data = mx.sym.Variable('data')
+    label = mx.sym.Variable('label')
+    label_weight = mx.sym.Variable('label_weight')
+    embed_weight = mx.sym.Variable('embed_weight')
+    label_embed_weight = mx.sym.Variable('label_embed_weight')
+    data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
+                                  weight=embed_weight,
+                                  output_dim=100, name='data_embed')
+    datavec = mx.sym.SliceChannel(data=data_embed,
+                                  num_outputs=seq_len,
+                                  squeeze_axis=True, name='data_slice')
+    labelvec = mx.sym.SliceChannel(data=label,
+                                   num_outputs=seq_len,
+                                   squeeze_axis=True, name='label_slice')
+    labelweightvec = mx.sym.SliceChannel(data=label_weight,
+                                         num_outputs=seq_len,
+                                         squeeze_axis=True, name='label_weight_slice')
+    probs = []
+    for seqidx in range(seq_len):
+        hidden = datavec[seqidx]
+
+        for i in range(num_lstm_layer):
+            next_state = _lstm(num_hidden, indata=hidden,
+                               prev_state=last_states[i],
+                               param=param_cells[i],
+                               seqidx=seqidx, layeridx=i)
+            hidden = next_state.h
+            last_states[i] = next_state
+
+        probs.append(nce_loss(data=hidden,
+                              label=labelvec[seqidx],
+                              label_weight=labelweightvec[seqidx],
+                              embed_weight=label_embed_weight,
+                              vocab_size=vocab_size,
+                              num_hidden=100))
+    return mx.sym.Group(probs)
diff --git a/example/nce-loss/lstm_word.py b/example/nce-loss/lstm_word.py
index 23729917d9..063cece0af 100644
--- a/example/nce-loss/lstm_word.py
+++ b/example/nce-loss/lstm_word.py
@@ -15,198 +15,27 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
+# pylint: disable=missing-docstring, deprecated-module
 from __future__ import print_function
+
 import logging
-import sys, random, time, math
-sys.path.insert(0, "../../python")
-import mxnet as mx
-import numpy as np
-from collections import namedtuple
-from nce import *
-from operator import itemgetter
 from optparse import OptionParser
 
-LSTMState = namedtuple("LSTMState", ["c", "h"])
-LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias",
-                                     "h2h_weight", "h2h_bias"])
-LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
-                                     "init_states", "last_states",
-                                     "seq_data", "seq_labels", "seq_outputs",
-                                     "param_blocks"])
-
-def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
-    """LSTM Cell symbol"""
-    if dropout > 0.:
-        indata = mx.sym.Dropout(data=indata, p=dropout)
-    i2h = mx.sym.FullyConnected(data=indata,
-                                weight=param.i2h_weight,
-                                bias=param.i2h_bias,
-                                num_hidden=num_hidden * 4,
-                                name="t%d_l%d_i2h" % (seqidx, layeridx))
-    h2h = mx.sym.FullyConnected(data=prev_state.h,
-                                weight=param.h2h_weight,
-                                bias=param.h2h_bias,
-                                num_hidden=num_hidden * 4,
-                                name="t%d_l%d_h2h" % (seqidx, layeridx))
-    gates = i2h + h2h
-    slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
-                                      name="t%d_l%d_slice" % (seqidx, layeridx))
-    in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
-    in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
-    forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
-    out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
-    next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
-    next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
-    return LSTMState(c=next_c, h=next_h)
-
-
-def get_net(vocab_size, seq_len, num_label, num_lstm_layer, num_hidden):
-    param_cells = []
-    last_states = []
-    for i in range(num_lstm_layer):
-        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
-                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
-                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
-                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
-        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
-                          h=mx.sym.Variable("l%d_init_h" % i))
-        last_states.append(state)
-
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    label_weight = mx.sym.Variable('label_weight')
-    embed_weight = mx.sym.Variable('embed_weight')
-    label_embed_weight = mx.sym.Variable('label_embed_weight')
-    data_embed = mx.sym.Embedding(data = data, input_dim = vocab_size,
-                                  weight = embed_weight,
-                                  output_dim = 100, name = 'data_embed')
-    datavec = mx.sym.SliceChannel(data = data_embed,
-                                  num_outputs = seq_len,
-                                  squeeze_axis = True, name = 'data_slice')
-    labelvec = mx.sym.SliceChannel(data = label,
-                                   num_outputs = seq_len,
-                                   squeeze_axis = True, name = 'label_slice')
-    labelweightvec = mx.sym.SliceChannel(data = label_weight,
-                                         num_outputs = seq_len,
-                                         squeeze_axis = True, name = 'label_weight_slice')
-    probs = []
-    for seqidx in range(seq_len):
-        hidden = datavec[seqidx]
-
-        for i in range(num_lstm_layer):
-            next_state = lstm(num_hidden, indata = hidden,
-                              prev_state = last_states[i],
-                              param = param_cells[i],
-                              seqidx = seqidx, layeridx = i)
-            hidden = next_state.h
-            last_states[i] = next_state
-
-        probs.append(nce_loss(data = hidden,
-                              label = labelvec[seqidx],
-                              label_weight = labelweightvec[seqidx],
-                              embed_weight = label_embed_weight,
-                              vocab_size = vocab_size,
-                              num_hidden = 100,
-                              num_label = num_label))
-    return mx.sym.Group(probs)
-
-
-def load_data(name):
-    buf = open(name).read()
-    tks = buf.split(' ')
-    vocab = {}
-    freq = [0]
-    data = []
-    for tk in tks:
-        if len(tk) == 0:
-            continue
-        if tk not in vocab:
-            vocab[tk] = len(vocab) + 1
-            freq.append(0)
-        wid = vocab[tk]
-        data.append(wid)
-        freq[wid] += 1
-    negative = []
-    for i, v in enumerate(freq):
-        if i == 0 or v < 5:
-            continue
-        v = int(math.pow(v * 1.0, 0.75))
-        negative += [i for _ in range(v)]
-    return data, negative, vocab, freq
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, name, batch_size, seq_len, num_label, init_states):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.data, self.negative, self.vocab, self.freq = load_data(name)
-        self.vocab_size = 1 + len(self.vocab)
-        print(self.vocab_size)
-        self.seq_len = seq_len
-        self.num_label = num_label
-        self.init_states = init_states
-        self.init_state_names = [x[0] for x in self.init_states]
-        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
-        self.provide_data = [('data', (batch_size, seq_len))] + init_states
-        self.provide_label = [('label', (self.batch_size, seq_len, num_label)),
-                              ('label_weight', (self.batch_size, seq_len, num_label))]
-
-    def sample_ne(self):
-        return self.negative[random.randint(0, len(self.negative) - 1)]
-
-    def __iter__(self):
-        print('begin')
-        batch_data = []
-        batch_label = []
-        batch_label_weight = []
-        for i in range(0, len(self.data) - self.seq_len - 1, self.seq_len):
-            data = self.data[i: i+self.seq_len]
-            label = [[self.data[i+k+1]] \
-                     + [self.sample_ne() for _ in range(self.num_label-1)]\
-                     for k in range(self.seq_len)]
-            label_weight = [[1.0] \
-                            + [0.0 for _ in range(self.num_label-1)]\
-                            for k in range(self.seq_len)]
-
-            batch_data.append(data)
-            batch_label.append(label)
-            batch_label_weight.append(label_weight)
-            if len(batch_data) == self.batch_size:
-                data_all = [mx.nd.array(batch_data)] + self.init_state_arrays
-                label_all = [mx.nd.array(batch_label), mx.nd.array(batch_label_weight)]
-                data_names = ['data'] + self.init_state_names
-                label_names = ['label', 'label_weight']
-                batch_data = []
-                batch_label = []
-                batch_label_weight = []
-                yield SimpleBatch(data_names, data_all, label_names, label_all)
+import mxnet as mx
+from nce import NceLSTMAuc
+from text8_data import DataIterLstm
+from lstm_net import get_lstm_net
 
-    def reset(self):
-        pass
 
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
     logging.basicConfig(level=logging.DEBUG, format=head)
 
     parser = OptionParser()
-    parser.add_option("-g", "--gpu", action = "store_true", dest = "gpu", default = False,
-                      help = "use gpu")
+    parser.add_option("-g", "--gpu", action="store_true", dest="gpu", default=False,
+                      help="use gpu")
+    options, args = parser.parse_args()
+
     batch_size = 1024
     seq_len = 5
     num_label = 6
@@ -217,25 +46,29 @@ def reset(self):
     init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)]
     init_states = init_c + init_h
 
+    data_train = DataIterLstm("./data/text8", batch_size, seq_len, num_label, init_states)
 
-    data_train = DataIter("./data/text8", batch_size, seq_len, num_label,
-                          init_states)
+    network = get_lstm_net(data_train.vocab_size, seq_len, num_lstm_layer, num_hidden)
 
-    network = get_net(data_train.vocab_size, seq_len, num_label, num_lstm_layer, num_hidden)
-    options, args = parser.parse_args()
     devs = mx.cpu()
-    if options.gpu == True:
+    if options.gpu:
         devs = mx.gpu()
-    model = mx.model.FeedForward(ctx = devs,
-                                 symbol = network,
-                                 num_epoch = 20,
-                                 learning_rate = 0.3,
-                                 momentum = 0.9,
-                                 wd = 0.0000,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
 
-    metric = NceLSTMAuc()
-    model.fit(X = data_train,
-              eval_metric = metric,
-              batch_end_callback = mx.callback.Speedometer(batch_size, 50),)
+    model = mx.mod.Module(
+        symbol=network,
+        data_names=[x[0] for x in data_train.provide_data],
+        label_names=[y[0] for y in data_train.provide_label],
+        context=[devs]
+    )
 
+    print("Training on {}".format("GPU" if options.gpu else "CPU"))
+    metric = NceLSTMAuc()
+    model.fit(
+        train_data=data_train,
+        num_epoch=20,
+        optimizer='sgd',
+        optimizer_params={'learning_rate': 0.3, 'momentum': 0.9, 'wd': 0.0000},
+        initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
+        eval_metric=metric,
+        batch_end_callback=mx.callback.Speedometer(batch_size, 50)
+    )
diff --git a/example/nce-loss/nce.py b/example/nce-loss/nce.py
index 7f57dfdb75..e59220a026 100644
--- a/example/nce-loss/nce.py
+++ b/example/nce-loss/nce.py
@@ -15,48 +15,51 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
-import sys
-sys.path.insert(0, "../../python")
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+from operator import itemgetter
+
 import mxnet as mx
 import numpy as np
-from operator import itemgetter
 
-def nce_loss(data, label, label_weight, embed_weight, vocab_size, num_hidden, num_label):
-    label_embed = mx.sym.Embedding(data = label, input_dim = vocab_size,
-                                   weight = embed_weight,
-                                   output_dim = num_hidden, name = 'label_embed')
-    data = mx.sym.Reshape(data = data, shape = (-1, 1, num_hidden))
+
+def nce_loss(data, label, label_weight, embed_weight, vocab_size, num_hidden):
+    label_embed = mx.sym.Embedding(data=label, input_dim=vocab_size,
+                                   weight=embed_weight,
+                                   output_dim=num_hidden, name='label_embed')
+    data = mx.sym.Reshape(data=data, shape=(-1, 1, num_hidden))
     pred = mx.sym.broadcast_mul(data, label_embed)
-    pred = mx.sym.sum(data = pred, axis = 2)
-    return mx.sym.LogisticRegressionOutput(data = pred,
-                                           label = label_weight)
+    pred = mx.sym.sum(data=pred, axis=2)
+    return mx.sym.LogisticRegressionOutput(data=pred,
+                                           label=label_weight)
 
 
-def nce_loss_subwords(data, label, label_mask, label_weight, embed_weight, vocab_size, num_hidden, num_label):
+def nce_loss_subwords(
+        data, label, label_mask, label_weight, embed_weight, vocab_size, num_hidden):
     """NCE-Loss layer under subword-units input.
     """
     # get subword-units embedding.
-    label_units_embed = mx.sym.Embedding(data = label,
-                                         input_dim = vocab_size,
-                                         weight = embed_weight,
-                                         output_dim = num_hidden)
+    label_units_embed = mx.sym.Embedding(data=label,
+                                         input_dim=vocab_size,
+                                         weight=embed_weight,
+                                         output_dim=num_hidden)
     # get valid subword-units embedding with the help of label_mask
-    # it's achieve by multiply zeros to useless units in order to handle variable-length input.
-    label_units_embed = mx.sym.broadcast_mul(lhs = label_units_embed,
-                                             rhs = label_mask,
-                                             name = 'label_units_embed')
+    # it's achieved by multiplying zeros to useless units in order to handle variable-length input.
+    label_units_embed = mx.sym.broadcast_mul(lhs=label_units_embed,
+                                             rhs=label_mask,
+                                             name='label_units_embed')
     # sum over them to get label word embedding.
-    label_embed = mx.sym.sum(label_units_embed, axis=2, name = 'label_embed')
+    label_embed = mx.sym.sum(label_units_embed, axis=2, name='label_embed')
 
-    # by boardcast_mul and sum you can get prediction scores in all num_label inputs,
+    # by boardcast_mul and sum you can get prediction scores in all label_embed inputs,
     # which is easy to feed into LogisticRegressionOutput and make your code more concise.
-    data = mx.sym.Reshape(data = data, shape = (-1, 1, num_hidden))
+    data = mx.sym.Reshape(data=data, shape=(-1, 1, num_hidden))
     pred = mx.sym.broadcast_mul(data, label_embed)
-    pred = mx.sym.sum(data = pred, axis = 2)
+    pred = mx.sym.sum(data=pred, axis=2)
 
-    return mx.sym.LogisticRegressionOutput(data = pred,
-                                           label = label_weight)
+    return mx.sym.LogisticRegressionOutput(data=pred,
+                                           label=label_weight)
 
 
 class NceAccuracy(mx.metric.EvalMetric):
@@ -71,6 +74,7 @@ def update(self, labels, preds):
                 self.sum_metric += 1
             self.num_inst += 1
 
+
 class NceAuc(mx.metric.EvalMetric):
     def __init__(self):
         super(NceAuc, self).__init__('nce-auc')
@@ -82,12 +86,12 @@ def update(self, labels, preds):
         for i in range(preds.shape[0]):
             for j in range(preds.shape[1]):
                 tmp.append((label_weight[i][j], preds[i][j]))
-        tmp = sorted(tmp, key = itemgetter(1), reverse = True)
+        tmp = sorted(tmp, key=itemgetter(1), reverse=True)
         m = 0.0
         n = 0.0
         z = 0.0
         k = 0
-        for a, b in tmp:
+        for a, _ in tmp:
             if a > 0.5:
                 m += 1.0
                 z += len(tmp) - k
@@ -100,6 +104,7 @@ def update(self, labels, preds):
         self.sum_metric += z
         self.num_inst += 1
 
+
 class NceLSTMAuc(mx.metric.EvalMetric):
     def __init__(self):
         super(NceLSTMAuc, self).__init__('nce-lstm-auc')
@@ -115,12 +120,12 @@ def update(self, labels, preds):
         for i in range(preds.shape[0]):
             for j in range(preds.shape[1]):
                 tmp.append((label_weight[i][j], preds[i][j]))
-        tmp = sorted(tmp, key = itemgetter(1), reverse = True)
+        tmp = sorted(tmp, key=itemgetter(1), reverse=True)
         m = 0.0
         n = 0.0
         z = 0.0
         k = 0
-        for a, b in tmp:
+        for a, _ in tmp:
             if a > 0.5:
                 m += 1.0
                 z += len(tmp) - k
diff --git a/example/nce-loss/random_data.py b/example/nce-loss/random_data.py
new file mode 100644
index 0000000000..7d88e2931f
--- /dev/null
+++ b/example/nce-loss/random_data.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+import random
+
+import mxnet as mx
+import numpy as np
+
+
+class SimpleBatch(object):
+    def __init__(self, data_names, data, label_names, label):
+        self.data = data
+        self.label = label
+        self.data_names = data_names
+        self.label_names = label_names
+
+    @property
+    def provide_data(self):
+        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
+
+    @property
+    def provide_label(self):
+        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
+
+
+class DataIterSoftmax(mx.io.DataIter):
+    def __init__(self, count, batch_size, vocab_size, num_label, feature_size):
+        super(DataIterSoftmax, self).__init__()
+        self.batch_size = batch_size
+        self.count = count
+        self.vocab_size = vocab_size
+        self.num_label = num_label
+        self.feature_size = feature_size
+        self.provide_data = [('data', (batch_size, feature_size))]
+        self.provide_label = [('label', (self.batch_size,))]
+
+    def mock_sample(self):
+        ret = np.zeros(self.feature_size)
+        rn = set()
+        while len(rn) < 3:
+            rn.add(random.randint(0, self.feature_size - 1))
+        s = 0
+        for k in rn:
+            ret[k] = 1.0
+            s *= self.feature_size
+            s += k
+        return ret, s % self.vocab_size
+
+    def __iter__(self):
+        for _ in range(self.count // self.batch_size):
+            data = []
+            label = []
+            for _ in range(self.batch_size):
+                d, l = self.mock_sample()
+                data.append(d)
+                label.append(l)
+            data_all = [mx.nd.array(data)]
+            label_all = [mx.nd.array(label)]
+            data_names = ['data']
+            label_names = ['label']
+            yield SimpleBatch(data_names, data_all, label_names, label_all)
+
+    def reset(self):
+        pass
+
+
+class DataIterNce(mx.io.DataIter):
+    def __init__(self, count, batch_size, vocab_size, num_label, feature_size):
+        super(DataIterNce, self).__init__()
+        self.batch_size = batch_size
+        self.count = count
+        self.vocab_size = vocab_size
+        self.num_label = num_label
+        self.feature_size = feature_size
+        self.provide_data = [('data', (batch_size, feature_size))]
+        self.provide_label = [('label', (self.batch_size, num_label)),
+                              ('label_weight', (self.batch_size, num_label))]
+
+    def mock_sample(self):
+        ret = np.zeros(self.feature_size)
+        rn = set()
+        while len(rn) < 3:
+            rn.add(random.randint(0, self.feature_size - 1))
+        s = 0
+        for k in rn:
+            ret[k] = 1.0
+            s *= self.feature_size
+            s += k
+        la = [s % self.vocab_size] +\
+             [random.randint(0, self.vocab_size - 1) for _ in range(self.num_label - 1)]
+        return ret, la
+
+    def __iter__(self):
+        for _ in range(self.count // self.batch_size):
+            data = []
+            label = []
+            label_weight = []
+            for _ in range(self.batch_size):
+                d, l = self.mock_sample()
+                data.append(d)
+                label.append(l)
+                label_weight.append([1.0] + [0.0 for _ in range(self.num_label - 1)])
+            data_all = [mx.nd.array(data)]
+            label_all = [mx.nd.array(label), mx.nd.array(label_weight)]
+            data_names = ['data']
+            label_names = ['label', 'label_weight']
+            yield SimpleBatch(data_names, data_all, label_names, label_all)
+
+    def reset(self):
+        pass
diff --git a/example/nce-loss/text8_data.py b/example/nce-loss/text8_data.py
new file mode 100644
index 0000000000..6af72e9c8c
--- /dev/null
+++ b/example/nce-loss/text8_data.py
@@ -0,0 +1,362 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+from collections import Counter
+import logging
+import math
+import random
+
+import mxnet as mx
+import numpy as np
+
+
+def _load_data(name):
+    buf = open(name).read()
+    tks = buf.split(' ')
+    vocab = {}
+    freq = [0]
+    data = []
+    for tk in tks:
+        if len(tk) == 0:
+            continue
+        if tk not in vocab:
+            vocab[tk] = len(vocab) + 1
+            freq.append(0)
+        wid = vocab[tk]
+        data.append(wid)
+        freq[wid] += 1
+    negative = []
+    for i, v in enumerate(freq):
+        if i == 0 or v < 5:
+            continue
+        v = int(math.pow(v * 1.0, 0.75))
+        negative += [i for _ in range(v)]
+    return data, negative, vocab, freq
+
+
+class SubwordData(object):
+    def __init__(self, data, units, weights, negative_units, negative_weights, vocab, units_vocab,
+                 freq, max_len):
+        self.data = data
+        self.units = units
+        self.weights = weights
+        self.negative_units = negative_units
+        self.negative_weights = negative_weights
+        self.vocab = vocab
+        self.units_vocab = units_vocab
+        self.freq = freq
+        self.max_len = max_len
+
+
+def _get_subword_units(token, gram):
+    """Return subword-units presentation, given a word/token.
+    """
+    if token == '</s>':  # special token for padding purpose.
+        return [token]
+    t = '#' + token + '#'
+    return [t[i:i + gram] for i in range(0, len(t) - gram + 1)]
+
+
+def _get_subword_representation(wid, vocab_inv, units_vocab, max_len, gram, padding_char):
+    token = vocab_inv[wid]
+    units = [units_vocab[unit] for unit in _get_subword_units(token, gram)]
+    weights = [1] * len(units) + [0] * (max_len - len(units))
+    units = units + [units_vocab[padding_char]] * (max_len - len(units))
+    return units, weights
+
+
+def _prepare_subword_units(tks, gram, padding_char):
+    # statistics on units
+    units_vocab = {padding_char: 1}
+    max_len = 0
+    unit_set = set()
+    logging.info('grams: %d', gram)
+    logging.info('counting max len...')
+    for tk in tks:
+        res = _get_subword_units(tk, gram)
+        unit_set.update(i for i in res)
+        if max_len < len(res):
+            max_len = len(res)
+    logging.info('preparing units vocab...')
+    for unit in unit_set:
+        if len(unit) == 0:
+            continue
+        if unit not in units_vocab:
+            units_vocab[unit] = len(units_vocab)
+        # uid = units_vocab[unit]
+    return units_vocab, max_len
+
+
+def _load_data_as_subword_units(name, min_count, gram, max_subwords, padding_char):
+    tks = []
+    fread = open(name, 'rb')
+    logging.info('reading corpus from file...')
+    for line in fread:
+        line = line.strip().decode('utf-8')
+        tks.extend(line.split(' '))
+
+    logging.info('Total tokens: %d', len(tks))
+
+    tks = [tk for tk in tks if len(tk) <= max_subwords]
+    c = Counter(tks)
+
+    logging.info('Total vocab: %d', len(c))
+
+    vocab = {}
+    vocab_inv = {}
+    freq = [0]
+    data = []
+    for tk in tks:
+        if len(tk) == 0:
+            continue
+        if tk not in vocab:
+            vocab[tk] = len(vocab)
+            freq.append(0)
+        wid = vocab[tk]
+        vocab_inv[wid] = tk
+        data.append(wid)
+        freq[wid] += 1
+
+    negative = []
+    for i, v in enumerate(freq):
+        if i == 0 or v < min_count:
+            continue
+        v = int(math.pow(v * 1.0, 0.75))  # sample negative w.r.t. its frequency
+        negative += [i for _ in range(v)]
+
+    logging.info('counting subword units...')
+    units_vocab, max_len = _prepare_subword_units(tks, gram, padding_char)
+    logging.info('vocabulary size: %d', len(vocab))
+    logging.info('subword unit size: %d', len(units_vocab))
+
+    logging.info('generating input data...')
+    units = []
+    weights = []
+    for wid in data:
+        word_units, weight = _get_subword_representation(
+            wid, vocab_inv, units_vocab, max_len, gram, padding_char)
+        units.append(word_units)
+        weights.append(weight)
+
+    negative_units = []
+    negative_weights = []
+    for wid in negative:
+        word_units, weight = _get_subword_representation(
+            wid, vocab_inv, units_vocab, max_len, gram, padding_char)
+        negative_units.append(word_units)
+        negative_weights.append(weight)
+
+    return SubwordData(
+        data=data, units=units, weights=weights, negative_units=negative_units,
+        negative_weights=negative_weights, vocab=vocab, units_vocab=units_vocab,
+        freq=freq, max_len=max_len
+    )
+
+
+class SimpleBatch(object):
+    def __init__(self, data_names, data, label_names, label):
+        self.data = data
+        self.label = label
+        self.data_names = data_names
+        self.label_names = label_names
+
+    @property
+    def provide_data(self):
+        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
+
+    @property
+    def provide_label(self):
+        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
+
+
+class DataIterWords(mx.io.DataIter):
+    def __init__(self, name, batch_size, num_label):
+        super(DataIterWords, self).__init__()
+        self.batch_size = batch_size
+        self.data, self.negative, self.vocab, self.freq = _load_data(name)
+        self.vocab_size = 1 + len(self.vocab)
+        print("Vocabulary Size: {}".format(self.vocab_size))
+        self.num_label = num_label
+        self.provide_data = [('data', (batch_size, num_label - 1))]
+        self.provide_label = [('label', (self.batch_size, num_label)),
+                              ('label_weight', (self.batch_size, num_label))]
+
+    def sample_ne(self):
+        return self.negative[random.randint(0, len(self.negative) - 1)]
+
+    def __iter__(self):
+        batch_data = []
+        batch_label = []
+        batch_label_weight = []
+        start = random.randint(0, self.num_label - 1)
+        for i in range(start, len(self.data) - self.num_label - start, self.num_label):
+            context = self.data[i: i + self.num_label // 2] \
+                      + self.data[i + 1 + self.num_label // 2: i + self.num_label]
+            target_word = self.data[i + self.num_label // 2]
+            if self.freq[target_word] < 5:
+                continue
+            target = [target_word] + [self.sample_ne() for _ in range(self.num_label - 1)]
+            target_weight = [1.0] + [0.0 for _ in range(self.num_label - 1)]
+            batch_data.append(context)
+            batch_label.append(target)
+            batch_label_weight.append(target_weight)
+            if len(batch_data) == self.batch_size:
+                data_all = [mx.nd.array(batch_data)]
+                label_all = [mx.nd.array(batch_label), mx.nd.array(batch_label_weight)]
+                data_names = ['data']
+                label_names = ['label', 'label_weight']
+                batch_data = []
+                batch_label = []
+                batch_label_weight = []
+                yield SimpleBatch(data_names, data_all, label_names, label_all)
+
+    def reset(self):
+        pass
+
+
+class DataIterLstm(mx.io.DataIter):
+    def __init__(self, name, batch_size, seq_len, num_label, init_states):
+        super(DataIterLstm, self).__init__()
+        self.batch_size = batch_size
+        self.data, self.negative, self.vocab, self.freq = _load_data(name)
+        self.vocab_size = 1 + len(self.vocab)
+        print("Vocabulary Size: {}".format(self.vocab_size))
+        self.seq_len = seq_len
+        self.num_label = num_label
+        self.init_states = init_states
+        self.init_state_names = [x[0] for x in self.init_states]
+        self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]
+        self.provide_data = [('data', (batch_size, seq_len))] + init_states
+        self.provide_label = [('label', (self.batch_size, seq_len, num_label)),
+                              ('label_weight', (self.batch_size, seq_len, num_label))]
+
+    def sample_ne(self):
+        return self.negative[random.randint(0, len(self.negative) - 1)]
+
+    def __iter__(self):
+        batch_data = []
+        batch_label = []
+        batch_label_weight = []
+        for i in range(0, len(self.data) - self.seq_len - 1, self.seq_len):
+            data = self.data[i: i+self.seq_len]
+            label = [[self.data[i+k+1]] \
+                     + [self.sample_ne() for _ in range(self.num_label-1)]\
+                     for k in range(self.seq_len)]
+            label_weight = [[1.0] \
+                            + [0.0 for _ in range(self.num_label-1)]\
+                            for k in range(self.seq_len)]
+
+            batch_data.append(data)
+            batch_label.append(label)
+            batch_label_weight.append(label_weight)
+            if len(batch_data) == self.batch_size:
+                data_all = [mx.nd.array(batch_data)] + self.init_state_arrays
+                label_all = [mx.nd.array(batch_label), mx.nd.array(batch_label_weight)]
+                data_names = ['data'] + self.init_state_names
+                label_names = ['label', 'label_weight']
+                batch_data = []
+                batch_label = []
+                batch_label_weight = []
+                yield SimpleBatch(data_names, data_all, label_names, label_all)
+
+    def reset(self):
+        pass
+
+
+class DataIterSubWords(mx.io.DataIter):
+    def __init__(self, fname, batch_size, num_label, min_count, gram, max_subwords, padding_char):
+        super(DataIterSubWords, self).__init__()
+        self.batch_size = batch_size
+        self.min_count = min_count
+        self.swd = _load_data_as_subword_units(
+            fname,
+            min_count=min_count,
+            gram=gram,
+            max_subwords=max_subwords,
+            padding_char=padding_char)
+        self.vocab_size = len(self.swd.units_vocab)
+        self.num_label = num_label
+        self.provide_data = [('data', (batch_size, num_label - 1, self.swd.max_len)),
+                             ('mask', (batch_size, num_label - 1, self.swd.max_len, 1))]
+        self.provide_label = [('label', (self.batch_size, num_label, self.swd.max_len)),
+                              ('label_weight', (self.batch_size, num_label)),
+                              ('label_mask', (self.batch_size, num_label, self.swd.max_len, 1))]
+
+    def sample_ne(self):
+        # a negative sample.
+        return self.swd.negative_units[random.randint(0, len(self.swd.negative_units) - 1)]
+
+    def sample_ne_indices(self):
+        return [random.randint(0, len(self.swd.negative_units) - 1)
+                for _ in range(self.num_label - 1)]
+
+    def __iter__(self):
+        logging.info('DataIter start.')
+        batch_data = []
+        batch_data_mask = []
+        batch_label = []
+        batch_label_mask = []
+        batch_label_weight = []
+        start = random.randint(0, self.num_label - 1)
+        for i in range(start, len(self.swd.units) - self.num_label - start, self.num_label):
+            context_units = self.swd.units[i: i + self.num_label // 2] + \
+                            self.swd.units[i + 1 + self.num_label // 2: i + self.num_label]
+            context_mask = self.swd.weights[i: i + self.num_label // 2] + \
+                           self.swd.weights[i + 1 + self.num_label // 2: i + self.num_label]
+            target_units = self.swd.units[i + self.num_label // 2]
+            target_word = self.swd.data[i + self.num_label // 2]
+            if self.swd.freq[target_word] < self.min_count:
+                continue
+            indices = self.sample_ne_indices()
+            target = [target_units] + [self.swd.negative_units[i] for i in indices]
+            target_weight = [1.0] + [0.0 for _ in range(self.num_label - 1)]
+            target_mask = [self.swd.weights[i + self.num_label // 2]] +\
+                          [self.swd.negative_weights[i] for i in indices]
+
+            batch_data.append(context_units)
+            batch_data_mask.append(context_mask)
+            batch_label.append(target)
+            batch_label_mask.append(target_mask)
+            batch_label_weight.append(target_weight)
+
+            if len(batch_data) == self.batch_size:
+                # reshape for broadcast_mul
+                batch_data_mask = np.reshape(
+                    batch_data_mask, (self.batch_size, self.num_label - 1, self.swd.max_len, 1))
+                batch_label_mask = np.reshape(
+                    batch_label_mask, (self.batch_size, self.num_label, self.swd.max_len, 1))
+                data_all = [mx.nd.array(batch_data), mx.nd.array(batch_data_mask)]
+                label_all = [
+                    mx.nd.array(batch_label),
+                    mx.nd.array(batch_label_weight),
+                    mx.nd.array(batch_label_mask)
+                ]
+                data_names = ['data', 'mask']
+                label_names = ['label', 'label_weight', 'label_mask']
+                # clean up
+                batch_data = []
+                batch_data_mask = []
+                batch_label = []
+                batch_label_weight = []
+                batch_label_mask = []
+                yield SimpleBatch(data_names, data_all, label_names, label_all)
+
+    def reset(self):
+        pass
diff --git a/example/nce-loss/toy_nce.py b/example/nce-loss/toy_nce.py
index 39da7c7790..d71aab925e 100644
--- a/example/nce-loss/toy_nce.py
+++ b/example/nce-loss/toy_nce.py
@@ -15,90 +15,31 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
 import logging
-import sys, random, time
-sys.path.insert(0, "../../python")
+
 import mxnet as mx
-import numpy as np
-from collections import namedtuple
-from nce import *
+from nce import nce_loss, NceAccuracy
+from random_data import DataIterNce
+
 
-def get_net(vocab_size, num_label):
+def get_net(num_vocab):
     data = mx.sym.Variable('data')
     label = mx.sym.Variable('label')
     label_weight = mx.sym.Variable('label_weight')
     embed_weight = mx.sym.Variable('embed_weight')
-    pred = mx.sym.FullyConnected(data = data, num_hidden = 100)
-    ret = nce_loss(data = pred,
-                    label = label,
-                    label_weight = label_weight,
-                    embed_weight = embed_weight,
-                    vocab_size = vocab_size,
-                    num_hidden = 100,
-                    num_label = num_label)
+    pred = mx.sym.FullyConnected(data=data, num_hidden=100)
+    ret = nce_loss(
+        data=pred,
+        label=label,
+        label_weight=label_weight,
+        embed_weight=embed_weight,
+        vocab_size=num_vocab,
+        num_hidden=100)
     return ret
 
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, count, batch_size, vocab_size, num_label, feature_size):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.count = count
-        self.vocab_size = vocab_size
-        self.num_label = num_label
-        self.feature_size = feature_size
-        self.provide_data = [('data', (batch_size, feature_size))]
-        self.provide_label = [('label', (self.batch_size, num_label)),
-                              ('label_weight', (self.batch_size, num_label))]
-
-    def mock_sample(self):
-        ret = np.zeros(self.feature_size)
-        rn = set()
-        while len(rn) < 3:
-            rn.add(random.randint(0, self.feature_size - 1))
-        s = 0
-        for k in rn:
-            ret[k] = 1.0
-            s *= self.feature_size
-            s += k
-        la = [s % self.vocab_size] +\
-             [random.randint(0, self.vocab_size - 1) for _ in range(self.num_label - 1)]
-        return ret, la
-
-    def __iter__(self):
-        for _ in range(self.count / self.batch_size):
-            data = []
-            label = []
-            label_weight = []
-            for i in range(self.batch_size):
-                d, l = self.mock_sample()
-                data.append(d)
-                label.append(l)
-                label_weight.append([1.0] + [0.0 for _ in range(self.num_label - 1)])
-            data_all = [mx.nd.array(data)]
-            label_all = [mx.nd.array(label), mx.nd.array(label_weight)]
-            data_names = ['data']
-            label_names = ['label', 'label_weight']
-            yield SimpleBatch(data_names, data_all, label_names, label_all)
-
-    def reset(self):
-        pass
 
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
@@ -109,21 +50,25 @@ def reset(self):
     feature_size = 100
     num_label = 6
 
-    data_train = DataIter(100000, batch_size, vocab_size, num_label, feature_size)
-    data_test = DataIter(1000, batch_size, vocab_size, num_label, feature_size)
+    data_train = DataIterNce(100000, batch_size, vocab_size, num_label, feature_size)
+    data_test = DataIterNce(1000, batch_size, vocab_size, num_label, feature_size)
 
-    network = get_net(vocab_size, num_label)
-    devs = [mx.cpu()]
-    model = mx.model.FeedForward(ctx = devs,
-                                 symbol = network,
-                                 num_epoch = 20,
-                                 learning_rate = 0.03,
-                                 momentum = 0.9,
-                                 wd = 0.00001,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
+    network = get_net(vocab_size)
+    model = mx.mod.Module(
+        symbol=network,
+        data_names=[x[0] for x in data_train.provide_data],
+        label_names=[y[0] for y in data_train.provide_label],
+        context=[mx.cpu()]
+    )
 
     metric = NceAccuracy()
-    model.fit(X = data_train, eval_data = data_test,
-              eval_metric = metric,
-              batch_end_callback = mx.callback.Speedometer(batch_size, 50),)
-
+    model.fit(
+        train_data=data_train,
+        eval_data=data_test,
+        num_epoch=20,
+        optimizer='sgd',
+        optimizer_params={'learning_rate': 0.03, 'momentum': 0.9, 'wd': 0.00001},
+        initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
+        eval_metric=metric,
+        batch_end_callback=mx.callback.Speedometer(batch_size, 50)
+    )
diff --git a/example/nce-loss/toy_softmax.py b/example/nce-loss/toy_softmax.py
index ff6ff4327c..566296ca8d 100644
--- a/example/nce-loss/toy_softmax.py
+++ b/example/nce-loss/toy_softmax.py
@@ -15,79 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
 import logging
-import sys, random, time
-sys.path.insert(0, "../../python")
+
 import mxnet as mx
-import numpy as np
-from collections import namedtuple
+from random_data import DataIterSoftmax
 
-ToyModel = namedtuple("ToyModel", ["ex", "symbol", "param_blocks"])
 
-def get_net(vocab_size):
+def get_net(num_labels):
     data = mx.sym.Variable('data')
     label = mx.sym.Variable('label')
-    pred = mx.sym.FullyConnected(data = data, num_hidden = 100)
-    pred = mx.sym.FullyConnected(data = pred, num_hidden = vocab_size)
-    sm = mx.sym.SoftmaxOutput(data = pred, label = label)
+    pred = mx.sym.FullyConnected(data=data, num_hidden=100)
+    pred = mx.sym.FullyConnected(data=pred, num_hidden=num_labels)
+    sm = mx.sym.SoftmaxOutput(data=pred, label=label)
     return sm
 
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, count, batch_size, vocab_size, num_label, feature_size):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.count = count
-        self.vocab_size = vocab_size
-        self.num_label = num_label
-        self.feature_size = feature_size
-        self.provide_data = [('data', (batch_size, feature_size))]
-        self.provide_label = [('label', (self.batch_size,))]
-
-    def mock_sample(self):
-        ret = np.zeros(self.feature_size)
-        rn = set()
-        while len(rn) < 3:
-            rn.add(random.randint(0, self.feature_size - 1))
-        s = 0
-        for k in rn:
-            ret[k] = 1.0
-            s *= self.feature_size
-            s += k
-        return ret, s % self.vocab_size
-
-    def __iter__(self):
-        for _ in range(self.count / self.batch_size):
-            data = []
-            label = []
-            for i in range(self.batch_size):
-                d, l = self.mock_sample()
-                data.append(d)
-                label.append(l)
-            data_all = [mx.nd.array(data)]
-            label_all = [mx.nd.array(label)]
-            data_names = ['data']
-            label_names = ['label']
-            yield SimpleBatch(data_names, data_all, label_names, label_all)
-
-    def reset(self):
-        pass
 
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
@@ -98,19 +42,24 @@ def reset(self):
     feature_size = 100
     num_label = 6
 
-    data_train = DataIter(100000, batch_size, vocab_size, num_label, feature_size)
-    data_test = DataIter(1000, batch_size, vocab_size, num_label, feature_size)
+    data_train = DataIterSoftmax(100000, batch_size, vocab_size, num_label, feature_size)
+    data_test = DataIterSoftmax(1000, batch_size, vocab_size, num_label, feature_size)
 
     network = get_net(vocab_size)
-    devs = mx.cpu()
-    model = mx.model.FeedForward(ctx = devs,
-                                 symbol = network,
-                                 num_epoch = 20,
-                                 learning_rate = 0.03,
-                                 momentum = 0.9,
-                                 wd = 0.0000,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
-
-    model.fit(X = data_train, eval_data = data_test,
-              batch_end_callback = mx.callback.Speedometer(batch_size, 50),)
 
+    model = mx.mod.Module(
+        symbol=network,
+        data_names=[x[0] for x in data_train.provide_data],
+        label_names=[y[0] for y in data_train.provide_label],
+        context=[mx.cpu()]
+    )
+
+    model.fit(
+        train_data=data_train,
+        eval_data=data_test,
+        num_epoch=20,
+        optimizer='sgd',
+        optimizer_params={'learning_rate': 0.03, 'momentum': 0.9, 'wd': 0.0000},
+        initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
+        batch_end_callback=mx.callback.Speedometer(batch_size, 50)
+    )
diff --git a/example/nce-loss/wordvec.py b/example/nce-loss/wordvec.py
index 887d586ff3..783cf22727 100644
--- a/example/nce-loss/wordvec.py
+++ b/example/nce-loss/wordvec.py
@@ -15,154 +15,53 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
+# pylint: disable=missing-docstring, deprecated-module
 from __future__ import print_function
+
 import logging
-import sys, random, time, math
-sys.path.insert(0, "../../python")
-import mxnet as mx
-import numpy as np
-from collections import namedtuple
-from nce import *
-from operator import itemgetter
 from optparse import OptionParser
 
-def get_net(vocab_size, num_input, num_label):
-    data = mx.sym.Variable('data')
-    label = mx.sym.Variable('label')
-    label_weight = mx.sym.Variable('label_weight')
-    embed_weight = mx.sym.Variable('embed_weight')
-    data_embed = mx.sym.Embedding(data = data, input_dim = vocab_size,
-                                  weight = embed_weight,
-                                  output_dim = 100, name = 'data_embed')
-    datavec = mx.sym.SliceChannel(data = data_embed,
-                                     num_outputs = num_input,
-                                     squeeze_axis = 1, name = 'data_slice')
-    pred = datavec[0]
-    for i in range(1, num_input):
-        pred = pred + datavec[i]
-    return nce_loss(data = pred,
-                    label = label,
-                    label_weight = label_weight,
-                    embed_weight = embed_weight,
-                    vocab_size = vocab_size,
-                    num_hidden = 100,
-                    num_label = num_label)
-
-def load_data(name):
-    buf = open(name).read()
-    tks = buf.split(' ')
-    vocab = {}
-    freq = [0]
-    data = []
-    for tk in tks:
-        if len(tk) == 0:
-            continue
-        if tk not in vocab:
-            vocab[tk] = len(vocab) + 1
-            freq.append(0)
-        wid = vocab[tk]
-        data.append(wid)
-        freq[wid] += 1
-    negative = []
-    for i, v in enumerate(freq):
-        if i == 0 or v < 5:
-            continue
-        v = int(math.pow(v * 1.0, 0.75))
-        negative += [i for _ in range(v)]
-    return data, negative, vocab, freq
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, name, batch_size, num_label):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.data, self.negative, self.vocab, self.freq = load_data(name)
-        self.vocab_size = 1 + len(self.vocab)
-        print(self.vocab_size)
-        self.num_label = num_label
-        self.provide_data = [('data', (batch_size, num_label - 1))]
-        self.provide_label = [('label', (self.batch_size, num_label)),
-                              ('label_weight', (self.batch_size, num_label))]
-
-    def sample_ne(self):
-        return self.negative[random.randint(0, len(self.negative) - 1)]
-
-    def __iter__(self):
-        print('begin')
-        batch_data = []
-        batch_label = []
-        batch_label_weight = []
-        start = random.randint(0, self.num_label - 1)
-        for i in range(start, len(self.data) - self.num_label - start, self.num_label):
-            context = self.data[i: i + self.num_label / 2] \
-                      + self.data[i + 1 + self.num_label / 2: i + self.num_label]
-            target_word = self.data[i + self.num_label / 2]
-            if self.freq[target_word] < 5:
-                continue
-            target = [target_word] \
-                     + [self.sample_ne() for _ in range(self.num_label - 1)]
-            target_weight = [1.0] + [0.0 for _ in range(self.num_label - 1)]
-            batch_data.append(context)
-            batch_label.append(target)
-            batch_label_weight.append(target_weight)
-            if len(batch_data) == self.batch_size:
-                data_all = [mx.nd.array(batch_data)]
-                label_all = [mx.nd.array(batch_label), mx.nd.array(batch_label_weight)]
-                data_names = ['data']
-                label_names = ['label', 'label_weight']
-                batch_data = []
-                batch_label = []
-                batch_label_weight = []
-                yield SimpleBatch(data_names, data_all, label_names, label_all)
+import mxnet as mx
+from nce import NceAuc
+from text8_data import DataIterWords
+from wordvec_net import get_word_net
 
-    def reset(self):
-        pass
 
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
     logging.basicConfig(level=logging.DEBUG, format=head)
 
     parser = OptionParser()
-    parser.add_option("-g", "--gpu", action = "store_true", dest = "gpu", default = False,
-                      help = "use gpu")
+    parser.add_option("-g", "--gpu", action="store_true", dest="gpu", default=False,
+                      help="use gpu")
+    options, args = parser.parse_args()
+
     batch_size = 256
     num_label = 5
 
-    data_train = DataIter("./data/text8", batch_size, num_label)
+    data_train = DataIterWords("./data/text8", batch_size, num_label)
 
-    network = get_net(data_train.vocab_size, num_label - 1, num_label)
+    network = get_word_net(data_train.vocab_size, num_label - 1)
 
-    options, args = parser.parse_args()
     devs = mx.cpu()
-    if options.gpu == True:
+    if options.gpu:
         devs = mx.gpu()
-    model = mx.model.FeedForward(ctx = devs,
-                                 symbol = network,
-                                 num_epoch = 20,
-                                 learning_rate = 0.3,
-                                 momentum = 0.9,
-                                 wd = 0.0000,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
 
+    model = mx.mod.Module(
+        symbol=network,
+        data_names=[x[0] for x in data_train.provide_data],
+        label_names=[y[0] for y in data_train.provide_label],
+        context=[devs]
+    )
 
+    print("Training on {}".format("GPU" if options.gpu else "CPU"))
     metric = NceAuc()
-    model.fit(X = data_train,
-              eval_metric = metric,
-              batch_end_callback = mx.callback.Speedometer(batch_size, 50),)
-
+    model.fit(
+        train_data=data_train,
+        num_epoch=20,
+        optimizer='sgd',
+        optimizer_params={'learning_rate': 0.3, 'momentum': 0.9, 'wd': 0.0000},
+        initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
+        eval_metric=metric,
+        batch_end_callback=mx.callback.Speedometer(batch_size, 50)
+    )
diff --git a/example/nce-loss/wordvec_net.py b/example/nce-loss/wordvec_net.py
new file mode 100644
index 0000000000..4f52a0b679
--- /dev/null
+++ b/example/nce-loss/wordvec_net.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=missing-docstring
+from __future__ import print_function
+
+import mxnet as mx
+from nce import nce_loss, nce_loss_subwords
+
+
+def get_word_net(vocab_size, num_input):
+    data = mx.sym.Variable('data')
+    label = mx.sym.Variable('label')
+    label_weight = mx.sym.Variable('label_weight')
+    embed_weight = mx.sym.Variable('embed_weight')
+    data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
+                                  weight=embed_weight,
+                                  output_dim=100, name='data_embed')
+    datavec = mx.sym.SliceChannel(data=data_embed,
+                                  num_outputs=num_input,
+                                  squeeze_axis=1, name='data_slice')
+    pred = datavec[0]
+    for i in range(1, num_input):
+        pred = pred + datavec[i]
+    return nce_loss(data=pred,
+                    label=label,
+                    label_weight=label_weight,
+                    embed_weight=embed_weight,
+                    vocab_size=vocab_size,
+                    num_hidden=100)
+
+
+def get_subword_net(vocab_size, num_input, embedding_size):
+    data = mx.sym.Variable('data')
+    mask = mx.sym.Variable('mask')  # use mask to handle variable-length input.
+    label = mx.sym.Variable('label')
+    label_mask = mx.sym.Variable('label_mask')
+    label_weight = mx.sym.Variable('label_weight')
+    embed_weight = mx.sym.Variable('embed_weight')
+
+    # Get embedding for one-hot input.
+    # get sub-word units input.
+    unit_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
+                                  weight=embed_weight,
+                                  output_dim=embedding_size)
+
+    # mask embedding_output to get summation of sub-word units'embedding.
+    unit_embed = mx.sym.broadcast_mul(lhs=unit_embed, rhs=mask, name='data_units_embed')
+
+    # sum over all these words then you get word-embedding.
+    data_embed = mx.sym.sum(unit_embed, axis=2)
+
+    # Slice input equally along specified axis.
+    datavec = mx.sym.SliceChannel(data=data_embed,
+                                  num_outputs=num_input,
+                                  squeeze_axis=1, name='data_slice')
+    pred = datavec[0]
+    for i in range(1, num_input):
+        pred = pred + datavec[i]
+
+    return nce_loss_subwords(data=pred,
+                             label=label,
+                             label_mask=label_mask,
+                             label_weight=label_weight,
+                             embed_weight=embed_weight,
+                             vocab_size=vocab_size,
+                             num_hidden=embedding_size)
diff --git a/example/nce-loss/wordvec_subwords.py b/example/nce-loss/wordvec_subwords.py
index c8d46a1aeb..677f740421 100644
--- a/example/nce-loss/wordvec_subwords.py
+++ b/example/nce-loss/wordvec_subwords.py
@@ -15,22 +15,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint:skip-file
+# pylint: disable=missing-docstring, deprecated-module
+from __future__ import print_function
+
 import logging
-import sys, random, time, math
-import mxnet as mx
-import numpy as np
-from nce import *
-from operator import itemgetter
 from optparse import OptionParser
-from collections import Counter
 
-import logging
-head = head = '%(asctime)-15s %(message)s'
-logging.basicConfig(level=logging.INFO, format=head)
+import mxnet as mx
+from nce import NceAuc
+from text8_data import DataIterSubWords
+from wordvec_net import get_subword_net
+
 
+head = '%(asctime)-15s %(message)s'
+logging.basicConfig(level=logging.INFO, format=head)
 
-# ----------------------------------------------------------------------------------------
 EMBEDDING_SIZE = 100
 BATCH_SIZE = 256
 NUM_LABEL = 5
@@ -41,229 +40,6 @@
 PADDING_CHAR = '</s>'
 
 
-# ----------------------------------------------------------------------------------------
-def get_net(vocab_size, num_input, num_label):
-    data = mx.sym.Variable('data')
-    mask = mx.sym.Variable('mask')  # use mask to handle variable-length input.
-    label = mx.sym.Variable('label')
-    label_mask = mx.sym.Variable('label_mask')
-    label_weight = mx.sym.Variable('label_weight')
-    embed_weight = mx.sym.Variable('embed_weight')
-
-    # Get embedding for one-hot input.
-    # get sub-word units input.
-    unit_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
-                                  weight=embed_weight,
-                                  output_dim=EMBEDDING_SIZE)
-
-    # mask embedding_output to get summation of sub-word units'embedding.
-    unit_embed = mx.sym.broadcast_mul(lhs=unit_embed, rhs=mask, name='data_units_embed')
-
-    # sum over all these words then you get word-embedding.
-    data_embed = mx.sym.sum(unit_embed, axis=2)
-
-    # Slice input equally along specified axis.
-    datavec = mx.sym.SliceChannel(data=data_embed,
-                                  num_outputs=num_input,
-                                  squeeze_axis=1, name='data_slice')
-    pred = datavec[0]
-    for i in range(1, num_input):
-        pred = pred + datavec[i]
-
-    return nce_loss_subwords(data=pred,
-                             label=label,
-                             label_mask=label_mask,
-                             label_weight=label_weight,
-                             embed_weight=embed_weight,
-                             vocab_size=vocab_size,
-                             num_hidden=EMBEDDING_SIZE,
-                             num_label=num_label)
-
-
-def get_subword_units(token, gram=GRAMS):
-    """Return subword-units presentation, given a word/token.
-    """
-    if token == '</s>':  # special token for padding purpose.
-        return [token]
-    t = '#' + token + '#'
-    return [t[i:i + gram] for i in range(0, len(t) - gram + 1)]
-
-
-def get_subword_representation(wid, vocab_inv, units_vocab, max_len):
-    token = vocab_inv[wid]
-    units = [units_vocab[unit] for unit in get_subword_units(token)]
-    weights = [1] * len(units) + [0] * (max_len - len(units))
-    units = units + [units_vocab[PADDING_CHAR]] * (max_len - len(units))
-    return units, weights
-
-
-def prepare_subword_units(tks):
-    # statistics on units
-    units_vocab = {PADDING_CHAR:1}
-    max_len = 0
-    unit_set = set()
-    logging.info('grams: %d', GRAMS)
-    logging.info('counting max len...')
-    for tk in tks:
-        res = get_subword_units(tk)
-        unit_set.update(i for i in res)
-        if max_len < len(res):
-            max_len = len(res)
-    logging.info('preparing units vocab...')
-    for unit in unit_set:
-        if len(unit) == 0:
-            continue
-        if unit not in units_vocab:
-            units_vocab[unit] = len(units_vocab)
-        uid = units_vocab[unit]
-    return units_vocab, max_len
-
-
-def load_data_as_subword_units(name):
-    tks = []
-    fread = open(name, 'r')
-    logging.info('reading corpus from file...')
-    for line in fread:
-        line = line.strip().decode('utf-8')
-        tks.extend(line.split(' '))
-
-    logging.info('Total tokens: %d', len(tks))
-
-    tks = [tk for tk in tks if len(tk) <= MAX_SUBWORDS]
-    c = Counter(tks)
-
-    logging.info('Total vocab: %d', len(c))
-
-    vocab = {}
-    vocab_inv = {}
-    freq = [0]
-    data = []
-    for tk in tks:
-        if len(tk) == 0:
-            continue
-        if tk not in vocab:
-            vocab[tk] = len(vocab)
-            freq.append(0)
-        wid = vocab[tk]
-        vocab_inv[wid] = tk
-        data.append(wid)
-        freq[wid] += 1
-
-    negative = []
-    for i, v in enumerate(freq):
-        if i == 0 or v < MIN_COUNT:
-            continue
-        v = int(math.pow(v * 1.0, 0.75))  # sample negative w.r.t. its frequency
-        negative += [i for _ in range(v)]
-
-    logging.info('counting subword units...')
-    units_vocab, max_len = prepare_subword_units(tks)
-    logging.info('vocabulary size: %d', len(vocab))
-    logging.info('subword unit size: %d', len(units_vocab))
-
-    logging.info('generating input data...')
-    units = []
-    weights = []
-    for wid in data:
-        word_units, weight = get_subword_representation(wid, vocab_inv, units_vocab, max_len)
-        units.append(word_units)
-        weights.append(weight)
-
-    negative_units = []
-    negative_weights = []
-    for wid in negative:
-        word_units, weight = get_subword_representation(wid, vocab_inv, units_vocab, max_len)
-        negative_units.append(word_units)
-        negative_weights.append(weight)
-
-    return data, units, weights, negative_units, negative_weights, vocab, units_vocab, freq, max_len
-
-
-class SimpleBatch(object):
-    def __init__(self, data_names, data, label_names, label):
-        self.data = data
-        self.label = label
-        self.data_names = data_names
-        self.label_names = label_names
-
-    @property
-    def provide_data(self):
-        return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
-
-    @property
-    def provide_label(self):
-        return [(n, x.shape) for n, x in zip(self.label_names, self.label)]
-
-
-class DataIter(mx.io.DataIter):
-    def __init__(self, fname, batch_size, num_label):
-        super(DataIter, self).__init__()
-        self.batch_size = batch_size
-        self.data, self.units, self.weights, self.negative_units, self.negative_weights, \
-        self.vocab, self.units_vocab, self.freq, self.max_len = load_data_as_subword_units(fname)
-        self.vocab_size = len(self.units_vocab)
-        self.num_label = num_label
-        self.provide_data = [('data', (batch_size, num_label - 1, self.max_len)),
-                             ('mask', (batch_size, num_label - 1, self.max_len, 1))]
-        self.provide_label = [('label', (self.batch_size, num_label, self.max_len)),
-                              ('label_weight', (self.batch_size, num_label)),
-                              ('label_mask', (self.batch_size, num_label, self.max_len, 1))]
-
-    def sample_ne(self):
-        # a negative sample.
-        return self.negative_units[random.randint(0, len(self.negative_units) - 1)]
-
-    def sample_ne_indices(self):
-        return [random.randint(0, len(self.negative_units) - 1) for _ in range(self.num_label - 1)]
-
-    def __iter__(self):
-        logging.info('DataIter start.')
-        batch_data = []
-        batch_data_mask = []
-        batch_label = []
-        batch_label_mask = []
-        batch_label_weight = []
-        start = random.randint(0, self.num_label - 1)
-        for i in range(start, len(self.units) - self.num_label - start, self.num_label):
-            context_units = self.units[i: i + self.num_label / 2] + \
-                            self.units[i + 1 + self.num_label / 2: i + self.num_label]
-            context_mask = self.weights[i: i + self.num_label / 2] + \
-                           self.weights[i + 1 + self.num_label / 2: i + self.num_label]
-            target_units = self.units[i + self.num_label / 2]
-            target_word = self.data[i + self.num_label / 2]
-            if self.freq[target_word] < MIN_COUNT:
-                continue
-            indices = self.sample_ne_indices()
-            target = [target_units] + [self.negative_units[i] for i in indices]
-            target_weight = [1.0] + [0.0 for _ in range(self.num_label - 1)]
-            target_mask = [self.weights[i + self.num_label / 2]] + [self.negative_weights[i] for i in indices]
-
-            batch_data.append(context_units)
-            batch_data_mask.append(context_mask)
-            batch_label.append(target)
-            batch_label_mask.append(target_mask)
-            batch_label_weight.append(target_weight)
-
-            if len(batch_data) == self.batch_size:
-                # reshape for broadcast_mul
-                batch_data_mask = np.reshape(batch_data_mask, (batch_size, num_label - 1, self.max_len, 1))
-                batch_label_mask = np.reshape(batch_label_mask, (batch_size, num_label, self.max_len, 1))
-                data_all = [mx.nd.array(batch_data), mx.nd.array(batch_data_mask)]
-                label_all = [mx.nd.array(batch_label), mx.nd.array(batch_label_weight), mx.nd.array(batch_label_mask)]
-                data_names = ['data', 'mask']
-                label_names = ['label', 'label_weight', 'label_mask']
-                # clean up
-                batch_data = []
-                batch_data_mask = []
-                batch_label = []
-                batch_label_weight = []
-                batch_label_mask = []
-                yield SimpleBatch(data_names, data_all, label_names, label_all)
-
-    def reset(self):
-        pass
-
-
 if __name__ == '__main__':
     head = '%(asctime)-15s %(message)s'
     logging.basicConfig(level=logging.DEBUG, format=head)
@@ -271,28 +47,42 @@ def reset(self):
     parser = OptionParser()
     parser.add_option("-g", "--gpu", action="store_true", dest="gpu", default=False,
                       help="use gpu")
+    options, args = parser.parse_args()
 
     batch_size = BATCH_SIZE
     num_label = NUM_LABEL
+    embedding_size = EMBEDDING_SIZE
 
-    data_train = DataIter("./data/text8", batch_size, num_label)
+    data_train = DataIterSubWords(
+        "./data/text8",
+        batch_size=batch_size,
+        num_label=num_label,
+        min_count=MIN_COUNT,
+        gram=GRAMS,
+        max_subwords=MAX_SUBWORDS,
+        padding_char=PADDING_CHAR)
 
-    network = get_net(data_train.vocab_size, num_label - 1, num_label)
+    network = get_subword_net(data_train.vocab_size, num_label - 1, embedding_size)
 
-    options, args = parser.parse_args()
-    # devs = mx.cpu()
-    devs = [mx.cpu(i) for i in range(4)]
-    if options.gpu == True:
+    devs = mx.cpu()
+    if options.gpu:
         devs = mx.gpu()
-    model = mx.model.FeedForward(ctx=devs,
-                                 symbol=network,
-                                 num_epoch=NUM_EPOCH,
-                                 learning_rate=0.3,
-                                 momentum=0.9,
-                                 wd=0.0000,
-                                 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
 
+    model = mx.mod.Module(
+        symbol=network,
+        data_names=[x[0] for x in data_train.provide_data],
+        label_names=[y[0] for y in data_train.provide_label],
+        context=[devs]
+    )
+
+    print("Training on {}".format("GPU" if options.gpu else "CPU"))
     metric = NceAuc()
-    model.fit(X=data_train,
-              eval_metric=metric,
-              batch_end_callback=mx.callback.Speedometer(batch_size, 50), )
+    model.fit(
+        train_data=data_train,
+        num_epoch=NUM_EPOCH,
+        optimizer='sgd',
+        optimizer_params={'learning_rate': 0.3, 'momentum': 0.9, 'wd': 0.0000},
+        initializer=mx.init.Xavier(factor_type='in', magnitude=2.34),
+        eval_metric=metric,
+        batch_end_callback=mx.callback.Speedometer(batch_size, 50)
+    )


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message