superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From maximebeauche...@apache.org
Subject [incubator-superset] branch master updated: Druid refresh metadata performance improvements (#3527)
Date Tue, 26 Sep 2017 01:00:48 GMT
This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new cf0b670  Druid refresh metadata performance improvements (#3527)
cf0b670 is described below

commit cf0b670932e11a7d1fbfaaab7d6d6748340d4e80
Author: Jeff Niu <jeffniu22@gmail.com>
AuthorDate: Mon Sep 25 18:00:46 2017 -0700

    Druid refresh metadata performance improvements (#3527)
    
    * parallelized refresh druid metadata
    
    * fixed code style errors
    
    * fixed code for python3
    
    * added option to only scan for new druid datasources
    
    * Increased code coverage
---
 superset/connectors/druid/models.py | 294 ++++++++++++++++++++++--------------
 superset/connectors/druid/views.py  |  43 ++++--
 tests/druid_tests.py                |  14 +-
 3 files changed, 220 insertions(+), 131 deletions(-)

diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index cc40b83..89e1ed9 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -5,12 +5,13 @@ import logging
 from copy import deepcopy
 from datetime import datetime, timedelta
 from six import string_types
+from multiprocessing import Pool
 
 import requests
 import sqlalchemy as sa
 from sqlalchemy import (
     Column, Integer, String, ForeignKey, Text, Boolean,
-    DateTime,
+    DateTime, or_, and_,
 )
 from sqlalchemy.orm import backref, relationship
 from dateutil.parser import parse as dparse
@@ -39,6 +40,12 @@ from superset.models.helpers import AuditMixinNullable, QueryResult, set_perm
 DRUID_TZ = conf.get("DRUID_TZ")
 
 
+# Function wrapper because bound methods cannot
+# be passed to processes
+def _fetch_metadata_for(datasource):
+    return datasource.latest_metadata()
+
+
 class JavascriptPostAggregator(Postaggregator):
     def __init__(self, name, field_names, function):
         self.post_aggregator = {
@@ -101,15 +108,99 @@ class DruidCluster(Model, AuditMixinNullable):
         ).format(obj=self)
         return json.loads(requests.get(endpoint).text)['version']
 
-    def refresh_datasources(self, datasource_name=None, merge_flag=False):
+    def refresh_datasources(
+            self,
+            datasource_name=None,
+            merge_flag=True,
+            refreshAll=True):
         """Refresh metadata of all datasources in the cluster
         If ``datasource_name`` is specified, only that datasource is updated
         """
         self.druid_version = self.get_druid_version()
-        for datasource in self.get_datasources():
-            if datasource not in conf.get('DRUID_DATA_SOURCE_BLACKLIST', []):
-                if not datasource_name or datasource_name == datasource:
-                    DruidDatasource.sync_to_db(datasource, self, merge_flag)
+        ds_list = self.get_datasources()
+        blacklist = conf.get('DRUID_DATA_SOURCE_BLACKLIST', [])
+        ds_refresh = []
+        if not datasource_name:
+            ds_refresh = list(filter(lambda ds: ds not in blacklist, ds_list))
+        elif datasource_name not in blacklist and datasource_name in ds_list:
+            ds_refresh.append(datasource_name)
+        else:
+            return
+        self.refresh_async(ds_refresh, merge_flag, refreshAll)
+
+    def refresh_async(self, datasource_names, merge_flag, refreshAll):
+        """
+        Fetches metadata for the specified datasources andm
+        merges to the Superset database
+        """
+        session = db.session
+        ds_list = (
+            session.query(DruidDatasource)
+            .filter(or_(DruidDatasource.datasource_name == name
+                    for name in datasource_names))
+        )
+
+        ds_map = {ds.name: ds for ds in ds_list}
+        for ds_name in datasource_names:
+            datasource = ds_map.get(ds_name, None)
+            if not datasource:
+                datasource = DruidDatasource(datasource_name=ds_name)
+                with session.no_autoflush:
+                    session.add(datasource)
+                flasher(
+                    "Adding new datasource [{}]".format(ds_name), 'success')
+                ds_map[ds_name] = datasource
+            elif refreshAll:
+                flasher(
+                    "Refreshing datasource [{}]".format(ds_name), 'info')
+            else:
+                del ds_map[ds_name]
+                continue
+            datasource.cluster = self
+            datasource.merge_flag = merge_flag
+        session.flush()
+
+        # Prepare multithreaded executation
+        pool = Pool()
+        ds_refresh = list(ds_map.values())
+        metadata = pool.map(_fetch_metadata_for, ds_refresh)
+        pool.close()
+        pool.join()
+
+        for i in range(0, len(ds_refresh)):
+            datasource = ds_refresh[i]
+            cols = metadata[i]
+            col_objs_list = (
+                session.query(DruidColumn)
+                .filter(DruidColumn.datasource_name == datasource.datasource_name)
+                .filter(or_(DruidColumn.column_name == col for col in cols))
+            )
+            col_objs = {col.column_name: col for col in col_objs_list}
+            for col in cols:
+                if col == '__time':  # skip the time column
+                    continue
+                col_obj = col_objs.get(col, None)
+                if not col_obj:
+                    col_obj = DruidColumn(
+                        datasource_name=datasource.datasource_name,
+                        column_name=col)
+                    with session.no_autoflush:
+                        session.add(col_obj)
+                datatype = cols[col]['type']
+                if datatype == 'STRING':
+                    col_obj.groupby = True
+                    col_obj.filterable = True
+                if datatype == 'hyperUnique' or datatype == 'thetaSketch':
+                    col_obj.count_distinct = True
+                # Allow sum/min/max for long or double
+                if datatype == 'LONG' or datatype == 'DOUBLE':
+                    col_obj.sum = True
+                    col_obj.min = True
+                    col_obj.max = True
+                col_obj.type = datatype
+                col_obj.datasource = datasource
+            datasource.generate_metrics_for(col_objs_list)
+        session.commit()
 
     @property
     def perm(self):
@@ -160,16 +251,14 @@ class DruidColumn(Model, BaseColumn):
         if self.dimension_spec_json:
             return json.loads(self.dimension_spec_json)
 
-    def generate_metrics(self):
-        """Generate metrics based on the column metadata"""
-        M = DruidMetric  # noqa
-        metrics = []
-        metrics.append(DruidMetric(
+    def get_metrics(self):
+        metrics = {}
+        metrics['count'] = DruidMetric(
             metric_name='count',
             verbose_name='COUNT(*)',
             metric_type='count',
             json=json.dumps({'type': 'count', 'name': 'count'})
-        ))
+        )
         # Somehow we need to reassign this for UDAFs
         if self.type in ('DOUBLE', 'FLOAT'):
             corrected_type = 'DOUBLE'
@@ -179,49 +268,49 @@ class DruidColumn(Model, BaseColumn):
         if self.sum and self.is_num:
             mt = corrected_type.lower() + 'Sum'
             name = 'sum__' + self.column_name
-            metrics.append(DruidMetric(
+            metrics[name] = DruidMetric(
                 metric_name=name,
                 metric_type='sum',
                 verbose_name='SUM({})'.format(self.column_name),
                 json=json.dumps({
                     'type': mt, 'name': name, 'fieldName': self.column_name})
-            ))
+            )
 
         if self.avg and self.is_num:
             mt = corrected_type.lower() + 'Avg'
             name = 'avg__' + self.column_name
-            metrics.append(DruidMetric(
+            metrics[name] = DruidMetric(
                 metric_name=name,
                 metric_type='avg',
                 verbose_name='AVG({})'.format(self.column_name),
                 json=json.dumps({
                     'type': mt, 'name': name, 'fieldName': self.column_name})
-            ))
+            )
 
         if self.min and self.is_num:
             mt = corrected_type.lower() + 'Min'
             name = 'min__' + self.column_name
-            metrics.append(DruidMetric(
+            metrics[name] = DruidMetric(
                 metric_name=name,
                 metric_type='min',
                 verbose_name='MIN({})'.format(self.column_name),
                 json=json.dumps({
                     'type': mt, 'name': name, 'fieldName': self.column_name})
-            ))
+            )
         if self.max and self.is_num:
             mt = corrected_type.lower() + 'Max'
             name = 'max__' + self.column_name
-            metrics.append(DruidMetric(
+            metrics[name] = DruidMetric(
                 metric_name=name,
                 metric_type='max',
                 verbose_name='MAX({})'.format(self.column_name),
                 json=json.dumps({
                     'type': mt, 'name': name, 'fieldName': self.column_name})
-            ))
+            )
         if self.count_distinct:
             name = 'count_distinct__' + self.column_name
             if self.type == 'hyperUnique' or self.type == 'thetaSketch':
-                metrics.append(DruidMetric(
+                metrics[name] = DruidMetric(
                     metric_name=name,
                     verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
                     metric_type=self.type,
@@ -230,10 +319,9 @@ class DruidColumn(Model, BaseColumn):
                         'name': name,
                         'fieldName': self.column_name
                     })
-                ))
+                )
             else:
-                mt = 'count_distinct'
-                metrics.append(DruidMetric(
+                metrics[name] = DruidMetric(
                     metric_name=name,
                     verbose_name='COUNT(DISTINCT {})'.format(self.column_name),
                     metric_type='count_distinct',
@@ -241,22 +329,25 @@ class DruidColumn(Model, BaseColumn):
                         'type': 'cardinality',
                         'name': name,
                         'fieldNames': [self.column_name]})
-                ))
-        session = get_session()
-        new_metrics = []
-        for metric in metrics:
-            m = (
-                session.query(M)
-                .filter(M.metric_name == metric.metric_name)
-                .filter(M.datasource_name == self.datasource_name)
-                .filter(DruidCluster.cluster_name == self.datasource.cluster_name)
-                .first()
-            )
+                )
+        return metrics
+
+    def generate_metrics(self):
+        """Generate metrics based on the column metadata"""
+        metrics = self.get_metrics()
+        dbmetrics = (
+            db.session.query(DruidMetric)
+            .filter(DruidCluster.cluster_name == self.datasource.cluster_name)
+            .filter(DruidMetric.datasource_name == self.datasource_name)
+            .filter(or_(
+                DruidMetric.metric_name == m for m in metrics
+            ))
+        )
+        dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
+        for metric in metrics.values():
             metric.datasource_name = self.datasource_name
