superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: [bugfix] temporal columns with expression fail (#4890)
Date Fri, 27 Apr 2018 04:14:00 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3f48c00  [bugfix] temporal columns with expression fail (#4890)
3f48c00 is described below

commit 3f48c005df985888f19d46d00cac0f35ec98466d
Author: Maxime Beauchemin <maximebeauchemin@gmail.com>
AuthorDate: Thu Apr 26 21:13:52 2018 -0700

    [bugfix] temporal columns with expression fail (#4890)
    
    * [bugfix] temporal columns with expression fail
    
    error msg: "local variable 'literal' referenced before assignment"
    
    Error occurs [only] when using temporal column defined as a SQL
    expression.
    
    Also noticed that examples were using `granularity` instead of using
    `granularity_sqla` as they should. Fixed that here.
    
    * Add tests
---
 superset/connectors/base/models.py |  5 ++++
 superset/connectors/sqla/models.py | 28 +++++++++--------
 superset/data/__init__.py          | 22 +++++++-------
 tests/model_tests.py               | 61 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 92 insertions(+), 24 deletions(-)

diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 8e4a2a2..9f9522d 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -241,6 +241,11 @@ class BaseDatasource(AuditMixinNullable, ImportMixin):
     def default_query(qry):
         return qry
 
+    def get_column(self, column_name):
+        for col in self.columns:
+            if col.column_name == column_name:
+                return col
+
 
 class BaseColumn(AuditMixinNullable, ImportMixin):
     """Interface for column"""
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index c65df02..56a1751 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -117,22 +117,24 @@ class TableColumn(Model, BaseColumn):
 
     def get_timestamp_expression(self, time_grain):
         """Getting the time component of the query"""
+        pdf = self.python_date_format
+        is_epoch = pdf in ('epoch_s', 'epoch_ms')
+        if not self.expression and not time_grain and not is_epoch:
+            return column(self.column_name, type_=DateTime).label(DTTM_ALIAS)
+
         expr = self.expression or self.column_name
-        if not self.expression and not time_grain:
-            return column(expr, type_=DateTime).label(DTTM_ALIAS)
+        if is_epoch:
+            # if epoch, translate to DATE using db specific conf
+            db_spec = self.table.database.db_engine_spec
+            if pdf == 'epoch_s':
+                expr = db_spec.epoch_to_dttm().format(col=expr)
+            elif pdf == 'epoch_ms':
+                expr = db_spec.epoch_ms_to_dttm().format(col=expr)
         if time_grain:
-            pdf = self.python_date_format
-            if pdf in ('epoch_s', 'epoch_ms'):
-                # if epoch, translate to DATE using db specific conf
-                db_spec = self.table.database.db_engine_spec
-                if pdf == 'epoch_s':
-                    expr = db_spec.epoch_to_dttm().format(col=expr)
-                elif pdf == 'epoch_ms':
-                    expr = db_spec.epoch_ms_to_dttm().format(col=expr)
             grain = self.table.database.grains_dict().get(time_grain)
-            literal = grain.function if grain else '{col}'
-            literal = expr.format(col=expr)
-        return literal_column(literal, type_=DateTime).label(DTTM_ALIAS)
+            if grain:
+                expr = grain.function.format(col=expr)
+        return literal_column(expr, type_=DateTime).label(DTTM_ALIAS)
 
     @classmethod
     def import_obj(cls, i_column):
diff --git a/superset/data/__init__.py b/superset/data/__init__.py
index 160ed64..8ad6c11 100644
--- a/superset/data/__init__.py
+++ b/superset/data/__init__.py
@@ -188,7 +188,7 @@ def load_world_bank_health_n_pop():
         "compare_lag": "10",
         "compare_suffix": "o10Y",
         "limit": "25",
-        "granularity": "year",
+        "granularity_sqla": "year",
         "groupby": [],
         "metric": 'sum__SP_POP_TOTL',
         "metrics": ["sum__SP_POP_TOTL"],
@@ -593,7 +593,7 @@ def load_birth_names():
         "compare_lag": "10",
         "compare_suffix": "o10Y",
         "limit": "25",
-        "granularity": "ds",
+        "granularity_sqla": "ds",
         "groupby": [],
         "metric": 'sum__num',
         "metrics": ["sum__num"],
@@ -642,7 +642,7 @@ def load_birth_names():
             datasource_id=tbl.id,
             params=get_slice_json(
                 defaults,
-                viz_type="big_number", granularity="ds",
+                viz_type="big_number", granularity_sqla="ds",
                 compare_lag="5", compare_suffix="over 5Y")),
         Slice(
             slice_name="Genders",
@@ -675,7 +675,7 @@ def load_birth_names():
             params=get_slice_json(
                 defaults,
                 viz_type="line", groupby=['name'],
-                granularity='ds', rich_tooltip=True, show_legend=True)),
+                granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
         Slice(
             slice_name="Average and Sum Trends",
             viz_type='dual_line',
@@ -684,7 +684,7 @@ def load_birth_names():
             params=get_slice_json(
                 defaults,
                 viz_type="dual_line", metric='avg__num', metric_2='sum__num',
-                granularity='ds')),
+                granularity_sqla='ds')),
         Slice(
             slice_name="Title",
             viz_type='markup',
@@ -729,7 +729,7 @@ def load_birth_names():
             datasource_id=tbl.id,
             params=get_slice_json(
                 defaults,
-                viz_type="big_number_total", granularity="ds",
+                viz_type="big_number_total", granularity_sqla="ds",
                 filters=[{
                     'col': 'gender',
                     'op': 'in',
@@ -876,7 +876,7 @@ def load_unicode_test_data():
     tbl = obj
 
     slice_data = {
-        "granularity": "dttm",
+        "granularity_sqla": "dttm",
         "groupby": [],
         "metric": 'sum__value',
         "row_limit": config.get("ROW_LIMIT"),
@@ -954,7 +954,7 @@ def load_random_time_series_data():
     tbl = obj
 
     slice_data = {
-        "granularity": "day",
+        "granularity_sqla": "day",
         "row_limit": config.get("ROW_LIMIT"),
         "since": "1 year ago",
         "until": "now",
@@ -1017,7 +1017,7 @@ def load_country_map_data():
     tbl = obj
 
     slice_data = {
-        "granularity": "",
+        "granularity_sqla": "",
         "since": "",
         "until": "",
         "where": "",
@@ -1092,7 +1092,7 @@ def load_long_lat_data():
     tbl = obj
 
     slice_data = {
-        "granularity": "day",
+        "granularity_sqla": "day",
         "since": "2014-01-01",
         "until": "now",
         "where": "",
@@ -1172,7 +1172,7 @@ def load_multiformat_time_series_data():
         slice_data = {
             "metric": 'count',
             "granularity_sqla": col.column_name,
-            "granularity": "day",
+            "granularity_sqla": "day",
             "row_limit": config.get("ROW_LIMIT"),
             "since": "1 year ago",
             "until": "now",
diff --git a/tests/model_tests.py b/tests/model_tests.py
index cdd4c83..8af104f 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -105,3 +105,64 @@ class DatabaseModelTestCase(SupersetTestCase):
         self.assertEquals(d.get('day').function, 'DATE({col})')
         self.assertEquals(d.get('P1D').function, 'DATE({col})')
         self.assertEquals(d.get('Time Column').function, '{col}')
+
+
+class SqlaTableModelTestCase(SupersetTestCase):
+
+    def test_get_timestamp_expression(self):
+        tbl = self.get_table_by_name('birth_names')
+        ds_col = tbl.get_column('ds')
+        sqla_literal = ds_col.get_timestamp_expression(None)
+        self.assertEquals(str(sqla_literal.compile()), 'ds')
+
+        sqla_literal = ds_col.get_timestamp_expression('P1D')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'DATE(ds)')
+
+        ds_col.expression = 'DATE_ADD(ds, 1)'
+        sqla_literal = ds_col.get_timestamp_expression('P1D')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')
+
+    def test_get_timestamp_expression_epoch(self):
+        tbl = self.get_table_by_name('birth_names')
+        ds_col = tbl.get_column('ds')
+
+        ds_col.expression = None
+        ds_col.python_date_format = 'epoch_s'
+        sqla_literal = ds_col.get_timestamp_expression(None)
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'from_unixtime(ds)')
+
+        ds_col.python_date_format = 'epoch_s'
+        sqla_literal = ds_col.get_timestamp_expression('P1D')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'DATE(from_unixtime(ds))')
+
+        ds_col.expression = 'DATE_ADD(ds, 1)'
+        sqla_literal = ds_col.get_timestamp_expression('P1D')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
+
+    def test_get_timestamp_expression_backward(self):
+        tbl = self.get_table_by_name('birth_names')
+        ds_col = tbl.get_column('ds')
+
+        ds_col.expression = None
+        ds_col.python_date_format = None
+        sqla_literal = ds_col.get_timestamp_expression('day')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'DATE(ds)')
+
+        ds_col.expression = None
+        ds_col.python_date_format = None
+        sqla_literal = ds_col.get_timestamp_expression('Time Column')
+        compiled = '{}'.format(sqla_literal.compile())
+        if tbl.database.backend == 'mysql':
+            self.assertEquals(compiled, 'ds')

-- 
To stop receiving notification emails like this one, please contact
maximebeauchemin@apache.org.

Mime
View raw message