singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [5/6] incubator-singa git commit: convert param for both v3 and v4
Date Thu, 13 Jul 2017 07:01:11 GMT
convert param for both v3 and v4


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/3f39cfa2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/3f39cfa2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/3f39cfa2

Branch: refs/heads/master
Commit: 3f39cfa246b477329ce349b8d2bcf10e5353bf37
Parents: 40124db
Author: wangwei <wangwei@comp.nus.edu.sg>
Authored: Fri Jul 7 13:20:03 2017 +0800
Committer: wangwei <wangwei@comp.nus.edu.sg>
Committed: Fri Jul 7 13:20:03 2017 +0800

----------------------------------------------------------------------
 examples/imagenet/inception/inception_v3.py | 456 +++++++++++++----------
 examples/imagenet/inception/inception_v4.py |  87 ++---
 2 files changed, 305 insertions(+), 238 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3f39cfa2/examples/imagenet/inception/inception_v3.py
----------------------------------------------------------------------
diff --git a/examples/imagenet/inception/inception_v3.py b/examples/imagenet/inception/inception_v3.py
index dc7b24b..421343b 100644
--- a/examples/imagenet/inception/inception_v3.py
+++ b/examples/imagenet/inception/inception_v3.py
@@ -28,52 +28,63 @@ from __future__ import division
 from __future__ import print_function
 
 from singa.layer import Conv2D, Activation, MaxPooling2D, AvgPooling2D,\
-        Split, Concat, Dropout, Flatten, Dense, BatchNormalization
+        Split, Concat, Dropout, Flatten, BatchNormalization
 
 from singa import net as ffnet
 
 ffnet.verbose = True
 
+
 def conv2d(net, name, nb_filter, k, s=1, border_mode='SAME', src=None):
     if type(k) is list:
         k = (k[0], k[1])
-    net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode, use_bias=False), src)
+    net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode,
+                   use_bias=False), src)
     net.add(BatchNormalization('%s/BatchNorm' % name))
     return net.add(Activation(name+'/relu'))
 
 
-def inception_v3_base(name, sample_shape, final_endpoint='Mixed_6e', aux_name=None, depth_multiplier=1,
min_depth=16):
+def inception_v3_base(name, sample_shape, final_endpoint, aux_endpoint,
+                      depth_multiplier=1, min_depth=16):
     """Creates the Inception V3 network up to the given final endpoint.
 
     Args:
-        inputs: a 4-D tensor of size [batch_size, height, width, 3].
+        sample_shape: input image sample shape, 3d tuple
         final_endpoint: specifies the endpoint to construct the network up to.
+        aux_endpoint: for aux loss.
 
     Returns:
         logits: the logits outputs of the model.
         end_points: the set of end_points from the inception model.
 
     Raises:
-        ValueError: if final_endpoint is not set to one of the predefined values,
+        ValueError: if final_endpoint is not set to one of the predefined values
     """
-    endpoints = {}
-    def final_aux_check(block_name, net):
+    V3 = 'InceptionV3'
+    end_points = {}
+    net = ffnet.FeedForwardNet()
+
+    def final_aux_check(block_name):
         if block_name == final_endpoint:
-            return net, endpoints[block_name], endpoints
-        if block_name == aux_name:
-            endpoints[aux_name + '-aux'] = net.add(Split('%s-aux' % aux_name, 2))
+            return True
+        if block_name == aux_endpoint:
+            aux = aux_endpoint + '-aux'
+            end_points[aux] = net.add(Split(aux, 2))
+        return False
 
-    net = ffnet.FeedForwardNet()
-    depth = lambda d: max(int(d * depth_multiplier), min_depth)
-    V3 = 'InceptionV3'
+    def depth(d):
+        return max(int(d * depth_multiplier), min_depth)
 
-    name = V3 + '/Conv2d_1a_3x3'
+    blk = V3 + '/Conv2d_1a_3x3'
     # 299 x 299 x 3
