superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: Improve database type inference (#4724)
Date Thu, 28 Jun 2018 04:35:20 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 777d876  Improve database type inference (#4724)
777d876 is described below

commit 777d876a523590be0e18f573eab02137fa989fa3
Author: Maxime Beauchemin <maximebeauchemin@gmail.com>
AuthorDate: Wed Jun 27 21:35:12 2018 -0700

    Improve database type inference (#4724)
    
    * Improve database type inference
    
    Python's DBAPI isn't super clear and homogeneous on the
    cursor.description specification, and this PR attempts to improve
    inferring the datatypes returned in the cursor.
    
    This work started around Presto's TIMESTAMP type being mishandled as
    string as the database driver (pyhive) returns it as a string. The work
    here fixes this bug and does a better job at inferring MySQL and Presto types.
    It also creates a new method in db_engine_specs allowing for other
    databases engines to implement and become more precise on type-inference
    as needed.
    
    * Fixing tests
    
    * Adressing comments
    
    * Using infer_objects
    
    * Removing faulty line
    
    * Addressing PrestoSpec redundant method comment
    
    * Fix rebase issue
    
    * Fix tests
---
 superset/dataframe.py         |  78 ++++++++++++++++++++++------
 superset/db_engine_specs.py   |  24 +++++++++
 superset/sql_lab.py           |  44 +---------------
 tests/celery_tests.py         |  51 +------------------
 tests/core_tests.py           |   4 +-
 tests/dataframe_test.py       | 115 ++++++++++++++++++++++++++++++++++++++++++
 tests/db_engine_specs_test.py |  10 +++-
 tests/sqllab_tests.py         |  15 ++++--
 8 files changed, 224 insertions(+), 117 deletions(-)

diff --git a/superset/dataframe.py b/superset/dataframe.py
index 79a2c3d..5fba4ff 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -13,6 +13,7 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from datetime import date, datetime
+import logging
 
 import numpy as np
 import pandas as pd
@@ -26,6 +27,27 @@ INFER_COL_TYPES_THRESHOLD = 95
 INFER_COL_TYPES_SAMPLE_SIZE = 100
 
 
+def dedup(l, suffix='__'):
+    """De-duplicates a list of string by suffixing a counter
+
+    Always returns the same number of entries as provided, and always returns
+    unique values.
+
+    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
+    foo,bar,bar__1,bar__2
+    """
+    new_l = []
+    seen = {}
+    for s in l:
+        if s in seen:
+            seen[s] += 1
+            s += suffix + str(seen[s])
+        else:
+            seen[s] = 0
+        new_l.append(s)
+    return new_l
+
+
 class SupersetDataFrame(object):
     # Mapping numpy dtype.char to generic database types
     type_map = {
@@ -43,19 +65,39 @@ class SupersetDataFrame(object):
         'V': None,   # raw data (void)
     }
 
-    def __init__(self, df):
-        self.__df = df.where((pd.notnull(df)), None)
+    def __init__(self, data, cursor_description, db_engine_spec):
+        column_names = []
+        if cursor_description:
+            column_names = [col[0] for col in cursor_description]
+
+        self.column_names = dedup(
+            db_engine_spec.get_normalized_column_names(cursor_description))
+
+        data = data or []
+        self.df = (
+            pd.DataFrame(list(data), columns=column_names).infer_objects())
+
+        self._type_dict = {}
+        try:
+            # The driver may not be passing a cursor.description
+            self._type_dict = {
+                col: db_engine_spec.get_datatype(cursor_description[i][1])
+                for i, col in enumerate(self.column_names)
+                if cursor_description
+            }
+        except Exception as e:
+            logging.exception(e)
 
     @property
     def size(self):
-        return len(self.__df.index)
+        return len(self.df.index)
 
     @property
     def data(self):
         # work around for https://github.com/pandas-dev/pandas/issues/18372
         data = [dict((k, _maybe_box_datetimelike(v))
-                for k, v in zip(self.__df.columns, np.atleast_1d(row)))
-                for row in self.__df.values]
+                for k, v in zip(self.df.columns, np.atleast_1d(row)))
+                for row in self.df.values]
         for d in data:
             for k, v in list(d.items()):
                 # if an int is too big for Java Script to handle
@@ -70,7 +112,8 @@ class SupersetDataFrame(object):
         """Given a numpy dtype, Returns a generic database type"""
         if isinstance(dtype, ExtensionDtype):
             return cls.type_map.get(dtype.kind)
-        return cls.type_map.get(dtype.char)
+        elif hasattr(dtype, 'char'):
+            return cls.type_map.get(dtype.char)
 
     @classmethod
     def datetime_conversion_rate(cls, data_series):
@@ -105,7 +148,7 @@ class SupersetDataFrame(object):
         # consider checking for key substring too.
         if cls.is_id(column_name):
             return 'count_distinct'
-        if (issubclass(dtype.type, np.generic) and
+        if (hasattr(dtype, 'type') and issubclass(dtype.type, np.generic) and
                 np.issubdtype(dtype, np.number)):
             return 'sum'
         return None
@@ -116,22 +159,25 @@ class SupersetDataFrame(object):
 
         :return: dict, with the fields name, type, is_date, is_dim and agg.
         """
-        if self.__df.empty:
+        if self.df.empty:
             return None
 
         columns = []
-        sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.__df.index))
-        sample = self.__df
+        sample_size = min(INFER_COL_TYPES_SAMPLE_SIZE, len(self.df.index))
+        sample = self.df
         if sample_size:
-            sample = self.__df.sample(sample_size)
-        for col in self.__df.dtypes.keys():
-            col_db_type = self.db_type(self.__df.dtypes[col])
+            sample = self.df.sample(sample_size)
+        for col in self.df.dtypes.keys():
+            col_db_type = (
+                self._type_dict.get(col) or
+                self.db_type(self.df.dtypes[col])
+            )
             column = {
                 'name': col,
-                'agg': self.agg_func(self.__df.dtypes[col], col),
+                'agg': self.agg_func(self.df.dtypes[col], col),
                 'type': col_db_type,
-                'is_date': self.is_date(self.__df.dtypes[col]),
-                'is_dim': self.is_dimension(self.__df.dtypes[col], col),
+                'is_date': self.is_date(self.df.dtypes[col]),
+                'is_dim': self.is_dimension(self.df.dtypes[col], col),
             }
 
             if column['type'] in ('OBJECT', None):
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4f6b22e..4181c49 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -30,6 +30,7 @@ import boto3
 from flask import g
 from flask_babel import lazy_gettext as _
 import pandas
+from past.builtins import basestring
 import sqlalchemy as sqla
 from sqlalchemy import select
 from sqlalchemy.engine import create_engine
@@ -86,6 +87,11 @@ class BaseEngineSpec(object):
         return cls.epoch_to_dttm().replace('{col}', '({col}/1000.0)')
 
     @classmethod
+    def get_datatype(cls, type_code):
+        if isinstance(type_code, basestring) and len(type_code):
+            return type_code.upper()
+
+    @classmethod
     def extra_table_metadata(cls, database, table_name, schema_name):
         """Returns engine-specific table metadata"""
         return {}
@@ -592,6 +598,7 @@ class MySQLEngineSpec(BaseEngineSpec):
               'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
               'P1W'),
     )
+    type_code_map = {}  # loaded from get_datatype only if needed
 
     @classmethod
     def convert_dttm(cls, target_type, dttm):
@@ -607,6 +614,23 @@ class MySQLEngineSpec(BaseEngineSpec):
         return uri
 
     @classmethod
+    def get_datatype(cls, type_code):
+        if not cls.type_code_map:
+            # only import and store if needed at least once
+            import MySQLdb
+            ft = MySQLdb.constants.FIELD_TYPE
+            cls.type_code_map = {
+                getattr(ft, k): k
+                for k in dir(ft)
+                if not k.startswith('_')
+            }
+        datatype = type_code
+        if isinstance(type_code, int):
+            datatype = cls.type_code_map.get(type_code)
+        if datatype and isinstance(datatype, basestring) and len(datatype):
+            return datatype
+
+    @classmethod
     def epoch_to_dttm(cls):
         return 'from_unixtime({col})'
 
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index df00a2b..34a9eeb 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -10,8 +10,6 @@ import uuid
 
 from celery.exceptions import SoftTimeLimitExceeded
 from contextlib2 import contextmanager
-import numpy as np
-import pandas as pd
 import sqlalchemy
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.pool import NullPool
@@ -31,27 +29,6 @@ class SqlLabException(Exception):
     pass
 
 
-def dedup(l, suffix='__'):
-    """De-duplicates a list of string by suffixing a counter
-
-    Always returns the same number of entries as provided, and always returns
-    unique values.
-
-    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
-    foo,bar,bar__1,bar__2
-    """
-    new_l = []
-    seen = {}
-    for s in l:
-        if s in seen:
-            seen[s] += 1
-            s += suffix + str(seen[s])
-        else:
-            seen[s] = 0
-        new_l.append(s)
-    return new_l
-
-
 def get_query(query_id, session, retry_count=5):
     """attemps to get the query and retry if it cannot"""
     query = None
@@ -96,24 +73,6 @@ def session_scope(nullpool):
         session.close()
 
 
-def convert_results_to_df(column_names, data):
-    """Convert raw query results to a DataFrame."""
-    column_names = dedup(column_names)
-
-    # check whether the result set has any nested dict columns
-    if data:
-        first_row = data[0]
-        has_dict_col = any([isinstance(c, dict) for c in first_row])
-        df_data = list(data) if has_dict_col else np.array(data, dtype=object)
-    else:
-        df_data = []
-
-    cdf = dataframe.SupersetDataFrame(
-        pd.DataFrame(df_data, columns=column_names))
-
-    return cdf
-
-
 @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
 def get_sql_results(
     ctask, query_id, rendered_query, return_results=True, store_results=False,
@@ -233,7 +192,6 @@ def execute_sql(
         return handle_error(db_engine_spec.extract_error_message(e))
 
     logging.info('Fetching cursor description')
-    column_names = db_engine_spec.get_normalized_column_names(cursor.description)
 
     if conn is not None:
         conn.commit()
@@ -242,7 +200,7 @@ def execute_sql(
     if query.status == utils.QueryStatus.STOPPED:
         return handle_error('The query has been stopped')
 
-    cdf = convert_results_to_df(column_names, data)
+    cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
 
     query.rows = cdf.size
     query.progress = 100
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index 39b7749..afaeea9 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -14,7 +14,7 @@ import unittest
 import pandas as pd
 from past.builtins import basestring
 
-from superset import app, cli, dataframe, db, security_manager
+from superset import app, cli, db, security_manager
 from superset.models.helpers import QueryStatus
 from superset.models.sql_lab import Query
 from superset.sql_parse import SupersetQuery
@@ -245,55 +245,6 @@ class CeleryTestCase(SupersetTestCase):
     def dictify_list_of_dicts(cls, l, k):
         return {str(o[k]): cls.de_unicode_dict(o) for o in l}
 
-    def test_get_columns(self):
-        main_db = self.get_main_database(db.session)
-        df = main_db.get_df('SELECT * FROM multiformat_time_series', None)
-        cdf = dataframe.SupersetDataFrame(df)
-
-        # Making ordering non-deterministic
-        cols = self.dictify_list_of_dicts(cdf.columns, 'name')
-
-        if main_db.sqlalchemy_uri.startswith('sqlite'):
-            self.assertEqual(self.dictify_list_of_dicts([
-                {'is_date': True, 'type': 'STRING', 'name': 'ds',
-                    'is_dim': False},
-                {'is_date': True, 'type': 'STRING', 'name': 'ds2',
-                    'is_dim': False},
-                {'agg': 'sum', 'is_date': False, 'type': 'INT',
-                    'name': 'epoch_ms', 'is_dim': False},
-                {'agg': 'sum', 'is_date': False, 'type': 'INT',
-                    'name': 'epoch_s', 'is_dim': False},
-                {'is_date': True, 'type': 'STRING', 'name': 'string0',
-                    'is_dim': False},
-                {'is_date': False, 'type': 'STRING',
-                    'name': 'string1', 'is_dim': True},
-                {'is_date': True, 'type': 'STRING', 'name': 'string2',
-                    'is_dim': False},
-                {'is_date': False, 'type': 'STRING',
-                    'name': 'string3', 'is_dim': True}], 'name'),
-                cols,
-            )
-        else:
-            self.assertEqual(self.dictify_list_of_dicts([
-                {'is_date': True, 'type': 'DATETIME', 'name': 'ds',
-                    'is_dim': False},
-                {'is_date': True, 'type': 'DATETIME',
-                    'name': 'ds2', 'is_dim': False},
-                {'agg': 'sum', 'is_date': False, 'type': 'INT',
-                    'name': 'epoch_ms', 'is_dim': False},
-                {'agg': 'sum', 'is_date': False, 'type': 'INT',
-                    'name': 'epoch_s', 'is_dim': False},
-                {'is_date': True, 'type': 'STRING', 'name': 'string0',
-                    'is_dim': False},
-                {'is_date': False, 'type': 'STRING',
-                    'name': 'string1', 'is_dim': True},
-                {'is_date': True, 'type': 'STRING', 'name': 'string2',
-                    'is_dim': False},
-                {'is_date': False, 'type': 'STRING',
-                    'name': 'string3', 'is_dim': True}], 'name'),
-                cols,
-            )
-
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 6a4f153..f1a0179 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -24,6 +24,7 @@ import sqlalchemy as sqla
 
 from superset import dataframe, db, jinja_context, security_manager, sql_lab, utils
 from superset.connectors.sqla.models import SqlaTable
+from superset.db_engine_specs import BaseEngineSpec
 from superset.models import core as models
 from superset.models.sql_lab import Query
 from superset.views.core import DatabaseView
@@ -626,8 +627,7 @@ class CoreTests(SupersetTestCase):
             (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),),
             (datetime.datetime(2017, 11, 18, 22, 6, 30, 61810, tzinfo=tz),),
         ]
-        df = dataframe.SupersetDataFrame(pd.DataFrame(data=list(data),
-                                                      columns=['data']))
+        df = dataframe.SupersetDataFrame(list(data), [['data']], BaseEngineSpec)
         data = df.data
         self.assertDictEqual(
             data[0],
diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py
new file mode 100644
index 0000000..b567702
--- /dev/null
+++ b/tests/dataframe_test.py
@@ -0,0 +1,115 @@
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from superset.dataframe import dedup, SupersetDataFrame
+from superset.db_engine_specs import BaseEngineSpec
+from .base_tests import SupersetTestCase
+
+
+class SupersetDataFrameTestCase(SupersetTestCase):
+    def test_dedup(self):
+        self.assertEquals(
+            dedup(['foo', 'bar']),
+            ['foo', 'bar'],
+        )
+        self.assertEquals(
+            dedup(['foo', 'bar', 'foo', 'bar']),
+            ['foo', 'bar', 'foo__1', 'bar__1'],
+        )
+        self.assertEquals(
+            dedup(['foo', 'bar', 'bar', 'bar']),
+            ['foo', 'bar', 'bar__1', 'bar__2'],
+        )
+
+    def test_get_columns_basic(self):
+        data = [
+            ('a1', 'b1', 'c1'),
+            ('a2', 'b2', 'c2'),
+        ]
+        cursor_descr = (
+            ('a', 'string'),
+            ('b', 'string'),
+            ('c', 'string'),
+        )
+        cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+        self.assertEqual(
+            cdf.columns,
+            [
+                {
+                    'is_date': False,
+                    'type': 'STRING',
+                    'name': 'a',
+                    'is_dim': True,
+                }, {
+                    'is_date': False,
+                    'type': 'STRING',
+                    'name': 'b',
+                    'is_dim': True,
+                }, {
+                    'is_date': False,
+                    'type': 'STRING',
+                    'name': 'c',
+                    'is_dim': True,
+                },
+            ],
+        )
+
+    def test_get_columns_with_int(self):
+        data = [
+            ('a1', 1),
+            ('a2', 2),
+        ]
+        cursor_descr = (
+            ('a', 'string'),
+            ('b', 'int'),
+        )
+        cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+        self.assertEqual(
+            cdf.columns,
+            [
+                {
+                    'is_date': False,
+                    'type': 'STRING',
+                    'name': 'a',
+                    'is_dim': True,
+                }, {
+                    'is_date': False,
+                    'type': 'INT',
+                    'name': 'b',
+                    'is_dim': False,
+                    'agg': 'sum',
+                },
+            ],
+        )
+
+    def test_get_columns_type_inference(self):
+        data = [
+            (1.2, 1),
+            (3.14, 2),
+        ]
+        cursor_descr = (
+            ('a', None),
+            ('b', None),
+        )
+        cdf = SupersetDataFrame(data, cursor_descr, BaseEngineSpec)
+        self.assertEqual(
+            cdf.columns,
+            [
+                {
+                    'is_date': False,
+                    'type': 'FLOAT',
+                    'name': 'a',
+                    'is_dim': False,
+                    'agg': 'sum',
+                }, {
+                    'is_date': False,
+                    'type': 'INT',
+                    'name': 'b',
+                    'is_dim': False,
+                    'agg': 'sum',
+                },
+            ],
+        )
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index bdce0b0..447914e 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -7,7 +7,9 @@ from __future__ import unicode_literals
 import textwrap
 
 from superset.db_engine_specs import (
-    HiveEngineSpec, MssqlEngineSpec, MySQLEngineSpec)
+    BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
+    MySQLEngineSpec, PrestoEngineSpec,
+)
 from superset.models.core import Database
 from .base_tests import SupersetTestCase
 
@@ -193,3 +195,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
                 FROM
                 table LIMIT 1000"""),
         )
