superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: Fixed finding postaggregations (#4017)
Date Thu, 07 Dec 2017 05:55:45 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new cb7c5aa  Fixed finding postaggregations (#4017)
cb7c5aa is described below

commit cb7c5aa70c3729d8f1fe0c310d30e620f9e9a581
Author: Jeff Niu <jeffniu22@gmail.com>
AuthorDate: Wed Dec 6 21:55:43 2017 -0800

    Fixed finding postaggregations (#4017)
---
 superset/connectors/druid/models.py | 174 ++++++++++++++--------
 tests/druid_func_tests.py           | 284 ++++++++++++++++++++++++++++++++++++
 tests/druid_tests.py                |  87 +----------
 3 files changed, 397 insertions(+), 148 deletions(-)

diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index bf7e176..acb1951 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -786,73 +786,123 @@ class DruidDatasource(Model, BaseDatasource):
         return granularity
 
     @staticmethod
-    def _metrics_and_post_aggs(metrics, metrics_dict):
-        all_metrics = []
-        post_aggs = {}
-
-        def recursive_get_fields(_conf):
-            _type = _conf.get('type')
-            _field = _conf.get('field')
-            _fields = _conf.get('fields')
-
-            field_names = []
-            if _type in ['fieldAccess', 'hyperUniqueCardinality',
-                         'quantile', 'quantiles']:
-                field_names.append(_conf.get('fieldName', ''))
+    def get_post_agg(mconf):
+        """
+        For a metric specified as `postagg` returns the
+        kind of post aggregation for pydruid.
+        """
+        if mconf.get('type') == 'javascript':
+            return JavascriptPostAggregator(
+                name=mconf.get('name', ''),
+                field_names=mconf.get('fieldNames', []),
+                function=mconf.get('function', ''))
+        elif mconf.get('type') == 'quantile':
+            return Quantile(
+                mconf.get('name', ''),
+                mconf.get('probability', ''),
+            )
+        elif mconf.get('type') == 'quantiles':
+            return Quantiles(
+                mconf.get('name', ''),
+                mconf.get('probabilities', ''),
+            )
+        elif mconf.get('type') == 'fieldAccess':
+            return Field(mconf.get('name'))
+        elif mconf.get('type') == 'constant':
+            return Const(
+                mconf.get('value'),
+                output_name=mconf.get('name', ''),
+            )
+        elif mconf.get('type') == 'hyperUniqueCardinality':
+            return HyperUniqueCardinality(
+                mconf.get('name'),
+            )
+        elif mconf.get('type') == 'arithmetic':
+            return Postaggregator(
+                mconf.get('fn', '/'),
+                mconf.get('fields', []),
+                mconf.get('name', ''))
+        else:
+            return CustomPostAggregator(
+                mconf.get('name', ''),
+                mconf)
 
-            if _field:
-                field_names += recursive_get_fields(_field)
+    @staticmethod
+    def find_postaggs_for(postagg_names, metrics_dict):
+        """Return a list of metrics that are post aggregations"""
+        postagg_metrics = [
+            metrics_dict[name] for name in postagg_names
+            if metrics_dict[name].metric_type == 'postagg'
+        ]
+        # Remove post aggregations that were found
+        for postagg in postagg_metrics:
+            postagg_names.remove(postagg.metric_name)
+        return postagg_metrics
 
-            if _fields:
-                for _f in _fields:
-                    field_names += recursive_get_fields(_f)
+    @staticmethod
+    def recursive_get_fields(_conf):
+        _type = _conf.get('type')
+        _field = _conf.get('field')
+        _fields = _conf.get('fields')
+        field_names = []
+        if _type in ['fieldAccess', 'hyperUniqueCardinality',
+                     'quantile', 'quantiles']:
+            field_names.append(_conf.get('fieldName', ''))
+        if _field:
+            field_names += DruidDatasource.recursive_get_fields(_field)
+        if _fields:
+            for _f in _fields:
+                field_names += DruidDatasource.recursive_get_fields(_f)
+        return list(set(field_names))
 
-            return list(set(field_names))
+    @staticmethod
+    def resolve_postagg(postagg, post_aggs, agg_names, visited_postaggs, metrics_dict):
+        mconf = postagg.json_obj
+        required_fields = set(
+            DruidDatasource.recursive_get_fields(mconf)
+            + mconf.get('fieldNames', []))
+        # Check if the fields are already in aggs
+        # or is a previous postagg
+        required_fields = set([
+            field for field in required_fields
+            if field not in visited_postaggs and field not in agg_names
+        ])
+        # First try to find postaggs that match
+        if len(required_fields) > 0:
+            missing_postaggs = DruidDatasource.find_postaggs_for(
+                required_fields, metrics_dict)
+            for missing_metric in required_fields:
+                agg_names.add(missing_metric)
+            for missing_postagg in missing_postaggs:
+                # Add to visited first to avoid infinite recursion
+                # if post aggregations are cyclicly dependent
+                visited_postaggs.add(missing_postagg.metric_name)
+            for missing_postagg in missing_postaggs:
+                DruidDatasource.resolve_postagg(
+                    missing_postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
+        post_aggs[postagg.metric_name] = DruidDatasource.get_post_agg(postagg.json_obj)
 
+    @staticmethod
+    def metrics_and_post_aggs(metrics, metrics_dict):
+        # Separate metrics into those that are aggregations
+        # and those that are post aggregations
+        agg_names = set()
+        postagg_names = []
         for metric_name in metrics:
-            metric = metrics_dict[metric_name]
-            if metric.metric_type != 'postagg':
-                all_metrics.append(metric_name)
+            if metrics_dict[metric_name].metric_type != 'postagg':
+                agg_names.add(metric_name)
             else:
-                mconf = metric.json_obj
-                all_metrics += recursive_get_fields(mconf)
-                all_metrics += mconf.get('fieldNames', [])
-                if mconf.get('type') == 'javascript':
-                    post_aggs[metric_name] = JavascriptPostAggregator(
-                        name=mconf.get('name', ''),
-                        field_names=mconf.get('fieldNames', []),
-                        function=mconf.get('function', ''))
-                elif mconf.get('type') == 'quantile':
-                    post_aggs[metric_name] = Quantile(
-                        mconf.get('name', ''),
-                        mconf.get('probability', ''),
-                    )
-                elif mconf.get('type') == 'quantiles':
-                    post_aggs[metric_name] = Quantiles(
-                        mconf.get('name', ''),
-                        mconf.get('probabilities', ''),
-                    )
-                elif mconf.get('type') == 'fieldAccess':
-                    post_aggs[metric_name] = Field(mconf.get('name'))
-                elif mconf.get('type') == 'constant':
-                    post_aggs[metric_name] = Const(
-                        mconf.get('value'),
-                        output_name=mconf.get('name', ''),
-                    )
-                elif mconf.get('type') == 'hyperUniqueCardinality':
-                    post_aggs[metric_name] = HyperUniqueCardinality(
-                        mconf.get('name'),
-                    )
-                elif mconf.get('type') == 'arithmetic':
-                    post_aggs[metric_name] = Postaggregator(
-                        mconf.get('fn', '/'),
-                        mconf.get('fields', []),
-                        mconf.get('name', ''))
-                else:
-                    post_aggs[metric_name] = CustomPostAggregator(
-                        mconf.get('name', ''),
-                        mconf)
-        return all_metrics, post_aggs
+                postagg_names.append(metric_name)
+        # Create the post aggregations, maintain order since postaggs
+        # may depend on previous ones
+        post_aggs = OrderedDict()
+        visited_postaggs = set()
+        for postagg_name in postagg_names:
+            postagg = metrics_dict[postagg_name]
+            visited_postaggs.add(postagg_name)
+            DruidDatasource.resolve_postagg(
+                postagg, post_aggs, agg_names, visited_postaggs, metrics_dict)
+        return list(agg_names), post_aggs
 
     def values_for_column(self,
                           column_name,
@@ -940,7 +990,7 @@ class DruidDatasource(Model, BaseDatasource):
 
         columns_dict = {c.column_name: c for c in self.columns}
 
-        all_metrics, post_aggs = self._metrics_and_post_aggs(
+        all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
             metrics,
             metrics_dict)
 
diff --git a/tests/druid_func_tests.py b/tests/druid_func_tests.py
index ba1f497..4c047df 100644
--- a/tests/druid_func_tests.py
+++ b/tests/druid_func_tests.py
@@ -2,12 +2,26 @@ import json
 import unittest
 
 from mock import Mock
+import pydruid.utils.postaggregator as postaggs
 
+import superset.connectors.druid.models as models
 from superset.connectors.druid.models import (
     DruidColumn, DruidDatasource, DruidMetric,
 )
 
 
+def mock_metric(metric_name, is_postagg=False):
+    metric = Mock()
+    metric.metric_name = metric_name
+    metric.metric_type = 'postagg' if is_postagg else 'metric'
+    return metric
+
+
+def emplace(metrics_dict, metric_name, is_postagg=False):
+    metrics_dict[metric_name] = mock_metric(metric_name, is_postagg)
+
+
+# Unit tests that can be run without initializing base tests
 class DruidFuncTestCase(unittest.TestCase):
 
     def test_get_filters_ignores_invalid_filter_objects(self):
@@ -271,3 +285,273 @@ class DruidFuncTestCase(unittest.TestCase):
         called_args = client.groupby.call_args_list[0][1]
         self.assertIn('dimensions', called_args)
         self.assertEqual(['col1', 'col2'], called_args['dimensions'])
+
+    def test_get_post_agg_returns_correct_agg_type(self):
+        get_post_agg = DruidDatasource.get_post_agg
+        # javascript PostAggregators
+        function = 'function(field1, field2) { return field1 + field2; }'
+        conf = {
+            'type': 'javascript',
+            'name': 'postagg_name',
+            'fieldNames': ['field1', 'field2'],
+            'function': function,
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, models.JavascriptPostAggregator))
+        self.assertEqual(postagg.name, 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['type'], 'javascript')
+        self.assertEqual(postagg.post_aggregator['fieldNames'], ['field1', 'field2'])
+        self.assertEqual(postagg.post_aggregator['name'], 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['function'], function)
+        # Quantile
+        conf = {
+            'type': 'quantile',
+            'name': 'postagg_name',
+            'probability': '0.5',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.Quantile))
+        self.assertEqual(postagg.name, 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['probability'], '0.5')
+        # Quantiles
+        conf = {
+            'type': 'quantiles',
+            'name': 'postagg_name',
+            'probabilities': '0.4,0.5,0.6',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.Quantiles))
+        self.assertEqual(postagg.name, 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['probabilities'], '0.4,0.5,0.6')
+        # FieldAccess
+        conf = {
+            'type': 'fieldAccess',
+            'name': 'field_name',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.Field))
+        self.assertEqual(postagg.name, 'field_name')
+        # constant
+        conf = {
+            'type': 'constant',
+            'value': 1234,
+            'name': 'postagg_name',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.Const))
+        self.assertEqual(postagg.name, 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['value'], 1234)
+        # hyperUniqueCardinality
+        conf = {
+            'type': 'hyperUniqueCardinality',
+            'name': 'unique_name',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.HyperUniqueCardinality))
+        self.assertEqual(postagg.name, 'unique_name')
+        # arithmetic
+        conf = {
+            'type': 'arithmetic',
+            'fn': '+',
+            'fields': ['field1', 'field2'],
+            'name': 'postagg_name',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, postaggs.Postaggregator))
+        self.assertEqual(postagg.name, 'postagg_name')
+        self.assertEqual(postagg.post_aggregator['fn'], '+')
+        self.assertEqual(postagg.post_aggregator['fields'], ['field1', 'field2'])
+        # custom post aggregator
+        conf = {
+            'type': 'custom',
+            'name': 'custom_name',
+            'stuff': 'more_stuff',
+        }
+        postagg = get_post_agg(conf)
+        self.assertTrue(isinstance(postagg, models.CustomPostAggregator))
+        self.assertEqual(postagg.name, 'custom_name')
+        self.assertEqual(postagg.post_aggregator['stuff'], 'more_stuff')
+
+    def test_find_postaggs_for_returns_postaggs_and_removes(self):
+        find_postaggs_for = DruidDatasource.find_postaggs_for
+        postagg_names = set(['pa2', 'pa3', 'pa4', 'm1', 'm2', 'm3', 'm4'])
+
+        metrics = {}
+        for i in range(1, 6):
+            emplace(metrics, 'pa' + str(i), True)
+            emplace(metrics, 'm' + str(i), False)
+        postagg_list = find_postaggs_for(postagg_names, metrics)
+        self.assertEqual(3, len(postagg_list))
+        self.assertEqual(4, len(postagg_names))
+        expected_metrics = ['m1', 'm2', 'm3', 'm4']
+        expected_postaggs = set(['pa2', 'pa3', 'pa4'])
+        for postagg in postagg_list:
+            expected_postaggs.remove(postagg.metric_name)
+        for metric in expected_metrics:
+            postagg_names.remove(metric)
+        self.assertEqual(0, len(expected_postaggs))
+        self.assertEqual(0, len(postagg_names))
+
+    def test_recursive_get_fields(self):
+        conf = {
+            'type': 'quantile',
+            'fieldName': 'f1',
+            'field': {
+                'type': 'custom',
+                'fields': [{
+                    'type': 'fieldAccess',
+                    'fieldName': 'f2',
+                }, {
+                    'type': 'fieldAccess',
+                    'fieldName': 'f3',
+                }, {
+                    'type': 'quantiles',
+                    'fieldName': 'f4',
+                    'field': {
+                        'type': 'custom',
+                    },
+                }, {
+                    'type': 'custom',
+                    'fields': [{
+                        'type': 'fieldAccess',
+                        'fieldName': 'f5',
+                    }, {
+                        'type': 'fieldAccess',
+                        'fieldName': 'f2',
+                        'fields': [{
+                            'type': 'fieldAccess',
+                            'fieldName': 'f3',
+                        }, {
+                            'type': 'fieldIgnoreMe',
+                            'fieldName': 'f6',
+                        }],
+                    }],
+                }],
+            },
+        }
+        fields = DruidDatasource.recursive_get_fields(conf)
+        expected = set(['f1', 'f2', 'f3', 'f4', 'f5'])
+        self.assertEqual(5, len(fields))
+        for field in fields:
+            expected.remove(field)
+        self.assertEqual(0, len(expected))
+
+    def test_metrics_and_post_aggs_tree(self):
+        metrics = ['A', 'B', 'm1', 'm2']
+        metrics_dict = {}
+        for i in range(ord('A'), ord('K') + 1):
+            emplace(metrics_dict, chr(i), True)
+        for i in range(1, 10):
+            emplace(metrics_dict, 'm' + str(i), False)
+
+        def depends_on(index, fields):
+            dependents = fields if isinstance(fields, list) else [fields]
+            metrics_dict[index].json_obj = {'fieldNames': dependents}
+
+        depends_on('A', ['m1', 'D', 'C'])
+        depends_on('B', ['B', 'C', 'E', 'F', 'm3'])
+        depends_on('C', ['H', 'I'])
+        depends_on('D', ['m2', 'm5', 'G', 'C'])
+        depends_on('E', ['H', 'I', 'J'])
+        depends_on('F', ['J', 'm5'])
+        depends_on('G', ['m4', 'm7', 'm6', 'A'])
+        depends_on('H', ['A', 'm4', 'I'])
+        depends_on('I', ['H', 'K'])
+        depends_on('J', 'K')
+        depends_on('K', ['m8', 'm9'])
+        all_metrics, postaggs = DruidDatasource.metrics_and_post_aggs(
+            metrics, metrics_dict)
+        expected_metrics = set(all_metrics)
+        self.assertEqual(9, len(all_metrics))
+        for i in range(1, 10):
+            expected_metrics.remove('m' + str(i))
+        self.assertEqual(0, len(expected_metrics))
+        self.assertEqual(11, len(postaggs))
+        for i in range(ord('A'), ord('K') + 1):
+            del postaggs[chr(i)]
+        self.assertEqual(0, len(postaggs))
+
+    def test_metrics_and_post_aggs(self):
+        """
+        Test generation of metrics and post-aggregations from an initial list
+        of superset metrics (which may include the results of either). This
+        primarily tests that specifying a post-aggregator metric will also
+        require the raw aggregation of the associated druid metric column.
+        """
+        metrics_dict = {
+            'unused_count': DruidMetric(
+                metric_name='unused_count',
+                verbose_name='COUNT(*)',
+                metric_type='count',
+                json=json.dumps({'type': 'count', 'name': 'unused_count'}),
+            ),
+            'some_sum': DruidMetric(
+                metric_name='some_sum',
+                verbose_name='SUM(*)',
+                metric_type='sum',
+                json=json.dumps({'type': 'sum', 'name': 'sum'}),
+            ),
+            'a_histogram': DruidMetric(
+                metric_name='a_histogram',
+                verbose_name='APPROXIMATE_HISTOGRAM(*)',
+                metric_type='approxHistogramFold',
+                json=json.dumps(
+                    {'type': 'approxHistogramFold', 'name': 'a_histogram'},
+                ),
+            ),
+            'aCustomMetric': DruidMetric(
+                metric_name='aCustomMetric',
+                verbose_name='MY_AWESOME_METRIC(*)',
+                metric_type='aCustomType',
+                json=json.dumps(
+                    {'type': 'customMetric', 'name': 'aCustomMetric'},
+                ),
+            ),
+            'quantile_p95': DruidMetric(
+                metric_name='quantile_p95',
+                verbose_name='P95(*)',
+                metric_type='postagg',
+                json=json.dumps({
+                    'type': 'quantile',
+                    'probability': 0.95,
+                    'name': 'p95',
+                    'fieldName': 'a_histogram',
+                }),
+            ),
+            'aCustomPostAgg': DruidMetric(
+                metric_name='aCustomPostAgg',
+                verbose_name='CUSTOM_POST_AGG(*)',
+                metric_type='postagg',
+                json=json.dumps({
+                    'type': 'customPostAgg',
+                    'name': 'aCustomPostAgg',
+                    'field': {
+                        'type': 'fieldAccess',
+                        'fieldName': 'aCustomMetric',
+                    },
+                }),
+            ),
+        }
+
+        metrics = ['some_sum']
+        all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+            metrics, metrics_dict)
+
+        assert all_metrics == ['some_sum']
+        assert post_aggs == {}
+
+        metrics = ['quantile_p95']
+        all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+            metrics, metrics_dict)
+
+        result_postaggs = set(['quantile_p95'])
+        assert all_metrics == ['a_histogram']
+        assert set(post_aggs.keys()) == result_postaggs
+
+        metrics = ['aCustomPostAgg']
+        all_metrics, post_aggs = DruidDatasource.metrics_and_post_aggs(
+            metrics, metrics_dict)
+
+        result_postaggs = set(['aCustomPostAgg'])
+        assert all_metrics == ['aCustomMetric']
+        assert set(post_aggs.keys()) == result_postaggs
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index c9dce33..c280da7 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -12,7 +12,7 @@ from mock import Mock, patch
 
 from superset import db, security, sm
 from superset.connectors.druid.models import (
-    DruidCluster, DruidDatasource, DruidMetric,
+    DruidCluster, DruidDatasource,
 )
 from .base_tests import SupersetTestCase
 