-    net.add(Conv2D(name, depth(32), 3, 2, border_mode='VALID', use_bias=False, input_sample_shape=sample_shape))
-    net.add(BatchNormalization(name + '/BatchNorm'))
-    net.add(Activation(name + '/relu'))
-    # 149 x 149 x 32
+    net.add(Conv2D(blk, depth(32), 3, 2, border_mode='VALID', use_bias=False,
+                   input_sample_shape=sample_shape))
+    net.add(BatchNormalization(blk + '/BatchNorm'))
+    end_points[blk] = net.add(Activation(blk + '/relu'))
+    if final_aux_check(blk):
+        return net, end_points
 
+    # 149 x 149 x 32
     conv2d(net, '%s/Conv2d_2a_3x3' % V3, depth(32), 3, border_mode='VALID')
     # 147 x 147 x 32
     conv2d(net, '%s/Conv2d_2b_3x3' % V3, depth(64), 3)
@@ -85,221 +96,274 @@ def inception_v3_base(name, sample_shape, final_endpoint='Mixed_6e',
aux_name=No
     conv2d(net, '%s/Conv2d_4a_3x3' % V3, depth(192), 3, border_mode='VALID')
     # 71 x 71 x 192.
     net.add(MaxPooling2D('%s/MaxPool_5a_3x3' % V3, 3, 2, border_mode='VALID'))
-    # 35 x 35 x 192.
 
+    # 35 x 35 x 192.
+    blk = V3 + '/Mixed_5b'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3)
+    net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(32), 1)
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
-    m5b = V3 + '/Mixed_5b'
-    s = net.add(Split('%s/Split' % m5b, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5b, depth(64), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m5b, depth(48), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % m5b, depth(64), 5)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5b, depth(64), 1, src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5b, depth(96), 3)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5b, depth(96), 3)
-    net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5b, 3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5b, depth(32), 1)
-    endpoints[m5b] =net.add(Concat('%s/Concat' % m5b, 1),  [br0, br1, br2, br3])
-    final_aux_check(m5b, net)
     # mixed_1: 35 x 35 x 288.
