madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [madlib] khannaekta commented on a change in pull request #522: DL: Remove keras dependency
Date Tue, 01 Dec 2020 01:23:34 GMT

khannaekta commented on a change in pull request #522:
URL: https://github.com/apache/madlib/pull/522#discussion_r533011189



##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we
create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and
graph
+        # It is fetched prior to calling the eval_transition as when we create a session
inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False,
**kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess

Review comment:
       Agreed, we should re-evaluate this when we move to 2.x.
   Since we already hop the model and then call evaluate (this behavior will change as part
of the https://github.com/apache/madlib/pull/525) and fetch the session/model from GD and
read the weights from the output table, it might fail as the weights and the model will not
be for the same model. 
   We can probably come back to this after merging the Model Hopper Refactor PR(https://github.com/apache/madlib/pull/525)
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message