airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1168] Add closing() to all connections and cursors
Date Fri, 12 May 2017 09:26:44 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master 443e6b295 -> 8aeebd488


[AIRFLOW-1168] Add closing() to all connections and cursors

This will prevent any left-open connections
whenever an exception occurs

Closes #2269 from NielsZeilemaker/AIRFLOW-1168


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/8aeebd48
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/8aeebd48
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/8aeebd48

Branch: refs/heads/master
Commit: 8aeebd488416bd7618d36c64c49eca58f3f45e0d
Parents: 443e6b2
Author: Niels Zeilemaker <nielszeilemaker@godatadriven.com>
Authored: Fri May 12 11:26:29 2017 +0200
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Fri May 12 11:26:29 2017 +0200

----------------------------------------------------------------------
 airflow/hooks/dbapi_hook.py    | 125 +++++++++++++++++-------------------
 tests/hooks/test_dbapi_hook.py |  76 ++++++++++++++++++++++
 2 files changed, 136 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/8aeebd48/airflow/hooks/dbapi_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
index df52e54..75ca409 100644
--- a/airflow/hooks/dbapi_hook.py
+++ b/airflow/hooks/dbapi_hook.py
@@ -15,6 +15,7 @@
 from builtins import str
 from past.builtins import basestring
 from datetime import datetime
+from contextlib import closing
 import numpy
 import logging
 import sys
@@ -87,10 +88,9 @@ class DbApiHook(BaseHook):
         if sys.version_info[0] < 3:
             sql = sql.encode('utf-8')
         import pandas.io.sql as psql
-        conn = self.get_conn()
-        df = psql.read_sql(sql, con=conn, params=parameters)
-        conn.close()
-        return df
+        
+        with closing(self.get_conn()) as conn:        
+            return psql.read_sql(sql, con=conn, params=parameters)
 
     def get_records(self, sql, parameters=None):
         """
@@ -104,16 +104,14 @@ class DbApiHook(BaseHook):
         """
         if sys.version_info[0] < 3:
             sql = sql.encode('utf-8')
-        conn = self.get_conn()
-        cur = self.get_cursor()
-        if parameters is not None:
-            cur.execute(sql, parameters)
-        else:
-            cur.execute(sql)
-        rows = cur.fetchall()
-        cur.close()
-        conn.close()
-        return rows
+            
+        with closing(self.get_conn()) as conn:
+            with closing(conn.cursor()) as cur:
+                if parameters is not None:
+                    cur.execute(sql, parameters)
+                else:
+                    cur.execute(sql)
+                return cur.fetchall()
 
     def get_first(self, sql, parameters=None):
         """
@@ -127,16 +125,14 @@ class DbApiHook(BaseHook):
         """
         if sys.version_info[0] < 3:
             sql = sql.encode('utf-8')
-        conn = self.get_conn()
-        cur = conn.cursor()
-        if parameters is not None:
-            cur.execute(sql, parameters)
-        else:
-            cur.execute(sql)
-        rows = cur.fetchone()
-        cur.close()
-        conn.close()
-        return rows
+        
+        with closing(self.get_conn()) as conn:
+            with closing(conn.cursor()) as cur:
+                if parameters is not None:
+                    cur.execute(sql, parameters)
+                else:
+                    cur.execute(sql)
+                return cur.fetchone()
 
     def run(self, sql, autocommit=False, parameters=None):
         """
@@ -153,25 +149,24 @@ class DbApiHook(BaseHook):
         :param parameters: The parameters to render the SQL query with.
         :type parameters: mapping or iterable
         """
-        conn = self.get_conn()
         if isinstance(sql, basestring):
             sql = [sql]
-
-        if self.supports_autocommit:
-            self.set_autocommit(conn, autocommit)
-
-        cur = conn.cursor()
-        for s in sql:
-            if sys.version_info[0] < 3:
-                s = s.encode('utf-8')
-            logging.info(s)
-            if parameters is not None:
-                cur.execute(s, parameters)
-            else:
-                cur.execute(s)
-        cur.close()
-        conn.commit()
-        conn.close()
+        
+        with closing(self.get_conn()) as conn:
+            if self.supports_autocommit:
+                self.set_autocommit(conn, autocommit)
+            
+            with closing(conn.cursor()) as cur:
+                for s in sql:
+                    if sys.version_info[0] < 3:
+                        s = s.encode('utf-8')
+                    logging.info(s)
+                    if parameters is not None:
+                        cur.execute(s, parameters)
+                    else:
+                        cur.execute(s)
+            
+            conn.commit()
 
     def set_autocommit(self, conn, autocommit):
         conn.autocommit = autocommit