+
+    def test_get_datatype(self):
+        self.assertEquals('STRING', PrestoEngineSpec.get_datatype('string'))
+        self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
+        self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
+        self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 49926f8..a3bb564 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -12,8 +12,9 @@ import unittest
 from flask_appbuilder.security.sqla import models as ab_models
 
 from superset import db, security_manager, utils
+from superset.dataframe import SupersetDataFrame
+from superset.db_engine_specs import BaseEngineSpec
 from superset.models.sql_lab import Query
-from superset.sql_lab import convert_results_to_df
 from .base_tests import SupersetTestCase
 
 
@@ -203,9 +204,13 @@ class SqlLabTests(SupersetTestCase):
             raise_on_error=True)
 
     def test_df_conversion_no_dict(self):
-        cols = ['string_col', 'int_col', 'float_col']
+        cols = [
+            ['string_col', 'string'],
+            ['int_col', 'int'],
+            ['float_col', 'float'],
+        ]
         data = [['a', 4, 4.0]]
-        cdf = convert_results_to_df(cols, data)
+        cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
 
         self.assertEquals(len(data), cdf.size)
         self.assertEquals(len(cols), len(cdf.columns))
@@ -213,7 +218,7 @@ class SqlLabTests(SupersetTestCase):
     def test_df_conversion_tuple(self):
         cols = ['string_col', 'int_col', 'list_col', 'float_col']
         data = [(u'Text', 111, [123], 1.0)]
-        cdf = convert_results_to_df(cols, data)
+        cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
 
         self.assertEquals(len(data), cdf.size)
         self.assertEquals(len(cols), len(cdf.columns))
@@ -221,7 +226,7 @@ class SqlLabTests(SupersetTestCase):
     def test_df_conversion_dict(self):
         cols = ['string_col', 'dict_col', 'int_col']
         data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]]
-        cdf = convert_results_to_df(cols, data)
+        cdf = SupersetDataFrame(data, cols, BaseEngineSpec)
 
         self.assertEquals(len(data), cdf.size)
         self.assertEquals(len(cols), len(cdf.columns))


Mime
View raw message