-    m5c = V3 + '/Mixed_5c'
-    s = net.add(Split('%s/Split' % m5c, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5c, depth(64), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Con2d_0b_1x1' % m5c, depth(48), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv_1_0c_5x5' % m5c, depth(64), 5)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5c, depth(64), 1, src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5c, depth(96), 3)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5c, depth(96), 3)
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5c, 3, 1), src=s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5c, depth(64), 1)
-    endpoints[m5c] = net.add(Concat('%s/Concat' % m5c, 1),  [br0, br1, br2, br3])
-    final_aux_check(m5c, net)
+    blk = V3 + '/Mixed_5c'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x1' % blk, depth(48), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv_1_0c_5x5' % blk, depth(64), 5)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3)
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1)
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_2: 35 x 35 x 288.
-    m5d = V3 + '/Mixed_5d'
-    s = net.add(Split('%s/Split' % m5d, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m5d, depth(64), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m5d, depth(48), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % m5d, depth(64), 5)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m5d, depth(64), 1, src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m5d, depth(96), 3)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % m5d, depth(96), 3)
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m5d,  3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m5d, depth(64), 1)
-    endpoints[m5d] =net.add(Concat('%s/Concat' % m5d, 1),  [br0, br1, br2, br3])
-    final_aux_check(m5d, net)
+    blk = V3 + '/Mixed_5d'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(48), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_5x5' % blk, depth(64), 5)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(96), 3)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_3x3' % blk, depth(96), 3)
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk,  3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(64), 1)
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_3: 17 x 17 x 768.
-    m6a = V3 + '/Mixed_6a'
-    s = net.add(Split('%s/Split' % m6a, 3))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_1x1' % m6a, depth(384), 3, 2, border_mode='VALID',
src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6a, depth(64), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % m6a, depth(96), 3)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_1x1' % m6a, depth(96), 3, 2, border_mode='VALID')
-    br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % m6a, 3, 2, border_mode='VALID'),
s)
-    endpoints[m6a] = net.add(Concat('%s/Concat' % m6a, 1),  [br0, br1, br2])
-    final_aux_check(m6a, net)
+    blk = V3 + '/Mixed_6a'
+    s = net.add(Split('%s/Split' % blk, 3))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_1x1' % blk, depth(384), 3, 2,
+                 border_mode='VALID', src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(64), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % blk, depth(96), 3)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_1x1' % blk, depth(96), 3, 2,
+                 border_mode='VALID')
+    br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2,
+                               border_mode='VALID'), s)
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),  [br0, br1, br2])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed4: 17 x 17 x 768.
-    m6b = V3 + '/Mixed_6b'
-    s = net.add(Split('%s/Split' % m6b, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6b, depth(192), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6b, depth(128), 1, src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6b, depth(128), [1, 7])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6b, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6b, depth(128), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6b, depth(128), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6b, depth(128), [1, 7])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6b, depth(128), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6b, depth(192), [1, 7])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6b, 3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6b, depth(192), [1, 1])
-    endpoints[m6b] = net.add(Concat('%s/Concat' % m6b, 1),  [br0, br1, br2, br3])
-    final_aux_check(m6b, net)
+    blk = V3 + '/Mixed_6b'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(128), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(128), [1, 7])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(128), [1, 1],
+                 src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(128), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(128), [1, 7])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(128), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_5: 17 x 17 x 768.
-    m6c = V3 + '/Mixed_6c'
-    s = net.add(Split('%s/Split' % m6c, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6c, depth(192), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6c, depth(160), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6c, depth(160), [1, 7])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6c, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6c, depth(160), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6c, depth(160), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6c, depth(160), [1, 7])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6c, depth(160), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6c, depth(192), [1, 7])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6c, 3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6c, depth(192), [1, 1])
-    endpoints[m6c] = net.add(Concat('%s/Concat' % m6c, 1),  [br0, br1, br2, br3])
-    final_aux_check(m6c, net)
+    blk = V3 + '/Mixed_6c'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1],
+                 src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_6: 17 x 17 x 768.
-    m6d = V3 + '/Mixed_6d'
-    s = net.add(Split('%s/Split' % m6d, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6d, depth(192), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6d, depth(160), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6d, depth(160), [1, 7])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6d, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6d, depth(160), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6d, depth(160), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6d, depth(160), [1, 7])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6d, depth(160), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6d, depth(192), [1, 7])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6d, 3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6d, depth(192), [1, 1])
-    endpoints[m6d] = net.add(Concat('%s/Concat' % m6d, 1),  [br0, br1, br2, br3])
-    final_aux_check(m6d, net)
-
-    m6e = V3 + '/Mixed_6e'
-    s = net.add(Split('%s/Split' % m6e, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m6e, depth(192), [1, 7])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m6e, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m6e, depth(192), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % m6e, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % m6e, depth(192), [1, 7])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % m6e, depth(192), [7, 1])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % m6e, depth(192), [1, 7])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m6d, 3, 1), s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m6d, depth(192), [1, 1])
-    endpoints[m6e] = net.add(Concat('%s/Concat' % m6d, 1),  [br0, br1, br2, br3])
-    final_aux_check(m6e, net)
+    blk = V3 + '/Mixed_6d'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(160), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(160), [1, 7])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(160), [1, 1],
+                 src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(160), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(160), [1, 7])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(160), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
+
+    blk = V3 + '/Mixed_6e'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x7' % blk, depth(192), [1, 7])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0d_7x1' % blk, depth(192), [7, 1])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0e_1x7' % blk, depth(192), [1, 7])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_8: 8 x 8 x 1280.
-    m7a = V3 + '/Mixed_7a'
-    s = net.add(Split('%s/Split' % m7a, 3))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7a, depth(192), [1, 1], src=s)
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % m7a, depth(320), [3, 3], 2, border_mode='VALID')
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7a, depth(192), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % m7a, depth(192), [1, 7])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % m7a, depth(192), [7, 1])
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % m7a, depth(192), [3, 3], 2, border_mode='VALID')
-    br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % m7a, 3, 2, border_mode='VALID'),
s)
-    endpoints[m7a] = net.add(Concat('%s/Concat' % m7a, 1),  [br0, br1, br2])
-    final_aux_check(m7a, net)
+    blk = V3 + '/Mixed_7a'
+    s = net.add(Split('%s/Split' % blk, 3))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, depth(320), [3, 3], 2,
+                 border_mode='VALID')
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(192), [1, 1],
+                 src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, depth(192), [1, 7])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, depth(192), [7, 1])
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, depth(192), [3, 3], 2,
+                 border_mode='VALID')
+    br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2,
+                               border_mode='VALID'), s)
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),  [br0, br1, br2])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_9: 8 x 8 x 2048.
-    m7b = V3 + '/Mixed_7b'
-    s = net.add(Split('%s/Split' % m7b, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7b, depth(320), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7b, depth(384), [1, 1], src=s)
-    s1 = net.add(Split('%s/Branch_1/Split1' % m7b, 2))
-    br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % m7b, depth(384), [1, 3], src=s1)
-    br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % m7b, depth(384), [3, 1], src=s1)
-    br1 = net.add(Concat('%s/Branch_1/Concat1' % m7b, 1),  [br11, br12])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m7b, depth(448), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m7b, depth(384), [3, 3])
-    s2 = net.add(Split('%s/Branch_2/Split2' % m7b, 2))
-    br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % m7b, depth(384), [1, 3], src=s2)
-    br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % m7b, depth(384), [3, 1], src=s2)
-    br2 = net.add(Concat('%s/Branch_2/Concat2' % m7b, 1),  [br21, br22])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m7b, 3, 1), src=s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m7b, depth(192), [1, 1])
-    endpoints[m7b] = net.add(Concat('%s/Concat' % m7b, 1),  [br0, br1, br2, br3])
-    final_aux_check(m7b, net)
+    blk = V3 + '/Mixed_7b'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s)
+    s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2))
+    br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3],
+                  src=s1)
+    br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % blk, depth(384), [3, 1],
+                  src=s1)
+    br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1),  [br11, br12])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), 1, src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), 3)
+    s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2))
+    br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3],
+                  src=s2)
+    br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1],
+                  src=s2)
+    br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1),  [br21, br22])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    if final_aux_check(blk):
+        return net, end_points
 
     # mixed_10: 8 x 8 x 2048.