@@ -202,30 +197,30 @@ class DbApiHook(BaseHook):
             target_fields = "({})".format(target_fields)
         else:
             target_fields = ''
-        conn = self.get_conn()
-        if self.supports_autocommit:
-            self.set_autocommit(conn, False)
-        conn.commit()
-        cur = conn.cursor()
-        i = 0
-        for row in rows:
-            i += 1
-            l = []
-            for cell in row:
-                l.append(self._serialize_cell(cell, conn))
-            values = tuple(l)
-            sql = "INSERT INTO {0} {1} VALUES ({2});".format(
-                table,
-                target_fields,
-                ",".join(values))
-            cur.execute(sql)
-            if commit_every and i % commit_every == 0:
-                conn.commit()
-                logging.info(
-                    "Loaded {i} into {table} rows so far".format(**locals()))
-        conn.commit()
-        cur.close()
-        conn.close()
+            
+        with closing(self.get_conn()) as conn:
+            if self.supports_autocommit:
+                self.set_autocommit(conn, False)
+            
+            conn.commit()
+            
+            with closing(conn.cursor()) as cur:
+                for i, row in enumerate(rows, 1):
+                    l = []
+                    for cell in row:
+                        l.append(self._serialize_cell(cell, conn))
+                    values = tuple(l)
+                    sql = "INSERT INTO {0} {1} VALUES ({2});".format(
+                        table,
+                        target_fields,
+                        ",".join(values))
+                    cur.execute(sql)
+                    if commit_every and i % commit_every == 0:
+                        conn.commit()
+                        logging.info(
+                            "Loaded {i} into {table} rows so far".format(**locals()))
+            
+            conn.commit()
         logging.info(
             "Done loading. Loaded a total of {i} rows".format(**locals()))
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/8aeebd48/tests/hooks/test_dbapi_hook.py
----------------------------------------------------------------------
diff --git a/tests/hooks/test_dbapi_hook.py b/tests/hooks/test_dbapi_hook.py
new file mode 100644
index 0000000..f9e7b37
--- /dev/null
+++ b/tests/hooks/test_dbapi_hook.py
@@ -0,0 +1,76 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import mock
+import unittest
+
+from airflow.hooks.dbapi_hook import DbApiHook
+
+
+class TestDbApiHook(unittest.TestCase):
+
+    def setUp(self):
+        super(TestDbApiHook, self).setUp()
+        
+        self.cur = mock.MagicMock()
+        self.conn = conn = mock.MagicMock()
+        self.conn.cursor.return_value = self.cur
+        
+        class TestDBApiHook(DbApiHook):
+            conn_name_attr = 'test_conn_id'
+            
+            def get_conn(self):
+                return conn
+        
+        self.db_hook = TestDBApiHook()
+
+    def test_get_records(self):
+        statement = "SQL"
+        rows = [("hello",),
+                ("world",)]
+        
+        self.cur.fetchall.return_value = rows
+        
+        self.assertEqual(rows, self.db_hook.get_records(statement))
+        
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+        self.cur.execute.assert_called_once_with(statement)
+        
+    def test_get_records_parameters(self):
+        statement = "SQL"
+        parameters = ["X", "Y", "Z"]
+        rows = [("hello",),
+                ("world",)]
+        
+        self.cur.fetchall.return_value = rows
+        
+        
+        self.assertEqual(rows, self.db_hook.get_records(statement, parameters))
+        
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+        self.cur.execute.assert_called_once_with(statement, parameters)
+        
+    def test_get_records_exception(self):
+        statement = "SQL"
+        self.cur.fetchall.side_effect = RuntimeError('Great Problems')
+        
+        with self.assertRaises(RuntimeError):
+            self.db_hook.get_records(statement)
+        
+        self.conn.close.assert_called_once()
+        self.cur.close.assert_called_once()
+        self.cur.execute.assert_called_once_with(statement)
\ No newline at end of file


Mime
View raw message