mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <>
Subject [GitHub] szha closed pull request #11749: [MXNET-8230] test_operator_gpu.test_rms fails
Date Wed, 18 Jul 2018 03:45:19 GMT
szha closed pull request #11749: [MXNET-8230] test_operator_gpu.test_rms fails

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/tests/python/unittest/ b/tests/python/unittest/
index a5b3d4047df..fdf7d279d9c 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -835,8 +835,7 @@ def update(self, index, weight, grad, state):
         if self.clip_weights:
              mx.ndarray.clip(weight, -self.clip_weights, self.clip_weights, out=weight)
-@unittest.skip("Test fails intermittently. Temporarily disabled until fixed. Tracked at")
 def test_rms():
     opt1 = PyRMSProp
     opt2 = mx.optimizer.RMSProp
@@ -848,6 +847,9 @@ def test_rms():
     wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
     mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
     for dtype in [np.float16, np.float32]:
+        # Reduce foating point compare tolerance to avoid flaky test failure.
+        rtol, atol = (1e-1, 1e-1) if dtype is np.float16 else (1e-2, 1e-2)
         for cw_option in cw_options:
             for cg_option in cg_options:
                 for center_option in center_options:
@@ -865,9 +867,9 @@ def test_rms():
                                         ('multi_precision' not in kwarg or
                                             not kwarg['multi_precision'])):
-                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)
+                                compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype,
rtol=rtol, atol=atol)
                                 if (default_context() == mx.cpu()):
-                                    compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
dtype, g_stype='row_sparse')
+                                    compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape,
dtype, g_stype='row_sparse', rtol=rtol, atol=atol)
 class PyFtrl(mx.optimizer.Optimizer):
     """The Ftrl optimizer.


This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:

With regards,
Apache Git Services

View raw message