madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [madlib] reductionista commented on a change in pull request #522: DL: Remove keras dependency
Date Tue, 01 Dec 2020 02:24:14 GMT

reductionista commented on a change in pull request #522:
URL: https://github.com/apache/madlib/pull/522#discussion_r533030552



##########
File path: src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
##########
@@ -569,6 +466,198 @@ class MadlibKerasFitTestCase(unittest.TestCase):
     def test_fit_transition_last_buffer_pass_gpdb(self):
         self._test_fit_transition_last_buffer_pass(False)
 
+    ############### GRAPH AND SESSION TESTS ################################
+    def test_fit_eval_2_iterations_mcf_null_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit for 2 iterations ##########
+        # iteration 1
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 2 (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        ###################### eval transition for last iteration ###########
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_eval_2_iterations_mcf_1_gpdb(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ######################### fit + eval for 2 iterations ##########
+        # iteration 1 fit
+        first_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        # iteration 1 eval
+        self._run_eval_iteration(False, first_iter_keras_sess, first_iter_tf_graph, **kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        eval_first_iter_keras_sess = self.subject.K.get_session()
+        eval_first_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(eval_first_iter_keras_sess, first_iter_keras_sess)
+        self.assertEquals(eval_first_iter_tf_graph, first_iter_tf_graph)
+
+        # iteration 2 fit (last iteration)
+        last_iter_keras_sess = self._run_fit_iteration(**kwargs)
+        self._assert_keras_session_same_as_gd_session(GD)
+
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+        # iteration 2 eval (last iteration)
+        self._run_eval_iteration(True, last_iter_keras_sess, last_iter_tf_graph, **kwargs)
+
+        eval_last_iter_keras_sess = self.subject.K.get_session()
+        eval_last_iter_tf_graph = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_last_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(eval_last_iter_tf_graph, last_iter_tf_graph)
+        self._assert_gd_cleared(GD)
+
+    def test_fit_multiple_2_iterations(self):
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        ############ fit multiple for 2 iterations ##########
+        # iteration 1
+        # first_iter_tf_graph is used to assert that calling fit_multiple clears the tf session
+        # and graph at the last buffer.
+        # It is fetched prior to calling the fit_transition(from fit_multiple) as when we
create
+        # a session inside fit_transition, instead of creating a new graph it will use first_iter_tf_graph.
+        # This enables us to do the not equals assert.
+        first_iter_tf_graph = self.subject.tf.get_default_graph()
+        first_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        # iteration 2 (last iteration)
+        last_iter_tf_graph = self.subject.tf.get_default_graph()
+        last_iter_keras_sess = self._run_fit_multiple_iteration(**kwargs)
+        self._assert_gd_cleared(GD)
+
+        self.assertNotEquals(first_iter_keras_sess, last_iter_keras_sess)
+        self.assertNotEquals(first_iter_tf_graph, last_iter_tf_graph)
+
+    def test_eval_multiple_any_iteration(self):
+        # This test tests 2 things:
+        # 1. Calling eval_transition from fit_multiple
+        # 2. Calling eval_transition from evaluate directly
+        kwargs = {'GD': {}}
+        GD = kwargs['GD']
+
+        # eval_iter_tf_graph1 is used to assert that calling eval clears the tf session and
graph
+        # It is fetched prior to calling the eval_transition as when we create a session
inside
+        # eval_transition, instead of creating a new graph it will use eval_iter_tf_graph1.
+        # This enables us to do the not equals assert.
+        eval_iter_tf_graph1 = self.subject.tf.get_default_graph()
+        eval_iter_keras_sess1 = self._run_eval_iteration(True, None, None, True, **kwargs)
+        eval_iter_keras_sess2 = self.subject.K.get_session()
+        eval_iter_tf_graph2 = self.subject.tf.get_default_graph()
+
+        self.assertNotEquals(eval_iter_keras_sess1, eval_iter_keras_sess2)
+        self.assertNotEquals(eval_iter_tf_graph1, eval_iter_tf_graph2)
+        self._assert_gd_cleared(GD)
+
+    def _run_eval_iteration(self, final_iteration, prev_keras_sess, prev_tf_graph, called_from_fit_multiple=False,
**kwargs):
+        self._test_internal_keras_eval_transition_first_buffer(final_iteration,
+                                                               **kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+
+        eval_first_buffer_keras_sess = kwargs['GD']['sess']
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+        eval_first_buffer_tf_graph = self.subject.tf.get_default_graph()
+
+        if not called_from_fit_multiple:
+            self.assertEquals(eval_first_buffer_keras_sess, prev_keras_sess)
+            self.assertEquals(eval_first_buffer_tf_graph, prev_tf_graph)
+
+        self._test_internal_keras_eval_transition_middle_buffer(final_iteration,
+                                                                **kwargs )
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(eval_first_buffer_keras_sess._closed)
+
+        self._test_internal_keras_eval_transition_last_buffer(final_iteration,
+                                                              **kwargs)
+        if final_iteration:
+            self._assert_gd_cleared(kwargs['GD'])
+            self.assertTrue(eval_first_buffer_keras_sess._closed)
+        else:
+            self._assert_gd_is_valid(kwargs['GD'])
+            self.assertFalse(eval_first_buffer_keras_sess._closed)
+        return eval_first_buffer_keras_sess
+
+    def _run_fit_iteration(self, **kwargs):
+        self._test_fit_transition_first_buffer_pass(**kwargs)
+        gd_first_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_first_buffer)
+        iter_sess = gd_first_buffer['sess']
+        self.assertFalse(iter_sess._closed)
+        self._assert_keras_session_same_as_gd_session(gd_first_buffer)
+
+        self._test_fit_transition_middle_buffer_pass(**kwargs)
+        gd_middle_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_middle_buffer)
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_last_buffer_pass(**kwargs)
+        gd_last_buffer = kwargs['GD']
+        self._assert_gd_is_valid(gd_last_buffer)
+        self.assertFalse(iter_sess._closed)
+        return iter_sess
+
+    def _run_fit_multiple_iteration(self, **kwargs):
+        self._test_fit_transition_multiple_model_no_cache_first_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        iter_sess = kwargs['GD']['sess']
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_middle_buffer_pass(**kwargs)
+        self._assert_gd_is_valid(kwargs['GD'])
+        self._assert_keras_session_same_as_gd_session(kwargs['GD'])
+        self.assertFalse(iter_sess._closed)
+
+        self._test_fit_transition_multiple_model_no_cache_last_buffer_pass(**kwargs)
+        self._assert_gd_cleared(kwargs['GD'])
+        self.assertTrue(iter_sess._closed)
+        return iter_sess

Review comment:
       Oh, interesting.  Then yes, let's come back to it once Model Hopper Refactor is merged.
   The hop in that branch always happens before the UDA runs rather than after.  Whatever
weights get returned will always match the temporary model_output table until the next time
run_training() gets called.  So I think it should eliminate this obstacle.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message