-            if not m:
-                new_metrics.append(metric)
-                session.add(metric)
-                session.flush()
+            if not dbmetrics.get(metric.metric_name, None):
+                db.session.add(metric)
 
     @classmethod
     def import_obj(cls, i_column):
@@ -474,6 +565,7 @@ class DruidDatasource(Model, BaseDatasource):
 
     def latest_metadata(self):
         """Returns segment metadata from the latest segment"""
+        logging.info("Syncing datasource [{}]".format(self.datasource_name))
         client = self.cluster.get_pydruid_client()
         results = client.time_boundary(datasource=self.datasource_name)
         if not results:
@@ -485,31 +577,33 @@ class DruidDatasource(Model, BaseDatasource):
         # realtime segments, which triggered a bug (fixed in druid 0.8.2).
         # https://groups.google.com/forum/#!topic/druid-user/gVCqqspHqOQ
         lbound = (max_time - timedelta(days=7)).isoformat()
-        rbound = max_time.isoformat()
         if not self.version_higher(self.cluster.druid_version, '0.8.2'):
             rbound = (max_time - timedelta(1)).isoformat()
+        else:
+            rbound = max_time.isoformat()
         segment_metadata = None
         try:
             segment_metadata = client.segment_metadata(
                 datasource=self.datasource_name,
                 intervals=lbound + '/' + rbound,
                 merge=self.merge_flag,
-                analysisTypes=conf.get('DRUID_ANALYSIS_TYPES'))
+                analysisTypes=[])
         except Exception as e:
             logging.warning("Failed first attempt to get latest segment")
             logging.exception(e)
         if not segment_metadata:
             # if no segments in the past 7 days, look at all segments
             lbound = datetime(1901, 1, 1).isoformat()[:10]