-    m7c = V3 + '/Mixed_7c'
-    s = net.add(Split('%s/Split' % m7c, 4))
-    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % m7c, depth(320), [1, 1], src=s)
-    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % m7c, depth(384), [1, 1], src=s)
-    s1 = net.add(Split('%s/Branch_1/Split1' % m7c, 2))
-    br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % m7c, depth(384), [1, 3], src=s1)
-    br12 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x1' % m7c, depth(384), [3, 1], src=s1)
-    br1 = net.add(Concat('%s/Branch_1/Concat1' % m7c, 1),  [br11, br12])
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % m7c, depth(448), [1, 1], src=s)
-    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % m7c, depth(384), [3, 3])
-    s2 = net.add(Split('%s/Branch_2/Split2' % m7c, 2))
-    br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % m7c, depth(384), [1, 3], src=s2)
-    br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % m7c, depth(384), [3, 1], src=s2)
-    br2 = net.add(Concat('%s/Branch_2/Concat2' % m7c, 1),  [br21, br22])
-    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % m7c, 3, 1), src=s)
-    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % m7c, depth(192), [1, 1])
-    endpoints[m7c] = net.add(Concat('%s/Concat' % m7c, 1),  [br0, br1, br2, br3])
-    final_aux_check(m7c, net)
-    return net, endpoints[m7c], endpoints
-
-
-def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True, dropout_keep_prob=0.8,
create_aux_logits=True):
+    blk = V3 + '/Mixed_7c'
+    s = net.add(Split('%s/Split' % blk, 4))
+    br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, depth(320), 1, src=s)
+    br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, depth(384), 1, src=s)
+    s1 = net.add(Split('%s/Branch_1/Split1' % blk, 2))
+    br11 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x3' % blk, depth(384), [1, 3],
+                  src=s1)
+    br12 = conv2d(net, '%s/Branch_1/Conv2d_0c_3x1' % blk, depth(384), [3, 1],
+                  src=s1)
+    br1 = net.add(Concat('%s/Branch_1/Concat1' % blk, 1),  [br11, br12])
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, depth(448), [1, 1],
+                 src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x3' % blk, depth(384), [3, 3])
+    s2 = net.add(Split('%s/Branch_2/Split2' % blk, 2))
+    br21 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, depth(384), [1, 3],
+                  src=s2)
+    br22 = conv2d(net, '%s/Branch_2/Conv2d_0d_3x1' % blk, depth(384), [3, 1],
+                  src=s2)
+    br2 = net.add(Concat('%s/Branch_2/Concat2' % blk, 1),  [br21, br22])
+    br3 = net.add(AvgPooling2D('%s/Branch_3/AvgPool_0a_3x3' % blk, 3, 1), src=s)
+    br3 = conv2d(net, '%s/Branch_3/Conv2d_0b_1x1' % blk, depth(192), [1, 1])
+    end_points[blk] = net.add(Concat('%s/Concat' % blk, 1),
+                              [br0, br1, br2, br3])
+    assert final_endpoint == blk, \
+        'final_enpoint = %s is not in the net' % final_endpoint
+    return net, end_points
+
+
+def create_net(num_classes=1001, sample_shape=(3, 299, 299),
+               final_endpoint='InceptionV3/Mixed_7c',
+               aux_endpoint='InceptionV3/Mixed_6e',
+               dropout_keep_prob=0.8):
     """Creates the Inception V4 model.
 
     Args:
         num_classes: number of predicted classes.
-        is_training: whether is training or not.
         dropout_keep_prob: float, the fraction to keep before final layer.
-        reuse: whether or not the network and its variables should be reused. To be
-        able to reuse 'scope' must be given.
-        create_aux_logits: Whether to include the auxiliary logits.
+        final_endpoint: 'InceptionV3/Mixed_7d',
+        aux_endpoint:
 
     Returns:
         logits: the logits outputs of the model.
         end_points: the set of end_points from the inception model.
     """
     name = 'InceptionV3'
