mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zhash...@apache.org
Subject [incubator-mxnet] branch master updated: [MXNET-338] Fix symbol boolean evaluation (#10618)
Date Fri, 20 Apr 2018 04:41:57 GMT
This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6fb7c85  [MXNET-338] Fix symbol boolean evaluation (#10618)
6fb7c85 is described below

commit 6fb7c8518651843f2c18e452f122aa2291981f32
Author: Aston Zhang <22279212+astonzhang@users.noreply.github.com>
AuthorDate: Thu Apr 19 21:41:51 2018 -0700

    [MXNET-338] Fix symbol boolean evaluation (#10618)
    
    * Fix symbol boolean evaluation
    
    * nonzeros
    
    * NotImplementedForSymbol
---
 python/mxnet/symbol/symbol.py        | 5 +++++
 tests/python/unittest/test_symbol.py | 7 ++++++-
 2 files changed, 11 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 915dfc9..1ab7cf8 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -103,6 +103,11 @@ class Symbol(SymbolBase):
         else:
             raise TypeError('type %s not supported' % str(type(other)))
 
+    def __bool__(self):
+        raise NotImplementedForSymbol(self.__bool__, 'bool')
+
+    __nonzero__ = __bool__
+
     def __iadd__(self, other):
         raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1)
 
diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py
index 11c2eba..387428a 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -20,7 +20,8 @@ import os
 import re
 import mxnet as mx
 import numpy as np
-from common import models
+from common import assertRaises, models
+from mxnet.base import NotImplementedForSymbol
 from mxnet.test_utils import discard_stderr
 import pickle as pkl
 
@@ -31,6 +32,10 @@ def test_symbol_basic():
         m.list_arguments()
         m.list_outputs()
 
+def test_symbol_bool():
+    x = mx.symbol.Variable('x')
+    assertRaises(NotImplementedForSymbol, bool, x)
+
 def test_symbol_compose():
     data = mx.symbol.Variable('data')
     net1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=10)

-- 
To stop receiving notification emails like this one, please contact
zhasheng@apache.org.

Mime
View raw message