-            rbound = datetime(2050, 1, 1).isoformat()[:10]
             if not self.version_higher(self.cluster.druid_version, '0.8.2'):
                 rbound = datetime.now().isoformat()
+            else:
+                rbound = datetime(2050, 1, 1).isoformat()[:10]
             try:
                 segment_metadata = client.segment_metadata(
                     datasource=self.datasource_name,
                     intervals=lbound + '/' + rbound,
                     merge=self.merge_flag,
-                    analysisTypes=conf.get('DRUID_ANALYSIS_TYPES'))
+                    analysisTypes=[])
             except Exception as e:
                 logging.warning("Failed 2nd attempt to get latest segment")
                 logging.exception(e)
@@ -517,17 +611,37 @@ class DruidDatasource(Model, BaseDatasource):
             return segment_metadata[-1]['columns']
 
     def generate_metrics(self):
-        for col in self.columns:
-            col.generate_metrics()
+        self.generate_metrics_for(self.columns)
+
+    def generate_metrics_for(self, columns):
+        metrics = {}
+        for col in columns:
+            metrics.update(col.get_metrics())
+        dbmetrics = (
+            db.session.query(DruidMetric)
+            .filter(DruidCluster.cluster_name == self.cluster_name)
+            .filter(DruidMetric.datasource_name == self.datasource_name)
+            .filter(or_(DruidMetric.metric_name == m for m in metrics))
+        )
+        dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
+        for metric in metrics.values():
+            metric.datasource_name = self.datasource_name
+            if not dbmetrics.get(metric.metric_name, None):
+                with db.session.no_autoflush:
+                    db.session.add(metric)
 
     @classmethod
