superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: Refactor dataframe and column name mutation logic (#6847)
Date Thu, 21 Feb 2019 07:05:41 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new f5277fe  Refactor dataframe and column name mutation logic (#6847)
f5277fe is described below

commit f5277fe6843be27e639b954e0a37a680faf22374
Author: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
AuthorDate: Thu Feb 21 09:05:35 2019 +0200

    Refactor dataframe and column name mutation logic (#6847)
    
    * Merge dataframe and column name mutation logic, add flag for disabling column aliases
and add column name length checking
    
    * Remove custome mutate_label from oracle spec
    
    * Move hashing from mutate_label() to make_label_compatible()
    
    * Remove empty line
    
    * Make label mutating and truncating more robust
    
    * Rename variables and make proposed changes from review
    
    * Always execute labels_expected codepath
    
    * Fix linting error
    
    * Add comments and fix subquery errors
    
    * Refine column compatibility
    
    * Simplify label assignment
    
    * Add unit tests for BQ and Oracle
    
    * Linting
---
 superset/connectors/sqla/models.py |  87 +++++++++++++-------------
 superset/db_engine_specs.py        | 125 ++++++++++++++++++++-----------------
 tests/db_engine_specs_test.py      |  29 ++++++++-
 3 files changed, 137 insertions(+), 104 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index ff2821b..8183d7c 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -116,14 +116,14 @@ class TableColumn(Model, BaseColumn):
     export_parent = 'table'
 
     def get_sqla_col(self, label=None):
-        label = label if label else self.column_name
-        label = self.table.get_label(label)
+        label = label or self.column_name
         if not self.expression:
             db_engine_spec = self.table.database.db_engine_spec
             type_ = db_engine_spec.get_sqla_column_type(self.type)
-            col = column(self.column_name, type_=type_).label(label)
+            col = column(self.column_name, type_=type_)
         else:
-            col = literal_column(self.expression).label(label)
+            col = literal_column(self.expression)
+        col = self.table.make_sqla_column_compatible(col, label)
         return col
 
     @property
@@ -142,13 +142,14 @@ class TableColumn(Model, BaseColumn):
 
     def get_timestamp_expression(self, time_grain):
         """Getting the time component of the query"""
-        label = self.table.get_label(utils.DTTM_ALIAS)
+        label = utils.DTTM_ALIAS
 
         db = self.table.database
         pdf = self.python_date_format
         is_epoch = pdf in ('epoch_s', 'epoch_ms')
         if not self.expression and not time_grain and not is_epoch:
-            return column(self.column_name, type_=DateTime).label(label)
+            sqla_col = column(self.column_name, type_=DateTime)
+            return self.table.make_sqla_column_compatible(sqla_col, label)
         grain = None
         if time_grain:
             grain = db.grains_dict().get(time_grain)
@@ -158,7 +159,8 @@ class TableColumn(Model, BaseColumn):
         expr = db.db_engine_spec.get_time_expr(
             self.expression or self.column_name,
             pdf, time_grain, grain)
-        return literal_column(expr, type_=DateTime).label(label)
+        sqla_col = literal_column(expr, type_=DateTime)
+        return self.table.make_sqla_column_compatible(sqla_col, label)
 
     @classmethod
     def import_obj(cls, i_column):
@@ -218,9 +220,9 @@ class SqlMetric(Model, BaseMetric):
     export_parent = 'table'
 
     def get_sqla_col(self, label=None):
-        label = label if label else self.metric_name
-        label = self.table.get_label(label)
-        return literal_column(self.expression).label(label)
+        label = label or self.metric_name
+        sqla_col = literal_column(self.expression)
+        return self.table.make_sqla_column_compatible(sqla_col, label)
 
     @property
     def perm(self):
@@ -298,20 +300,19 @@ class SqlaTable(Model, BaseDatasource):
         'MAX': sa.func.MAX,
     }
 
-    def get_label(self, label):
-        """Conditionally mutate a label to conform to db engine requirements
-        and store mapping from mutated label to original label
-
-        :param label: original label
-        :return: Either a string or sqlalchemy.sql.elements.quoted_name if required
-        by db engine
+    def make_sqla_column_compatible(self, sqla_col, label=None):
+        """Takes a sql alchemy column object and adds label info if supported by engine.
+        :param sqla_col: sql alchemy column instance
+        :param label: alias/label that column is expected to have
+        :return: either a sql alchemy column or label instance if supported by engine
         """
+        label_expected = label or sqla_col.name
         db_engine_spec = self.database.db_engine_spec
-        sqla_label = db_engine_spec.make_label_compatible(label)
-        mutated_label = str(sqla_label)
-        if label != mutated_label:
-            self.mutated_labels[mutated_label] = label
-        return sqla_label
+        if db_engine_spec.supports_column_aliases:
+            label = db_engine_spec.make_label_compatible(label_expected)
+            sqla_col = sqla_col.label(label)
+        sqla_col._df_label_expected = label_expected
+        return sqla_col
 
     def __repr__(self):
         return self.name
@@ -517,7 +518,6 @@ class SqlaTable(Model, BaseDatasource):
         """
         expression_type = metric.get('expressionType')
         label = utils.get_metric_name(metric)
-        label = self.get_label(label)
 
         if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
             column_name = metric.get('column').get('column_name')
@@ -527,15 +527,13 @@ class SqlaTable(Model, BaseDatasource):
             else:
                 sqla_column = column(column_name)
             sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
-            sqla_metric = sqla_metric.label(label)
-            return sqla_metric
         elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
             sqla_metric = literal_column(metric.get('sqlExpression'))
-            sqla_metric = sqla_metric.label(label)
-            return sqla_metric
         else:
             return None
 
+        return self.make_sqla_column_compatible(sqla_metric, label)
+
     def get_sqla_query(  # sqla
             self,
             groupby, metrics,
@@ -569,9 +567,6 @@ class SqlaTable(Model, BaseDatasource):
         template_processor = self.get_template_processor(**template_kwargs)
         db_engine_spec = self.database.db_engine_spec
 
-        # Initialize empty cache to store mutated labels
-        self.mutated_labels = {}
-
         orderby = orderby or []
 
         # For backward compatibility
@@ -601,8 +596,8 @@ class SqlaTable(Model, BaseDatasource):
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
         else:
-            label = self.get_label('ccount')
-            main_metric_expr = literal_column('COUNT(*)').label(label)
+            main_metric_expr, label = literal_column('COUNT(*)'), 'ccount'
+            main_metric_expr = self.make_sqla_column_compatible(main_metric_expr, label)
 
         select_exprs = []
         groupby_exprs_sans_timestamp = OrderedDict()
@@ -613,14 +608,16 @@ class SqlaTable(Model, BaseDatasource):
                 if s in cols:
                     outer = cols[s].get_sqla_col()
                 else:
-                    outer = literal_column(f'({s})').label(self.get_label(s))
+                    outer = literal_column(f'({s})')
+                    outer = self.make_sqla_column_compatible(outer, s)
 
                 groupby_exprs_sans_timestamp[outer.name] = outer
                 select_exprs.append(outer)
         elif columns:
             for s in columns:
                 select_exprs.append(
-                    cols[s].get_sqla_col() if s in cols else literal_column(s))
+                    cols[s].get_sqla_col() if s in cols else
+                    self.make_sqla_column_compatible(literal_column(s)))
             metrics_exprs = []
 
         groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
@@ -644,7 +641,7 @@ class SqlaTable(Model, BaseDatasource):
 
         select_exprs += metrics_exprs
 
-        labels_expected = [str(c.name) for c in select_exprs]
+        labels_expected = [c._df_label_expected for c in select_exprs]
 
         select_exprs = db_engine_spec.make_select_compatible(
             groupby_exprs_with_timestamp.values(),
@@ -732,12 +729,12 @@ class SqlaTable(Model, BaseDatasource):
                 # some sql dialects require for order by expressions
                 # to also be in the select clause -- others, e.g. vertica,
                 # require a unique inner alias
-                label = self.get_label('mme_inner__')
-                inner_main_metric_expr = main_metric_expr.label(label)
+                inner_main_metric_expr = self.make_sqla_column_compatible(
+                    main_metric_expr, 'mme_inner__')
                 inner_groupby_exprs = []
                 inner_select_exprs = []
                 for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
-                    inner = gby_obj.label(gby_name + '__')
+                    inner = self.make_sqla_column_compatible(gby_obj, gby_name + '__')
                     inner_groupby_exprs.append(inner)
                     inner_select_exprs.append(inner)
 
@@ -766,7 +763,7 @@ class SqlaTable(Model, BaseDatasource):
                     # in this case the column name, not the alias, needs to be
                     # conditionally mutated, as it refers to the column alias in
                     # the inner query
-                    col_name = self.get_label(gby_name + '__')
+                    col_name = db_engine_spec.make_label_compatible(gby_name + '__')
                     on_clause.append(gby_obj == column(col_name))
 
                 tbl = tbl.join(subq.alias(), and_(*on_clause))
@@ -841,15 +838,19 @@ class SqlaTable(Model, BaseDatasource):
         status = utils.QueryStatus.SUCCESS
         error_message = None
         df = None
-        db_engine_spec = self.database.db_engine_spec
         try:
             df = self.database.get_df(sql, self.schema)
-            if self.mutated_labels:
-                df = df.rename(index=str, columns=self.mutated_labels)
-            db_engine_spec.mutate_df_columns(df, sql, query_str_ext.labels_expected)
+            labels_expected = query_str_ext.labels_expected
+            if df is not None and not df.empty:
+                if len(df.columns) != len(labels_expected):
+                    raise Exception(f'For {sql}, df.columns: {df.columns}'
+                                    f' differs from {labels_expected}')
+                else:
+                    df.columns = labels_expected
         except Exception as e:
             status = utils.QueryStatus.FAILED
             logging.exception(f'Query {sql} on schema {self.schema} failed')
+            db_engine_spec = self.database.db_engine_spec
             error_message = db_engine_spec.extract_error_message(e)
 
         # if this is a main query with prequeries, combine them together
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 40d1ea5..0785082 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -111,8 +111,10 @@ class BaseEngineSpec(object):
     time_secondary_columns = False
     inner_joins = True
     allows_subquery = True
+    supports_column_aliases = True
     force_column_alias_quotes = False
     arraysize = None
+    max_column_name_length = None
 
     @classmethod
     def get_time_expr(cls, expr, pdf, time_grain, grain):
@@ -143,10 +145,6 @@ class BaseEngineSpec(object):
         return select_exprs
 
     @classmethod
-    def mutate_df_columns(cls, df, sql, labels_expected):
-        pass
-
-    @classmethod
     def fetch_data(cls, cursor, limit):
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
@@ -287,6 +285,8 @@ class BaseEngineSpec(object):
                     schema=schema, force=True,
                     cache=db.table_cache_enabled,
                     cache_timeout=db.table_cache_timeout)
+            else:
+                raise Exception(f'Unsupported datasource_type: {datasource_type}')
             all_result_sets += [
                 '{}.{}'.format(schema, t) for t in all_datasource_names]
         return all_result_sets
@@ -418,10 +418,15 @@ class BaseEngineSpec(object):
         force_column_alias_quotes is set to True, return the label as a
         sqlalchemy.sql.elements.quoted_name object to ensure that the select query
         and query results have same case. Otherwise return the mutated label as a
-        regular string.
+        regular string. If maxmimum supported column name length is exceeded,
+        generate a truncated label by calling truncate_label().
         """
-        label = cls.mutate_label(label)
-        return quoted_name(label, True) if cls.force_column_alias_quotes else label
+        label_mutated = cls.mutate_label(label)
+        if cls.max_column_name_length and len(label_mutated) > cls.max_column_name_length:
+            label_mutated = cls.truncate_label(label)
+        if cls.force_column_alias_quotes:
+            label_mutated = quoted_name(label_mutated, True)
+        return label_mutated
 
     @classmethod
     def get_sqla_column_type(cls, type_):
@@ -445,6 +450,19 @@ class BaseEngineSpec(object):
         """
         return label
 
+    @classmethod
+    def truncate_label(cls, label):
+        """
+        In the case that a label exceeds the max length supported by the engine,
+        this method is used to construct a deterministic and unique label based on
+        an md5 hash.
+        """
+        label = hashlib.md5(label.encode('utf-8')).hexdigest()
+        # truncate hash if it exceeds max length
+        if cls.max_column_name_length and len(label) > cls.max_column_name_length:
+            label = label[:cls.max_column_name_length]
+        return label
+
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -482,6 +500,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
 
 class PostgresEngineSpec(PostgresBaseEngineSpec):
     engine = 'postgresql'
+    max_column_name_length = 63
 
     @classmethod
     def get_table_names(cls, inspector, schema):
@@ -494,6 +513,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
     force_column_alias_quotes = True
+    max_column_name_length = 256
 
     time_grain_functions = {
         None: '{col}',
@@ -531,6 +551,7 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
+    max_column_name_length = 127
 
     @staticmethod
     def mutate_label(label):
@@ -546,6 +567,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
     engine = 'oracle'
     limit_method = LimitMethod.WRAP_SQL
     force_column_alias_quotes = True
+    max_column_name_length = 30
 
     time_grain_functions = {
         None: '{col}',
@@ -565,25 +587,12 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
             """TO_TIMESTAMP('{}', 'YYYY-MM-DD"T"HH24:MI:SS.ff6')"""
         ).format(dttm.isoformat())
 
-    @staticmethod
-    def mutate_label(label):
-        """
-        Oracle 12.1 and earlier support a maximum of 30 byte length object names, which
-        usually means 30 characters.
-        :param str label: Original label which might include unsupported characters
-        :return: String that is supported by the database
-        """
-        if len(label) > 30:
-            hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
-            # truncate the hash to first 30 characters
-            return hashed_label[:30]
-        return label
-
 
 class Db2EngineSpec(BaseEngineSpec):
     engine = 'ibm_db_sa'
     limit_method = LimitMethod.WRAP_SQL
     force_column_alias_quotes = True
+    max_column_name_length = 30
 
     time_grain_functions = {
         None: '{col}',
@@ -618,20 +627,6 @@ class Db2EngineSpec(BaseEngineSpec):
     def convert_dttm(cls, target_type, dttm):
         return "'{}'".format(dttm.strftime('%Y-%m-%d-%H.%M.%S'))
 
-    @staticmethod
-    def mutate_label(label):
-        """
-        Db2 for z/OS supports a maximum of 30 byte length object names, which usually
-        means 30 characters.
-        :param str label: Original label which might include unsupported characters
-        :return: String that is supported by the database
-        """
-        if len(label) > 30:
-            hashed_label = hashlib.md5(label.encode('utf-8')).hexdigest()
-            # truncate the hash to first 30 characters
-            return hashed_label[:30]
-        return label
-
 
 class SqliteEngineSpec(BaseEngineSpec):
     engine = 'sqlite'
@@ -668,6 +663,9 @@ class SqliteEngineSpec(BaseEngineSpec):
                 schema=schema, force=True,
                 cache=db.table_cache_enabled,
                 cache_timeout=db.table_cache_timeout)
+        else:
+            raise Exception(f'Unsupported datasource_type: {datasource_type}')
+
         all_result_sets += [
             '{}.{}'.format(schema, t) for t in all_datasource_names]
         return all_result_sets
@@ -687,6 +685,7 @@ class SqliteEngineSpec(BaseEngineSpec):
 
 class MySQLEngineSpec(BaseEngineSpec):
     engine = 'mysql'
+    max_column_name_length = 64
 
     time_grain_functions = {
         None: '{col}',
@@ -1060,6 +1059,7 @@ class HiveEngineSpec(PrestoEngineSpec):
     """Reuses PrestoEngineSpec functionality."""
 
     engine = 'hive'
+    max_column_name_length = 767
 
     # Scoping regex at class level to avoid recompiling
     # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
@@ -1366,6 +1366,7 @@ class MssqlEngineSpec(BaseEngineSpec):
     engine = 'mssql'
     epoch_to_dttm = "dateadd(S, {col}, '1970-01-01')"
     limit_method = LimitMethod.WRAP_SQL
+    max_column_name_length = 128
 
     time_grain_functions = {
         None: '{col}',
@@ -1434,11 +1435,21 @@ class AthenaEngineSpec(BaseEngineSpec):
     def epoch_to_dttm(cls):
         return 'from_unixtime({col})'
 
+    @staticmethod
+    def mutate_label(label):
+        """
+        Athena only supports lowercase column names and aliases.
+        :param str label: Original label which might include uppercase letters
+        :return: String that is supported by the database
+        """
+        return label.lower()
+
 
 class PinotEngineSpec(BaseEngineSpec):
     engine = 'pinot'
     allows_subquery = False
     inner_joins = False
+    supports_column_aliases = False
 
     _time_grain_to_datetimeconvert = {
         'PT1S': '1:SECONDS',
@@ -1481,17 +1492,6 @@ class PinotEngineSpec(BaseEngineSpec):
                 select_sans_groupby.append(s)
         return select_sans_groupby
 
-    @classmethod
-    def mutate_df_columns(cls, df, sql, labels_expected):
-        if df is not None and \
-                not df.empty and \
-                labels_expected is not None:
-            if len(df.columns) != len(labels_expected):
-                raise Exception(f'For {sql}, df.columns: {df.columns}'
-                                f' differs from {labels_expected}')
-            else:
-                df.columns = labels_expected
-
 
 class ClickHouseEngineSpec(BaseEngineSpec):
     """Dialect for ClickHouse analytical DB."""
@@ -1532,6 +1532,7 @@ class BQEngineSpec(BaseEngineSpec):
 
     As contributed by @mxmzdlv on issue #945"""
     engine = 'bigquery'
+    max_column_name_length = 128
 
     """
     https://www.python.org/dev/peps/pep-0249/#arraysize
@@ -1574,28 +1575,33 @@ class BQEngineSpec(BaseEngineSpec):
     @staticmethod
     def mutate_label(label):
         """
-        BigQuery field_name should start with a letter or underscore, contain only
-        alphanumeric characters and be at most 128 characters long. Labels that start
-        with a number are prefixed with an underscore. Any unsupported characters are
-        replaced with underscores and an md5 hash is added to the end of the label to
-        avoid possible collisions. If the resulting label exceeds 128 characters, only
-        the md5 sum is returned.
+        BigQuery field_name should start with a letter or underscore and contain only
+        alphanumeric characters. Labels that start with a number are prefixed with an
+        underscore. Any unsupported characters are replaced with underscores and an
+        md5 hash is added to the end of the label to avoid possible collisions.
         :param str label: the original label which might include unsupported characters
         :return: String that is supported by the database
         """
-        hashed_label = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
+        label_hashed = '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
 
         # if label starts with number, add underscore as first character
-        mutated_label = '_' + label if re.match(r'^\d', label) else label
+        label_mutated = '_' + label if re.match(r'^\d', label) else label
 
         # replace non-alphanumeric characters with underscores
-        mutated_label = re.sub(r'[^\w]+', '_', mutated_label)
-        if mutated_label != label:
+        label_mutated = re.sub(r'[^\w]+', '_', label_mutated)
+        if label_mutated != label:
             # add md5 hash to label to avoid possible collisions
-            mutated_label += hashed_label
+            label_mutated += label_hashed
+
+        return label_mutated
 
-        # return only hash if length of final label exceeds 128 chars
-        return mutated_label if len(mutated_label) <= 128 else hashed_label
+    @classmethod
+    def truncate_label(cls, label):
+        """BigQuery requires column names start with either a letter or
+        underscore. To make sure this is always the case, an underscore is prefixed
+        to the truncated label.
+        """
+        return '_' + hashlib.md5(label.encode('utf-8')).hexdigest()
 
     @classmethod
     def extra_table_metadata(cls, database, table_name, schema_name):
@@ -1727,6 +1733,7 @@ class TeradataEngineSpec(BaseEngineSpec):
     """Dialect for Teradata DB."""
     engine = 'teradata'
     limit_method = LimitMethod.WRAP_SQL
+    max_column_name_length = 30  # since 14.10 this is 128
 
     time_grain_functions = {
         None: '{col}',
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index c2f1713b..a48012d 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -17,11 +17,12 @@
 import inspect
 
 import mock
+from sqlalchemy import column
 
 from superset import db_engine_specs
 from superset.db_engine_specs import (
-    BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
-    MySQLEngineSpec, PrestoEngineSpec,
+    BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
+    MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
 )
 from superset.models.core import Database
 from .base_tests import SupersetTestCase
@@ -307,3 +308,27 @@ class DbEngineSpecsTestCase(SupersetTestCase):
 
     def test_hive_get_view_names_return_empty_list(self):
         self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))
+
+    def test_bigquery_sqla_column_label(self):
+        label = BQEngineSpec.make_label_compatible(column('Col').name)
+        label_expected = 'Col'
+        self.assertEqual(label, label_expected)
+
+        label = BQEngineSpec.make_label_compatible(column('SUM(x)').name)
+        label_expected = 'SUM_x__5f110b965a993675bc4953bb3e03c4a5'
+        self.assertEqual(label, label_expected)
+
+        label = BQEngineSpec.make_label_compatible(column('SUM[x]').name)
+        label_expected = 'SUM_x__7ebe14a3f9534aeee125449b0bc083a8'
+        self.assertEqual(label, label_expected)
+
+        label = BQEngineSpec.make_label_compatible(column('12345_col').name)
+        label_expected = '_12345_col_8d3906e2ea99332eb185f7f8ecb2ffd6'
+        self.assertEqual(label, label_expected)
+
+    def test_oracle_sqla_column_name_length_exceeded(self):
+        col = column('This_Is_32_Character_Column_Name')
+        label = OracleEngineSpec.make_label_compatible(col.name)
+        self.assertEqual(label.quote, True)
+        label_expected = '3b26974078683be078219674eeb8f5'
+        self.assertEqual(label, label_expected)


Mime
View raw message