madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <>
Subject [GitHub] [madlib] khannaekta commented on a change in pull request #522: DL: Remove keras dependency
Date Mon, 30 Nov 2020 22:10:23 GMT

khannaekta commented on a change in pull request #522:

File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -312,103 +201,129 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
-        # last iteration Call
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            **kwargs)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(
+        # We need to assert that the weights should be multiplied by final image count.
+        # So we divide and round the weights so that we can compare them with
+        # the original weights. They will only be equal if the learnt weights
+        # only go up or down by 1 which seems to be the case for our unit test.
+        # There might be a better way to assert this but this is good enough for now.
+        weights = np.rint(state[1:]/self.total_images_per_seg[0]).astype(
+        self.assertTrue(((weights == self.model_weights).all()))

Review comment:
       > I'm not sure how these tests could have passed if we weren't mocking though--it
should have failed if any weight changed by any amount, no matter how small, right?
   The np.rint() in `weights = np.rint(state[1:]).astype(` was actually rounding it,
which resulted in the same value as the initial weights that is why it didn't fail earlier.
   Agreed, that we should mock fit. We did spend sometime trying to mock it, but it wasn't
straight forward. I will look into it a bit more but FWIW, I think we can do it as part of
a separate PR.

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:

View raw message