superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: [sql lab] improve table name detection in free form SQL (#6793)
Date Tue, 05 Feb 2019 00:03:29 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a40f71  [sql lab] improve table name detection in free form SQL (#6793)
5a40f71 is described below

commit 5a40f7171079280c1c7d452e6f1344156b24a409
Author: Maxime Beauchemin <maximebeauchemin@gmail.com>
AuthorDate: Mon Feb 4 16:03:23 2019 -0800

    [sql lab] improve table name detection in free form SQL (#6793)
    
    * [sql lab] improve table name detection in free form SQL
    
    * flake
    
    * Addressing comments
---
 superset/sql_parse.py    | 67 +++++++++++++++++++++---------------------------
 tests/sql_parse_tests.py | 42 +++++++++++++++++++++++++++++-
 2 files changed, 70 insertions(+), 39 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 241917a..d1ad23d 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -23,7 +23,10 @@ from sqlparse.tokens import Keyword, Name
 
 RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
 ON_KEYWORD = 'ON'
-PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'}
+PRECEDES_TABLE_NAME = {
+    'FROM', 'JOIN', 'DESCRIBE', 'WITH', 'LEFT JOIN', 'RIGHT JOIN',
+}
+CTE_PREFIX = 'CTE__'
 
 
 class ParsedQuery(object):
@@ -72,13 +75,6 @@ class ParsedQuery(object):
         return statements
 
     @staticmethod
-    def __precedes_table_name(token_value):
-        for keyword in PRECEDES_TABLE_NAME:
-            if keyword in token_value:
-                return True
-        return False
-
-    @staticmethod
     def __get_full_name(identifier):
         if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.':
             return '{}.{}'.format(identifier.tokens[0].value,
@@ -86,20 +82,15 @@ class ParsedQuery(object):
         return identifier.get_real_name()
 
     @staticmethod
-    def __is_result_operation(keyword):
-        for operation in RESULT_OPERATIONS:
-            if operation in keyword.upper():
-                return True
-        return False
-
-    @staticmethod
     def __is_identifier(token):
         return isinstance(token, (IdentifierList, Identifier))
 
     def __process_identifier(self, identifier):
         # exclude subselects
-        if '(' not in '{}'.format(identifier):
-            self._table_names.add(self.__get_full_name(identifier))
+        if '(' not in str(identifier):
+            table_name = self.__get_full_name(identifier)
+            if not table_name.startswith(CTE_PREFIX):
+                self._table_names.add(self.__get_full_name(identifier))
             return
 
         # store aliases
@@ -129,39 +120,39 @@ class ParsedQuery(object):
         exec_sql += f'CREATE TABLE {table_name} AS \n{sql}'
         return exec_sql
 
-    def __extract_from_token(self, token):
+    def __extract_from_token(self, token, depth=0):
         if not hasattr(token, 'tokens'):
             return
 
         table_name_preceding_token = False
 
         for item in token.tokens:
+            logging.debug(('  ' * depth) + str(item.ttype) + str(item.value))
             if item.is_group and not self.__is_identifier(item):
-                self.__extract_from_token(item)
+                self.__extract_from_token(item, depth=depth + 1)
+
+            if (
+                    item.ttype in Keyword and (
+                        item.normalized in PRECEDES_TABLE_NAME or
+                        item.normalized.endswith(' JOIN')
+                    )):
+                table_name_preceding_token = True
+                continue
 
             if item.ttype in Keyword:
-                if self.__precedes_table_name(item.value.upper()):
-                    table_name_preceding_token = True
-                    continue
-
-            if not table_name_preceding_token:
+                table_name_preceding_token = False
                 continue
 
-            if item.ttype in Keyword or item.value == ',':
-                if (self.__is_result_operation(item.value) or
-                        item.value.upper() == ON_KEYWORD):
-                    table_name_preceding_token = False
-                    continue
-                # FROM clause is over
-                break
-
-            if isinstance(item, Identifier):
-                self.__process_identifier(item)
-
-            if isinstance(item, IdentifierList):
-                for token in item.tokens:
-                    if self.__is_identifier(token):
+            if table_name_preceding_token:
+                if isinstance(item, Identifier):
+                    self.__process_identifier(item)
+                elif isinstance(item, IdentifierList):
+                    for token in item.get_identifiers():
                         self.__process_identifier(token)
+            elif isinstance(item, IdentifierList):
+                for token in item.tokens:
+                    if not self.__is_identifier(token):
+                        self.__extract_from_token(item, depth=depth + 1)
 
     def _get_limit_from_token(self, token):
         if token.ttype == sqlparse.tokens.Literal.Number.Integer:
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index 9247780..e821fce 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -167,7 +167,6 @@ class SupersetTestCase(unittest.TestCase):
     # DESCRIBE | DESC qualifiedName
     def test_describe(self):
         self.assertEquals({'t1'}, self.extract_tables('DESCRIBE t1'))
-        self.assertEquals({'t1'}, self.extract_tables('DESC t1'))
 
     # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
     # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@@ -349,6 +348,32 @@ class SupersetTestCase(unittest.TestCase):
             {'table_a', 'table_b', 'table_c'},
             self.extract_tables(query))
 
+    def test_mixed_from_clause(self):
+        query = """SELECT *
+            FROM table_a AS a, (select * from table_b) AS b, table_c as c
+            WHERE a.id = b.id and b.id = c.id"""
+        self.assertEquals(
+            {'table_a', 'table_b', 'table_c'},
+            self.extract_tables(query))
+
+    def test_nested_selects(self):
+        query = """
+            select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
+            from INFORMATION_SCHEMA.COLUMNS
+            WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
+        """
+        self.assertEquals(
+            {'INFORMATION_SCHEMA.COLUMNS'},
+            self.extract_tables(query))
+        query = """
+            select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
+            from INFORMATION_SCHEMA.COLUMNS
+            WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
+        """
+        self.assertEquals(
+            {'INFORMATION_SCHEMA.COLUMNS'},
+            self.extract_tables(query))
+
     def test_complex_extract_tables3(self):
         query = """SELECT somecol AS somecol
             FROM
@@ -386,6 +411,21 @@ class SupersetTestCase(unittest.TestCase):
             {'a', 'b', 'c', 'd', 'e', 'f'},
             self.extract_tables(query))
 
+    def test_complex_cte_with_prefix(self):
+        query = """
+        WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
+        AS (
+            SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
+            FROM SalesOrderHeader
+            WHERE SalesPersonID IS NOT NULL
+        )
+        SELECT SalesPersonID, COUNT(SalesOrderID) AS TotalSales, SalesYear
+        FROM CTE__test
+        GROUP BY SalesYear, SalesPersonID
+        ORDER BY SalesPersonID, SalesYear;
+        """
+        self.assertEquals({'SalesOrderHeader'}, self.extract_tables(query))
+
     def test_basic_breakdown_statements(self):
         multi_sql = """
         SELECT * FROM ab_user;


Mime
View raw message