-    if is_training and create_aux_logits:
-        aux_name = name + '/Mixed_6e'
-    else:
-        aux_name = None
-    net, last_layer, end_points = inception_v3_base(name, sample_shape, aux_name=aux_name)
+    net, end_points = inception_v3_base(name, sample_shape, final_endpoint,
+                                        aux_endpoint)
     # Auxiliary Head logits
-    if aux_name is not None:
+    if aux_endpoint is not None:
         # 8 x 8 x 1280
-        aux_logits = end_points[aux_name + '-aux']
-        net.add(AvgPooling2D('%s/AuxLogits/AvgPool_1a_5x5' % name, 5, stride=3, border_mode='VALID'),
aux_logits)
-        t = conv2d(net, '%s/AuxLogits/Conv2d_1b_1x1' % name, 128, 1)
-        conv2d(net, '%s/AuxLogits/Conv2d_2a' % name, 768, t.get_output_sample_shape()[1:3],
border_mode='VALID')
-        net.add(Flatten('%s/AuxLogits/flat' % name))
-        end_points['AuxLogits'] = net.add(Dense('%s/AuxLogits/Aux_logits' % name, num_classes))
+        aux_logits = end_points[aux_endpoint + '-aux']
+        blk = name + '/AuxLogits'
+        net.add(AvgPooling2D('%s/AvgPool_1a_5x5' % blk, 5, stride=3,
+                             border_mode='VALID'), aux_logits)
+        t = conv2d(net, '%s/Conv2d_1b_1x1' % blk, 128, 1)
+        s = t.get_output_sample_shape()[1:3]
+        conv2d(net, '%s/Conv2d_2a_%dx%d' % (blk, s[0], s[1]), 768, s,
+               border_mode='VALID')
+        net.add(Conv2D('%s/Conv2d_2b_1x1' % blk, num_classes, 1))
+        net.add(Flatten('%s/flat' % blk))
 
     # Final pooling and prediction
     # 8 x 8 x 2048