-    def sync_to_db_from_config(cls, druid_config, user, cluster):
+    def sync_to_db_from_config(
+            cls,
+            druid_config,
+            user,
+            cluster,
+            refresh=True):
         """Merges the ds config from druid_config into one stored in the db."""
-        session = db.session()
+        session = db.session
         datasource = (
             session.query(cls)
-            .filter_by(
-                datasource_name=druid_config['name'])
+            .filter_by(datasource_name=druid_config['name'])
             .first()
         )
         # Create a new datasource.
@@ -540,16 +654,18 @@ class DruidDatasource(Model, BaseDatasource):
                 created_by_fk=user.id,
             )
             session.add(datasource)
+        elif not refresh:
+            return
 
         dimensions = druid_config['dimensions']
+        col_objs = (
+            session.query(DruidColumn)
+            .filter(DruidColumn.datasource_name == druid_config['name'])
+            .filter(or_(DruidColumn.column_name == dim for dim in dimensions))
+        )
+        col_objs = {col.column_name: col for col in col_objs}
         for dim in dimensions:
-            col_obj = (
-                session.query(DruidColumn)
-                .filter_by(
-                    datasource_name=druid_config['name'],
-                    column_name=dim)
-                .first()
-            )
+            col_obj = col_objs.get(dim, None)
             if not col_obj:
                 col_obj = DruidColumn(
                     datasource_name=druid_config['name'],
@@ -562,6 +678,13 @@ class DruidDatasource(Model, BaseDatasource):
                 )
                 session.add(col_obj)
         # Import Druid metrics
