madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <>
Subject [GitHub] [madlib] reductionista commented on a change in pull request #522: DL: Remove keras dependency
Date Mon, 30 Nov 2020 20:54:19 GMT

reductionista commented on a change in pull request #522:

File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -312,103 +201,129 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
-        # last iteration Call
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            **kwargs)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(
+        # We need to assert that the weights should be multiplied by final image count.
+        # So we divide and round the weights so that we can compare them with
+        # the original weights. They will only be equal if the learnt weights
+        # only go up or down by 1 which seems to be the case for our unit test.
+        # There might be a better way to assert this but this is good enough for now.
+        weights = np.rint(state[1:]/self.total_images_per_seg[0]).astype(
+        self.assertTrue(((weights == self.model_weights).all()))

Review comment:
       Maybe "no added benefit" is the wrong way to say it:  there might be a situation where
we'd want to know if the latest version of keras has changed enough that it's no longer compatible
with our code, but that's what dev-check tests and perf tests are supposed to be for.  The
goal of a unit test should be to only test a single function in MADlib (possibly including
some small helper functions) but not other functions it calls especially if they are outside
of MADlib.  (That way if the unit test fails, we know exactly where the problem is.)

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:

View raw message