-    net.add(AvgPooling2D('%s/Logits/AvgPool_1a' % name, last_layer.get_output_sample_shape()[1:3],
1, border_mode='VALID'), last_layer)
+    blk = name + '/Logits'
+    last_layer = end_points[final_endpoint]
+    net.add(AvgPooling2D('%s/AvgPool_1a' % blk,
+                         last_layer.get_output_sample_shape()[1:3], 1,
+                         border_mode='VALID'), last_layer)
     # 1 x 1 x 2048
-    net.add(Dropout('%s/Logits/Dropout_1b' % name, 1 - dropout_keep_prob))
-    net.add(Flatten('%s/Logits/PreLogitsFlatten' % name))
+    net.add(Dropout('%s/Dropout_1b' % blk, 1 - dropout_keep_prob))
+    net.add(Conv2D('%s/Conv2d_1c_1x1' % blk, num_classes, 1))
+    end_points[blk] = net.add(Flatten('%s/flat' % blk))
     # 2048
-    end_points['Logits'] = net.add(Dense('%s/Logits/Logits' % name, num_classes))
     return net, end_points
 
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/3f39cfa2/examples/imagenet/inception/inception_v4.py
----------------------------------------------------------------------
diff --git a/examples/imagenet/inception/inception_v4.py b/examples/imagenet/inception/inception_v4.py
index b641706..9c5883f 100644
--- a/examples/imagenet/inception/inception_v4.py
+++ b/examples/imagenet/inception/inception_v4.py
@@ -39,9 +39,10 @@ from singa import net as ffnet
 
 ffnet.verbose = True
 
+
 def conv2d(net, name, nb_filter, k, s=1, border_mode='SAME', src=None):
     net.add(Conv2D(name, nb_filter, k, s, border_mode=border_mode,
-        use_bias=False), src)
+                   use_bias=False), src)
     net.add(BatchNormalization('%s/BatchNorm' % name))
     return net.add(Activation(name+'/relu'))
 
@@ -66,13 +67,13 @@ def block_reduction_a(blk, net):
     # By default use stride=1 and SAME padding
     s = net.add(Split('%s/Split' % blk, 3))
     br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 384, 3, 2,
-            border_mode='VALID', src=s)
+                 border_mode='VALID', src=s)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 192, 1, src=s)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_3x3' % blk, 224, 3)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 256, 3, 2,
-            border_mode='VALID')
+                 border_mode='VALID')
     br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2,
-        border_mode='VALID'), s)
+                               border_mode='VALID'), s)
     return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2])
 
 
@@ -97,17 +98,17 @@ def block_inception_b(blk, net):
 def block_reduction_b(blk, net):
     """Builds Reduction-B block for Inception v4 network."""
     # By default use stride=1 and SAME padding
-    s = net.add(Split('%s/Split' % blk , 3))
+    s = net.add(Split('%s/Split' % blk, 3))
     br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, 192, 1, src=s)
     br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 192, 3, 2,
-            border_mode='VALID')
+                 border_mode='VALID')
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 256, 1, src=s)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, 256, (1, 7))
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, 320, (7, 1))
     br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 320, 3, 2,
-            border_mode='VALID')
+                 border_mode='VALID')
     br2 = net.add(MaxPooling2D('%s/Branch_2/MaxPool_1a_3x3' % blk, 3, 2,
-        border_mode='VALID'), s)
+                               border_mode='VALID'), s)
     return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2])
 
 
@@ -123,9 +124,9 @@ def block_inception_c(blk, net):
     br11 = conv2d(net, '%s/Branch_1/Conv2d_0c_3x1' % blk, 256, (3, 1), src=br1)
     br1 = net.add(Concat('%s/Branch_1/Concat' % blk, 1), [br10, br11])
 