+        metric_objs = (
+            session.query(DruidMetric)
+            .filter(DruidMetric.datasource_name == druid_config['name'])
+            .filter(or_(DruidMetric.metric_name == spec['name']
+                    for spec in druid_config["metrics_spec"]))
+        )
+        metric_objs = {metric.metric_name: metric for metric in metric_objs}
         for metric_spec in druid_config["metrics_spec"]:
             metric_name = metric_spec["name"]
             metric_type = metric_spec["type"]
@@ -575,12 +698,7 @@ class DruidDatasource(Model, BaseDatasource):
                     "fieldName": metric_name,
                 })
 
-            metric_obj = (
-                session.query(DruidMetric)
-                .filter_by(
-                    datasource_name=druid_config['name'],
-                    metric_name=metric_name)
-            ).first()
+            metric_obj = metric_objs.get(metric_name, None)
             if not metric_obj:
                 metric_obj = DruidMetric(
                     metric_name=metric_name,
@@ -595,58 +713,6 @@ class DruidDatasource(Model, BaseDatasource):
                 session.add(metric_obj)
         session.commit()
 
-    @classmethod
-    def sync_to_db(cls, name, cluster, merge):
-        """Fetches metadata for that datasource and merges the Superset db"""
-        logging.info("Syncing Druid datasource [{}]".format(name))
-        session = get_session()
-        datasource = session.query(cls).filter_by(datasource_name=name).first()
-        if not datasource:
-            datasource = cls(datasource_name=name)
-            session.add(datasource)
-            flasher("Adding new datasource [{}]".format(name), "success")
-        else:
-            flasher("Refreshing datasource [{}]".format(name), "info")
-        session.flush()
-        datasource.cluster = cluster
-        datasource.merge_flag = merge
-        session.flush()
-
-        cols = datasource.latest_metadata()
-        if not cols:
-            logging.error("Failed at fetching the latest segment")
-            return
-        for col in cols:
-            # Skip the time column
-            if col == "__time":
-                continue
-            col_obj = (
-                session
-                .query(DruidColumn)
-                .filter_by(datasource_name=name, column_name=col)
-                .first()
-            )
-            datatype = cols[col]['type']
-            if not col_obj:
-                col_obj = DruidColumn(datasource_name=name, column_name=col)
-                session.add(col_obj)
-            if datatype == "STRING":
-                col_obj.groupby = True
-                col_obj.filterable = True
-            if datatype == "hyperUnique" or datatype == "thetaSketch":
-                col_obj.count_distinct = True
-            # If long or double, allow sum/min/max
-            if datatype == "LONG" or datatype == "DOUBLE":
-                col_obj.sum = True
-                col_obj.min = True
-                col_obj.max = True
-            if col_obj:
-                col_obj.type = cols[col]['type']
-            session.flush()
-            col_obj.datasource = datasource
-            col_obj.generate_metrics()
-            session.flush()
-
     @staticmethod
     def time_offset(granularity):
         if granularity == 'week_ending_saturday':
diff --git a/superset/connectors/druid/views.py b/superset/connectors/druid/views.py
index f64b6c1..42fbdbb 100644
--- a/superset/connectors/druid/views.py
+++ b/superset/connectors/druid/views.py
@@ -235,17 +235,17 @@ class DruidDatasourceModelView(DatasourceModelView, DeleteMixin):  #
noqa
     }
 
     def pre_add(self, datasource):
-        number_of_existing_datasources = db.session.query(
-            sqla.func.count('*')).filter(
-            models.DruidDatasource.datasource_name ==
-                datasource.datasource_name,
-            models.DruidDatasource.cluster_name == datasource.cluster.id
-        ).scalar()
-
-        # table object is already added to the session
-        if number_of_existing_datasources > 1:
-            raise Exception(get_datasource_exist_error_mgs(
-                datasource.full_name))
+        with db.session.no_autoflush:
+            query = (
+                db.session.query(models.DruidDatasource)
+                .filter(models.DruidDatasource.datasource_name ==
+                        datasource.datasource_name,
+                        models.DruidDatasource.cluster_name ==
+                        datasource.cluster.id)
+            )
+            if db.session.query(query.exists()).scalar():
+                raise Exception(get_datasource_exist_error_mgs(
+                    datasource.full_name))
 
     def post_add(self, datasource):
         datasource.generate_metrics()
