madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [madlib] khannaekta commented on a change in pull request #522: DL: Remove keras dependency
Date Tue, 01 Dec 2020 00:27:59 GMT

khannaekta commented on a change in pull request #522:
URL: https://github.com/apache/madlib/pull/522#discussion_r532992630



##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -312,103 +201,129 @@ class MadlibKerasFitTestCase(unittest.TestCase):
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
-        # set_session is only called for the last buffer
-        self.assertEqual(0, self.subject.K.set_session.call_count)
-        # Clear session and sess.close must not get called for the middle buffer
-        self.assertEqual(0, self.subject.K.clear_session.call_count)
-        self.assertTrue('segment_model' not in k['SD'])
-        self.assertTrue('cache_set' not in k['SD'])
-        self.assertTrue(k['SD']['x_train'])
-        self.assertTrue(k['SD']['y_train'])
-
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
-        #TODO should we mock tensorflow's close_session and keras'
-        # clear_session instead of mocking the function `K.clear_session`
-        self.subject.K.set_session = Mock()
-        self.subject.K.clear_session = Mock()
+        self.assertTrue('sess' not in k['GD'])
+        self.assertTrue('segment_model' not in k['GD'])
+        self.assertTrue('cache_set' not in k['GD'])
+        self.assertTrue(k['GD']['x_train'])
+        self.assertTrue(k['GD']['y_train'])
+
+    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        # last iteration Call
-
-        multiplied_weights = mult(self.total_images_per_seg[0],self.model_weights)
-
-        self.subject.compile_and_set_weights(self.model, self.compile_params,
-                                             '/cpu:0', self.serialized_weights)
-        k = {'SD': {'segment_model' :self.model, 'sess': Mock()}}
-
         state = starting_image_count
         new_state = self.subject.fit_transition(
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **k)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
-        weights = np.rint(state[1:]).astype(np.int)
+        # We need to assert that the weights should be multiplied by final image count.
+        # So we divide and round the weights so that we can compare them with
+        # the original weights. They will only be equal if the learnt weights
+        # only go up or down by 1 which seems to be the case for our unit test.
+        # There might be a better way to assert this but this is good enough for now.
+        weights = np.rint(state[1:]/self.total_images_per_seg[0]).astype(np.int)
+        self.assertTrue(((weights == self.model_weights).all()))

Review comment:
       Finally, got the test to mock the call to the fit. :) 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message