-    br2 =conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, 384, 1, src=s)
-    br2 =conv2d(net, '%s/Branch_2/Conv2d_0b_3x1' % blk, 448, (3, 1))
-    br2 =conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, 512, (1, 3))
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0a_1x1' % blk, 384, 1, src=s)
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0b_3x1' % blk, 448, (3, 1))
+    br2 = conv2d(net, '%s/Branch_2/Conv2d_0c_1x3' % blk, 512, (1, 3))
     br2 = net.add(Split('%s/Branch_2/Split' % blk, 2))
     br20 = conv2d(net, '%s/Branch_2/Conv2d_0d_1x3' % blk, 256, (1, 3), src=br2)
     br21 = conv2d(net, '%s/Branch_2/Conv2d_0e_3x1' % blk, 256, (3, 1), src=br2)
@@ -136,8 +137,8 @@ def block_inception_c(blk, net):
     return net.add(Concat('%s/Concat' % blk, 1), [br0, br1, br2, br3])
 
 
-def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
-        aux_endpoint='Inception/Mixed_6e'):
+def inception_v4_base(sample_shape, final_endpoint='Inception/Mixed_7d',
+                      aux_endpoint='Inception/Mixed_6e'):
     """Creates the Inception V4 network up to the given final endpoint.
 
     Endpoint name list: 'InceptionV4/' +
@@ -148,17 +149,18 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
         'Mixed_7d']
 
     Args:
-        inputs: a 4-D tensor of size [batch_size, height, width, 3].
+        sample_shape: input image sample shape, 3d tuple
         final_endpoint: specifies the endpoint to construct the network up to.
         aux_endpoint: for aux loss.
 
     Returns:
         the neural net
-        the last layer
         the set of end_points from the inception model.
     """
+    name = 'InceptionV4'
     end_points = {}
     net = ffnet.FeedForwardNet()
+
     def final_aux_check(block_name):
         if block_name == final_endpoint:
             return True
@@ -167,36 +169,34 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
             end_points[aux] = net.add(Split(aux, 2))
         return False
 
-
     # 299 x 299 x 3
-    blk = name + 'Conv2d_1a_3x3'
+    blk = name + '/Conv2d_1a_3x3'
     net.add(Conv2D(blk, 32, 3, 2, border_mode='VALID', use_bias=False,
-        input_sample_shape=sample_shape))
+                   input_sample_shape=sample_shape))
     net.add(BatchNormalization('%s/BatchNorm' % blk))
     end_points[blk] = net.add(Activation('%s/relu' % blk))
     if final_aux_check(blk):
         return net, end_points
 
     # 149 x 149 x 32
-    blk = name + 'Conv2d_2a_3x3'
-    end_points[blk] = conv2d(net, '%s/Conv2d_2a_3x3' % blk, 32, 3,
-            border_mode='VALID')
+    blk = name + '/Conv2d_2a_3x3'
+    end_points[blk] = conv2d(net, blk, 32, 3, border_mode='VALID')
     if final_aux_check(blk):
         return net, end_points
 
     # 147 x 147 x 32
-    blk = name + 'Conv2d_2b_3x3'
-    end_points[blk] = conv2d(net, '%s/Conv2d_2b_3x3' % blk, 64, 3)
+    blk = name + '/Conv2d_2b_3x3'
+    end_points[blk] = conv2d(net, blk, 64, 3)
     if final_aux_check(blk):
         return net, end_points
 
     # 147 x 147 x 64
     blk = name + '/Mixed_3a'
     s = net.add(Split('%s/Split' % blk, 2))
-    br0 = net.add(MaxPooling2D('%s/Branch_0/MaxPool_0a_3x3' % blk, 3,
-        2, border_mode='VALID'), s)
+    br0 = net.add(MaxPooling2D('%s/Branch_0/MaxPool_0a_3x3' % blk, 3, 2,
+                               border_mode='VALID'), s)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_3x3' % blk, 96, 3, 2,
-            border_mode='VALID', src=s)
+                 border_mode='VALID', src=s)
     end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1])
     if final_aux_check(blk):
         return net, end_points
@@ -206,12 +206,12 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
     s = net.add(Split('%s/Split' % blk, 2))
     br0 = conv2d(net, '%s/Branch_0/Conv2d_0a_1x1' % blk, 64, 1, src=s)
     br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 96, 3,