@@ -328,91 +328,6 @@ class DruidTests(SupersetTestCase):
             permission=permission, view_menu=view_menu).first()
         assert pv is not None
 
-    def test_metrics_and_post_aggs(self):
-        """
-        Test generation of metrics and post-aggregations from an initial list
-        of superset metrics (which may include the results of either). This
-        primarily tests that specifying a post-aggregator metric will also
-        require the raw aggregation of the associated druid metric column.
-        """
-        metrics_dict = {
-            'unused_count': DruidMetric(
-                metric_name='unused_count',
-                verbose_name='COUNT(*)',
-                metric_type='count',
-                json=json.dumps({'type': 'count', 'name': 'unused_count'}),
-            ),
-            'some_sum': DruidMetric(
-                metric_name='some_sum',
-                verbose_name='SUM(*)',
-                metric_type='sum',
-                json=json.dumps({'type': 'sum', 'name': 'sum'}),
-            ),
-            'a_histogram': DruidMetric(
-                metric_name='a_histogram',
-                verbose_name='APPROXIMATE_HISTOGRAM(*)',
-                metric_type='approxHistogramFold',
-                json=json.dumps(
-                    {'type': 'approxHistogramFold', 'name': 'a_histogram'},
-                ),
-            ),
-            'aCustomMetric': DruidMetric(
-                metric_name='aCustomMetric',
-                verbose_name='MY_AWESOME_METRIC(*)',
-                metric_type='aCustomType',
-                json=json.dumps(
-                    {'type': 'customMetric', 'name': 'aCustomMetric'},
-                ),
-            ),
-            'quantile_p95': DruidMetric(
-                metric_name='quantile_p95',
-                verbose_name='P95(*)',
-                metric_type='postagg',
-                json=json.dumps({
-                    'type': 'quantile',
-                    'probability': 0.95,
-                    'name': 'p95',
-                    'fieldName': 'a_histogram',
-                }),
-            ),
-            'aCustomPostAgg': DruidMetric(
-                metric_name='aCustomPostAgg',
-                verbose_name='CUSTOM_POST_AGG(*)',
-                metric_type='postagg',
-                json=json.dumps({
-                    'type': 'customPostAgg',
-                    'name': 'aCustomPostAgg',
-                    'field': {
-                        'type': 'fieldAccess',
-                        'fieldName': 'aCustomMetric',
-                    },
-                }),
-            ),
-        }
-
-        metrics = ['some_sum']
-        all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
-            metrics, metrics_dict)
-
-        assert all_metrics == ['some_sum']
-        assert post_aggs == {}
-
-        metrics = ['quantile_p95']
-        all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
-            metrics, metrics_dict)
-
-        result_postaggs = set(['quantile_p95'])
-        assert all_metrics == ['a_histogram']
-        assert set(post_aggs.keys()) == result_postaggs
-
-        metrics = ['aCustomPostAgg']
-        all_metrics, post_aggs = DruidDatasource._metrics_and_post_aggs(
-            metrics, metrics_dict)
-
-        result_postaggs = set(['aCustomPostAgg'])
-        assert all_metrics == ['aCustomMetric']
-        assert set(post_aggs.keys()) == result_postaggs
-
 
 if __name__ == '__main__':
     unittest.main()

-- 
To stop receiving notification emails like this one, please contact
['"commits@superset.apache.org" <commits@superset.apache.org>'].

Mime
View raw message