@@ -273,14 +273,14 @@ class Druid(BaseSupersetView):
 
     @has_access
     @expose("/refresh_datasources/")
-    def refresh_datasources(self):
+    def refresh_datasources(self, refreshAll=True):
         """endpoint that refreshes druid datasources metadata"""
         session = db.session()
         DruidCluster = ConnectorRegistry.sources['druid'].cluster_class
         for cluster in session.query(DruidCluster).all():
             cluster_name = cluster.cluster_name
             try:
-                cluster.refresh_datasources()
+                cluster.refresh_datasources(refreshAll=refreshAll)
             except Exception as e:
                 flash(
                     "Error while processing cluster '{}'\n{}".format(
@@ -296,9 +296,26 @@ class Druid(BaseSupersetView):
         session.commit()
         return redirect("/druiddatasourcemodelview/list/")
 
+    @has_access
+    @expose("/scan_new_datasources/")
+    def scan_new_datasources(self):
+        """
+        Calling this endpoint will cause a scan for new
+        datasources only and add them.
+        """
+        return self.refresh_datasources(refreshAll=False)
+
 appbuilder.add_view_no_menu(Druid)
 
 appbuilder.add_link(
+    "Scan New Datasources",
+    label=__("Scan New Datasources"),
+    href='/druid/scan_new_datasources/',
+    category='Sources',
+    category_label=__("Sources"),
+    category_icon='fa-database',
+    icon="fa-refresh")
+appbuilder.add_link(
     "Refresh Druid Metadata",
     label=__("Refresh Druid Metadata"),
     href='/druid/refresh_datasources/',
diff --git a/tests/druid_tests.py b/tests/druid_tests.py
index 637afe9..c506ebf 100644
--- a/tests/druid_tests.py
+++ b/tests/druid_tests.py
@@ -16,6 +16,9 @@ from superset.connectors.druid.models import PyDruid, Quantile, Postaggregator
 
 from .base_tests import SupersetTestCase
 
+class PickableMock(Mock):
+    def __reduce__(self):
+        return (Mock, ())
 
 SEGMENT_METADATA = [{
   "id": "some_id",
@@ -98,8 +101,8 @@ class DruidTests(SupersetTestCase):
             metadata_last_refreshed=datetime.now())
 
         db.session.add(cluster)
-        cluster.get_datasources = Mock(return_value=['test_datasource'])
-        cluster.get_druid_version = Mock(return_value='0.9.1')
+        cluster.get_datasources = PickableMock(return_value=['test_datasource'])
+        cluster.get_druid_version = PickableMock(return_value='0.9.1')
         cluster.refresh_datasources()
         cluster.refresh_datasources(merge_flag=True)
         datasource_id = cluster.datasources[0].id
@@ -303,11 +306,14 @@ class DruidTests(SupersetTestCase):
             metadata_last_refreshed=datetime.now())
 
         db.session.add(cluster)
-        cluster.get_datasources = Mock(return_value=['test_datasource'])
-        cluster.get_druid_version = Mock(return_value='0.9.1')
+        cluster.get_datasources = PickableMock(return_value=['test_datasource'])
+        cluster.get_druid_version = PickableMock(return_value='0.9.1')
 
         cluster.refresh_datasources()
         datasource_id = cluster.datasources[0].id
+        cluster.datasources[0].merge_flag = True
+        metadata = cluster.datasources[0].latest_metadata()
+        self.assertEqual(len(metadata), 4)
         db.session.commit()
 
         view_menu_name = cluster.datasources[0].get_perm()

-- 
To stop receiving notification emails like this one, please contact
['"commits@superset.apache.org" <commits@superset.apache.org>'].

Mime
View raw message