-            border_mode='VALID')
+                 border_mode='VALID')
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0a_1x1' % blk, 64, 1, src=s)
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0b_1x7' % blk, 64, (1, 7))
     br1 = conv2d(net, '%s/Branch_1/Conv2d_0c_7x1' % blk, 64, (7, 1))
     br1 = conv2d(net, '%s/Branch_1/Conv2d_1a_3x3' % blk, 96, 3,
-            border_mode='VALID')
+                 border_mode='VALID')
     end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1])
     if final_aux_check(blk):
         return net, end_points
@@ -220,9 +220,9 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
     blk = name + '/Mixed_5a'
     s = net.add(Split('%s/Split' % blk, 2))
     br0 = conv2d(net, '%s/Branch_0/Conv2d_1a_3x3' % blk, 192, 3, 2,
-            border_mode='VALID', src=s)
-    br1 = net.add(MaxPooling2D('%s/Branch_1/MaxPool_1a_3x3' % blk, 3,
-        2, border_mode='VALID'), s)
+                 border_mode='VALID', src=s)
+    br1 = net.add(MaxPooling2D('%s/Branch_1/MaxPool_1a_3x3' % blk, 3, 2,
+                               border_mode='VALID'), s)
     end_points[blk] = net.add(Concat('%s/Concat' % blk, 1), [br0, br1])
     if final_aux_check(blk):
         return net, end_points
@@ -265,12 +265,13 @@ def inception_v4_base(name, sample_shape, final_endpoint='Inception/Mixed_7d',
         if final_aux_check(blk):
             return net, end_points
 
-    return net, end_points
+    assert final_endpoint == blk, \
+        'final_enpoint = %s is not in the net' % final_endpoint
 
 
 def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True,
-        dropout_keep_prob=0.8, final_endpoint='InceptionV4/Mixed_7d',
-        aux_endpoint='InceptionV4/Mixed_6e'):
+               dropout_keep_prob=0.8, final_endpoint='InceptionV4/Mixed_7d',
+               aux_endpoint='InceptionV4/Mixed_6e'):
     """Creates the Inception V4 model.
 
     Args:
@@ -285,28 +286,30 @@ def create_net(num_classes=1001, sample_shape=(3, 299, 299), is_training=True,
     """
     end_points = {}
     name = 'InceptionV4'
-    net, end_points = inception_v4_base(name, sample_shape,
-            final_endpoint=final_endpoint, aux_endpoint=aux_endpoint)
+    net, end_points = inception_v4_base(sample_shape,
+                                        final_endpoint=final_endpoint,
+                                        aux_endpoint=aux_endpoint)
     # Auxiliary Head logits
     if aux_endpoint is not None:
         # 17 x 17 x 1024
         aux_logits = end_points[aux_endpoint + '-aux']
-        blk = 'AuxLogits'
+        blk = name + '/AuxLogits'
         net.add(AvgPooling2D('%s/AvgPool_1a_5x5' % blk, 5, stride=3,
-            border_mode='VALID'), aux_logits)
+                             border_mode='VALID'), aux_logits)
         t = conv2d(net, '%s/Conv2d_1b_1x1' % blk, 128, 1)
         conv2d(net, '%s/Conv2d_2a' % blk, 768,
-                t.get_output_sample_shape()[1:3], border_mode='VALID')
+               t.get_output_sample_shape()[1:3], border_mode='VALID')
         net.add(Flatten('%s/flat' % blk))
         end_points[blk] = net.add(Dense('%s/Aux_logits' % blk, num_classes))
 
     # Final pooling and prediction
     # 8 x 8 x 1536
-    blk = 'Logits'
+    blk = name + '/Logits'
     last_layer = end_points[final_endpoint]
     net.add(AvgPooling2D('%s/AvgPool_1a' % blk,
-        last_layer.get_output_sample_shape()[1:3], border_mode='VALID'),
-        last_layer)
+                         last_layer.get_output_sample_shape()[1:3],
+                         border_mode='VALID'),
+            last_layer)
     # 1 x 1 x 1536
     net.add(Dropout('%s/Dropout_1b' % blk, 1 - dropout_keep_prob))
     net.add(Flatten('%s/PreLogitsFlatten' % blk))



Mime
View raw message