singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [10/15] incubator-singa git commit: SINGA-290 Upgrade to Python 3
Date Fri, 04 Aug 2017 08:32:54 GMT
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/examples/cifar10/predict.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py
index 123818a..7cab4b9 100644
--- a/examples/cifar10/predict.py
+++ b/examples/cifar10/predict.py
@@ -52,7 +52,7 @@ def predict(net, images, dev, topk=5):
 def load_dataset(filepath):
     print('Loading data file %s' % filepath)
     with open(filepath, 'rb') as fd:
-        cifar10 = pickle.load(fd)
+        cifar10 = pickle.load(fd, encoding='latin1')
     image = cifar10['data'].astype(dtype=np.uint8)
     image = image.reshape((-1, 3, 32, 32))
     label = np.asarray(cifar10['labels'], dtype=np.uint8)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index 8204055..9f90e58 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -48,7 +48,7 @@ import resnet
 def load_dataset(filepath):
     print('Loading data file %s' % filepath)
     with open(filepath, 'rb') as fd:
-        cifar10 = pickle.load(fd)
+        cifar10 = pickle.load(fd, encoding='latin1')
     image = cifar10['data'].astype(dtype=np.uint8)
     image = image.reshape((-1, 3, 32, 32))
     label = np.asarray(cifar10['labels'], dtype=np.uint8)

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index c49b9fa..a53fc68 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -487,6 +487,7 @@ class FeedForwardNet(object):
                 f = f[0:-4]
             sp = snapshot.Snapshot(f, False, buffer_size)
             params = sp.read()
+        version = __version__
         if 'SINGA_VERSION' in params:
             version = params['SINGA_VERSION']
         for name, val in zip(self.param_names(), self.param_values()):

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/python/singa/snapshot.py
----------------------------------------------------------------------
diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py
index 392ab3d..a4ac988 100644
--- a/python/singa/snapshot.py
+++ b/python/singa/snapshot.py
@@ -36,7 +36,6 @@ from builtins import object
 from . import singa_wrap as singa
 from . import tensor
 
-
 class Snapshot(object):
     ''' Class and member functions for singa::Snapshot.
 
@@ -58,7 +57,7 @@ class Snapshot(object):
             param_name (string): name of the parameter
             param_val (Tensor): value tensor of the parameter
         '''
-        self.snapshot.Write(str(param_name).encode(), param_val.singa_tensor)
+        self.snapshot.Write(param_name.encode(), param_val.singa_tensor)
 
     def read(self):
         '''Call read method to load all (param_name, param_val)


Mime
View raw message