superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From b...@apache.org
Subject [incubator-superset] branch master updated: feat: Add `validate_sql_json` endpoint for checking that a given sql query is valid for the chosen database (#7422) (#7462)
Date Mon, 06 May 2019 17:21:15 GMT
This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2497048  feat: Add `validate_sql_json` endpoint for checking that a given sql query
is valid for the chosen database (#7422) (#7462)
2497048 is described below

commit 24970485cf05fdc4afa815c708dd38d95ebe389e
Author: Alex Berghage <aberghage@gmail.com>
AuthorDate: Mon May 6 11:21:02 2019 -0600

    feat: Add `validate_sql_json` endpoint for checking that a given sql query is valid for
the chosen database (#7422) (#7462)
    
    merge from lyft-release-sp8 to master
---
 docs/installation.rst                |  40 ++++---
 superset/config.py                   |   7 ++
 superset/sql_validators/__init__.py  |  27 +++++
 superset/sql_validators/base.py      |  66 +++++++++++
 superset/sql_validators/presto_db.py | 186 +++++++++++++++++++++++++++++++
 superset/views/core.py               |  69 +++++++++++-
 tests/base_tests.py                  |  15 +++
 tests/sql_validator_tests.py         | 210 +++++++++++++++++++++++++++++++++++
 8 files changed, 605 insertions(+), 15 deletions(-)

diff --git a/docs/installation.rst b/docs/installation.rst
index b7c83e5..c7c24fc 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -439,8 +439,8 @@ The connection string for Teradata looks like this ::
 
 Required environment variables: ::
 
-    export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini  
-    export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini 
+    export ODBCINI=/.../teradata/client/ODBC_64/odbc.ini
+    export ODBCINST=/.../teradata/client/ODBC_64/odbcinst.ini
 
 See `Teradata SQLAlchemy <https://github.com/Teradata/sqlalchemy-teradata>`_.
 
@@ -811,6 +811,19 @@ in this dictionary are made available for users to use in their SQL.
         'my_crazy_macro': lambda x: x*2,
     }
 
+SQL Lab also includes a live query validation feature with pluggable backends.
+You can configure which validation implementation is used with which database
+engine by adding a block like the following to your config.py:
+
+.. code-block:: python
+     FEATURE_FLAGS = {
+         'SQL_VALIDATORS_BY_ENGINE': {
+             'presto': 'PrestoDBSQLValidator',
+         }
+     }
+
+The available validators and names can be found in `sql_validators/`.
+
 **Scheduling queries**
 
 You can optionally allow your users to schedule queries directly in SQL Lab.
@@ -967,7 +980,7 @@ Note that the above command will install Superset into ``default`` namespace
of
 Custom OAuth2 configuration
 ---------------------------
 
-Beyond FAB supported providers (github, twitter, linkedin, google, azure), its easy to connect
Superset with other OAuth2 Authorization Server implementations that support "code" authorization.

+Beyond FAB supported providers (github, twitter, linkedin, google, azure), its easy to connect
Superset with other OAuth2 Authorization Server implementations that support "code" authorization.
 
 The first step: Configure authorization in Superset ``superset_config.py``.
 
@@ -986,10 +999,10 @@ The first step: Configure authorization in Superset ``superset_config.py``.
                 },
                 'access_token_method':'POST',    # HTTP Method to call access_token_url
                 'access_token_params':{        # Additional parameters for calls to access_token_url
-                    'client_id':'myClientId'     
+                    'client_id':'myClientId'
                 },
-                'access_token_headers':{    # Additional headers for calls to access_token_url

-                    'Authorization': 'Basic Base64EncodedClientIdAndSecret' 
+                'access_token_headers':{    # Additional headers for calls to access_token_url
+                    'Authorization': 'Basic Base64EncodedClientIdAndSecret'
                 },
                 'base_url':'https://myAuthorizationServer/oauth2AuthorizationServer/',
                 'access_token_url':'https://myAuthorizationServer/oauth2AuthorizationServer/token',
@@ -997,25 +1010,25 @@ The first step: Configure authorization in Superset ``superset_config.py``.
             }
         }
     ]
-    
+
     # Will allow user self registration, allowing to create Flask users from Authorized User
     AUTH_USER_REGISTRATION = True
-    
+
     # The default user self registration role
     AUTH_USER_REGISTRATION_ROLE = "Public"
-    
+
 Second step: Create a `CustomSsoSecurityManager` that extends `SupersetSecurityManager` and
overrides `oauth_user_info`:
 
 .. code-block:: python
-    
+
     from superset.security import SupersetSecurityManager
-    
+
     class CustomSsoSecurityManager(SupersetSecurityManager):
 
         def oauth_user_info(self, provider, response=None):
             logging.debug("Oauth2 provider: {0}.".format(provider))
             if provider == 'egaSSO':
-                # As example, this line request a GET to base_url + '/' + userDetails with
Bearer  Authentication, 
+                # As example, this line request a GET to base_url + '/' + userDetails with
Bearer  Authentication,
         # and expects that authorization server checks the token, and response with user
details
                 me = self.appbuilder.sm.oauth_remotes[provider].get('userDetails').data
                 logging.debug("user_data: {0}".format(me))
@@ -1027,7 +1040,6 @@ This file must be located at the same directory than ``superset_config.py``
with
 Then we can add this two lines to ``superset_config.py``:
 
 .. code-block:: python
-  
+
   from custom_sso_security_manager import CustomSsoSecurityManager
   CUSTOM_SECURITY_MANAGER = CustomSsoSecurityManager
-
diff --git a/superset/config.py b/superset/config.py
index df14b0f..5a35f0b 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -420,6 +420,9 @@ DEFAULT_DB_ID = None
 # Timeout duration for SQL Lab synchronous queries
 SQLLAB_TIMEOUT = 30
 
+# Timeout duration for SQL Lab query validation
+SQLLAB_VALIDATION_TIMEOUT = 10
+
 # SQLLAB_DEFAULT_DBID
 SQLLAB_DEFAULT_DBID = None
 
@@ -608,6 +611,10 @@ DEFAULT_RELATIVE_END_TIME = 'today'
 # localtime (in the tz where the superset webserver is running)
 IS_EPOCH_S_TRULY_UTC = False
 
+# Configure which SQL validator to use for each engine
+SQL_VALIDATORS_BY_ENGINE = {
+    'presto': 'PrestoDBSQLValidator',
+}
 
 try:
     if CONFIG_PATH_ENV_VAR in os.environ:
diff --git a/superset/sql_validators/__init__.py b/superset/sql_validators/__init__.py
new file mode 100644
index 0000000..367aab6
--- /dev/null
+++ b/superset/sql_validators/__init__.py
@@ -0,0 +1,27 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Optional
+
+from . import base  # noqa
+from . import presto_db  # noqa
+from .base import SQLValidationAnnotation  # noqa
+
+
+def get_validator_by_name(name: str) -> Optional[base.BaseSQLValidator]:
+    return {
+        'PrestoDBSQLValidator': presto_db.PrestoDBSQLValidator,
+    }.get(name)
diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py
new file mode 100644
index 0000000..437001b
--- /dev/null
+++ b/superset/sql_validators/base.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: disable=too-few-public-methods
+
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional,
+)
+
+
+class SQLValidationAnnotation:
+    """Represents a single annotation (error/warning) in an SQL querytext"""
+    def __init__(
+            self,
+            message: str,
+            line_number: Optional[int],
+            start_column: Optional[int],
+            end_column: Optional[int],
+    ):
+        self.message = message
+        self.line_number = line_number
+        self.start_column = start_column
+        self.end_column = end_column
+
+    def to_dict(self) -> Dict:
+        """Return a dictionary representation of this annotation"""
+        return {
+            'line_number': self.line_number,
+            'start_column': self.start_column,
+            'end_column': self.end_column,
+            'message': self.message,
+        }
+
+
+class BaseSQLValidator:
+    """BaseSQLValidator defines the interface for checking that a given sql
+    query is valid for a given database engine."""
+
+    name = 'BaseSQLValidator'
+
+    @classmethod
+    def validate(
+            cls,
+            sql: str,
+            schema: str,
+            database: Any,
+    ) -> List[SQLValidationAnnotation]:
+        """Check that the given SQL querystring is valid for the given engine"""
+        raise NotImplementedError
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
new file mode 100644
index 0000000..87c2d8e
--- /dev/null
+++ b/superset/sql_validators/presto_db.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from contextlib import closing
+import logging
+import time
+from typing import (
+    Any,
+    Dict,
+    List,
+    Optional,
+)
+
+from flask import g
+from pyhive.exc import DatabaseError
+
+from superset import app, security_manager
+from superset.sql_parse import ParsedQuery
+from superset.sql_validators.base import (
+    BaseSQLValidator,
+    SQLValidationAnnotation,
+)
+from superset.utils.core import sources
+
+MAX_ERROR_ROWS = 10
+
+config = app.config
+
+
+class PrestoSQLValidationError(Exception):
+    """Error in the process of asking Presto to validate SQL querytext"""
+
+
+class PrestoDBSQLValidator(BaseSQLValidator):
+    """Validate SQL queries using Presto's built-in EXPLAIN subtype"""
+
+    name = 'PrestoDBSQLValidator'
+
+    @classmethod
+    def validate_statement(
+            cls,
+            statement,
+            database,
+            cursor,
+            user_name,
+    ) -> Optional[SQLValidationAnnotation]:
+        # pylint: disable=too-many-locals
+        db_engine_spec = database.db_engine_spec
+        parsed_query = ParsedQuery(statement)
+        sql = parsed_query.stripped()
+
+        # Hook to allow environment-specific mutation (usually comments) to the SQL
+        # pylint: disable=invalid-name
+        SQL_QUERY_MUTATOR = config.get('SQL_QUERY_MUTATOR')
+        if SQL_QUERY_MUTATOR:
+            sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database)
+
+        # Transform the final statement to an explain call before sending it on
+        # to presto to validate
+        sql = f'EXPLAIN (TYPE VALIDATE) {sql}'
+
+        # Invoke the query against presto. NB this deliberately doesn't use the
+        # engine spec's handle_cursor implementation since we don't record
+        # these EXPLAIN queries done in validation as proper Query objects
+        # in the superset ORM.
+        try:
+            db_engine_spec.execute(cursor, sql)
+            polled = cursor.poll()
+            while polled:
+                logging.info('polling presto for validation progress')
+                stats = polled.get('stats', {})
+                if stats:
+                    state = stats.get('state')
+                    if state == 'FINISHED':
+                        break
+                time.sleep(0.2)
+                polled = cursor.poll()
+            db_engine_spec.fetch_data(cursor, MAX_ERROR_ROWS)
+            return None
+        except DatabaseError as db_error:
+            # The pyhive presto client yields EXPLAIN (TYPE VALIDATE) responses
+            # as though they were normal queries. In other words, it doesn't
+            # know that errors here are not exceptional. To map this back to
+            # ordinary control flow, we have to trap the category of exception
+            # raised by the underlying client, match the exception arguments
+            # pyhive provides against the shape of dictionary for a presto query
+            # invalid error, and restructure that error as an annotation we can
+            # return up.
+
+            # Confirm the first element in the DatabaseError constructor is a
+            # dictionary with error information. This is currently provided by
+            # the pyhive client, but may break if their interface changes when
+            # we update at some point in the future.
+            if not db_error.args or not isinstance(db_error.args[0], dict):
+                raise PrestoSQLValidationError(
+                    'The pyhive presto client returned an unhandled '
+                    'database error.',
+                ) from db_error
+            error_args: Dict[str, Any] = db_error.args[0]
+
+            # Confirm the two fields we need to be able to present an annotation
+            # are present in the error response -- a message, and a location.
+            if 'message' not in error_args:
+                raise PrestoSQLValidationError(
+                    'The pyhive presto client did not report an error message',
+                ) from db_error
+            if 'errorLocation' not in error_args:
+                raise PrestoSQLValidationError(
+                    'The pyhive presto client did not report an error location',
+                ) from db_error
+
+            # Pylint is confused about the type of error_args, despite the hints
+            # and checks above.
+            # pylint: disable=invalid-sequence-index
+            message = error_args['message']
+            err_loc = error_args['errorLocation']
+            line_number = err_loc.get('lineNumber', None)
+            start_column = err_loc.get('columnNumber', None)
+            end_column = err_loc.get('columnNumber', None)
+
+            return SQLValidationAnnotation(
+                message=message,
+                line_number=line_number,
+                start_column=start_column,
+                end_column=end_column,
+            )
+        except Exception as e:
+            logging.exception(f'Unexpected error running validation query: {e}')
+            raise e
+
+    @classmethod
+    def validate(
+            cls,
+            sql: str,
+            schema: str,
+            database: Any,
+    ) -> List[SQLValidationAnnotation]:
+        """
+        Presto supports query-validation queries by running them with a
+        prepended explain.
+
+        For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
+        VALIDATE) SELECT 1 FROM default.mytable.
+        """
+        user_name = g.user.username if g.user else None
+        parsed_query = ParsedQuery(sql)
+        statements = parsed_query.get_statements()
+
+        logging.info(f'Validating {len(statements)} statement(s)')
+        engine = database.get_sqla_engine(
+            schema=schema,
+            nullpool=True,
+            user_name=user_name,
+            source=sources.get('sql_lab', None),
+        )
+        # Sharing a single connection and cursor across the
+        # execution of all statements (if many)
+        annotations: List[SQLValidationAnnotation] = []
+        with closing(engine.raw_connection()) as conn:
+            with closing(conn.cursor()) as cursor:
+                for statement in parsed_query.get_statements():
+                    annotation = cls.validate_statement(
+                        statement,
+                        database,
+                        cursor,
+                        user_name,
+                    )
+                    if annotation:
+                        annotations.append(annotation)
+        logging.debug(f'Validation found {len(annotations)} error(s)')
+
+        return annotations
diff --git a/superset/views/core.py b/superset/views/core.py
index e22acb7..eb25cd0 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -44,7 +44,7 @@ from werkzeug.routing import BaseConverter
 from werkzeug.utils import secure_filename
 
 from superset import (
-    app, appbuilder, cache, conf, db, results_backend,
+    app, appbuilder, cache, conf, db, get_feature_flags, results_backend,
     security_manager, sql_lab, viz)
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.connectors.sqla.models import AnnotationDatasource, SqlaTable
@@ -56,6 +56,7 @@ import superset.models.core as models
 from superset.models.sql_lab import Query
 from superset.models.user_attributes import UserAttribute
 from superset.sql_parse import ParsedQuery
+from superset.sql_validators import get_validator_by_name
 from superset.utils import core as utils
 from superset.utils import dashboard_import_export
 from superset.utils.dates import now_as_float
@@ -2517,6 +2518,72 @@ class Superset(BaseSupersetView):
         return self.json_response('OK')
 
     @has_access_api
+    @expose('/validate_sql_json/', methods=['POST', 'GET'])
+    @log_this
+    def validate_sql_json(self):
+        """Validates that arbitrary sql is acceptable for the given database.
+        Returns a list of error/warning annotations as json.
+        """
+        sql = request.form.get('sql')
+        database_id = request.form.get('database_id')
+        schema = request.form.get('schema') or None
+        template_params = json.loads(
+            request.form.get('templateParams') or '{}')
+
+        if len(template_params) > 0:
+            # TODO: factor the Database object out of template rendering
+            #       or provide it as mydb so we can render template params
+            #       without having to also persist a Query ORM object.
+            return json_error_response(
+                'SQL validation does not support template parameters',
+                status=400)
+
+        session = db.session()
+        mydb = session.query(models.Database).filter_by(id=database_id).first()
+        if not mydb:
+            json_error_response(
+                'Database with id {} is missing.'.format(database_id),
+                status=400,
+            )
+
+        spec = mydb.db_engine_spec
+        validators_by_engine = get_feature_flags().get(
+            'SQL_VALIDATORS_BY_ENGINE')
+        if not validators_by_engine or spec.engine not in validators_by_engine:
+            return json_error_response(
+                'no SQL validator is configured for {}'.format(spec.engine),
+                status=400)
+        validator_name = validators_by_engine[spec.engine]
+        validator = get_validator_by_name(validator_name)
+        if not validator:
+            return json_error_response(
+                'No validator named {} found (configured for the {} engine)'
+                .format(validator_name, spec.engine))
+
+        try:
+            timeout = config.get('SQLLAB_VALIDATION_TIMEOUT')
+            timeout_msg = (
+                f'The query exceeded the {timeout} seconds timeout.')
+            with utils.timeout(seconds=timeout,
+                               error_message=timeout_msg):
+                errors = validator.validate(sql, schema, mydb)
+            payload = json.dumps(
+                [err.to_dict() for err in errors],
+                default=utils.pessimistic_json_iso_dttm_ser,
+                ignore_nan=True,
+                encoding=None,
+            )
+            return json_success(payload)
+        except Exception as e:
+            logging.exception(e)
+            msg = _(
+                'Failed to validate your SQL query text. Please check that '
+                f'you have configured the {validator.name} validator '
+                'correctly and that any services it depends on are up. '
+                f'Exception: {e}')
+            return json_error_response(f'{msg}')
+
+    @has_access_api
     @expose('/sql_json/', methods=['POST', 'GET'])
     @log_this
     def sql_json(self):
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 8555915..6de082a 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -190,6 +190,21 @@ class SupersetTestCase(unittest.TestCase):
             raise Exception('run_sql failed')
         return resp
 
+    def validate_sql(self, sql, client_id=None, user_name=None,
+                     raise_on_error=False):
+        if user_name:
+            self.logout()
+            self.login(username=(user_name if user_name else 'admin'))
+        dbid = get_main_database(db.session).id
+        resp = self.get_json_resp(
+            '/superset/validate_sql_json/',
+            raise_on_error=False,
+            data=dict(database_id=dbid, sql=sql, client_id=client_id),
+        )
+        if raise_on_error and 'error' in resp:
+            raise Exception('validate_sql failed')
+        return resp
+
     @patch.dict('superset._feature_flags', {'FOO': True}, clear=True)
     def test_existing_feature_flags(self):
         self.assertTrue(is_feature_enabled('FOO'))
diff --git a/tests/sql_validator_tests.py b/tests/sql_validator_tests.py
new file mode 100644
index 0000000..0e1310c
--- /dev/null
+++ b/tests/sql_validator_tests.py
@@ -0,0 +1,210 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Unit tests for Sql Lab"""
+import unittest
+from unittest.mock import (
+    MagicMock,
+    patch,
+)
+
+from pyhive.exc import DatabaseError
+
+from superset import app
+from superset.sql_validators import SQLValidationAnnotation
+from superset.sql_validators.base import BaseSQLValidator
+from superset.sql_validators.presto_db import (
+    PrestoDBSQLValidator,
+    PrestoSQLValidationError,
+)
+from .base_tests import SupersetTestCase
+
+PRESTO_TEST_FEATURE_FLAGS = {
+    'SQL_VALIDATORS_BY_ENGINE': {
+        'presto': 'PrestoDBSQLValidator',
+        'sqlite': 'PrestoDBSQLValidator',
+        'postgresql': 'PrestoDBSQLValidator',
+        'mysql': 'PrestoDBSQLValidator',
+    },
+}
+
+
+class SqlValidatorEndpointTests(SupersetTestCase):
+    """Testing for Sql Lab querytext validation endpoint"""
+
+    def tearDown(self):
+        self.logout()
+
+    def test_validate_sql_endpoint_noconfig(self):
+        """Assert that validate_sql_json errors out when no validators are
+        configured for any db"""
+        self.login('admin')
+
+        app.config['SQL_VALIDATORS_BY_ENGINE'] = {}
+
+        resp = self.validate_sql(
+            'SELECT * FROM ab_user',
+            client_id='1',
+            raise_on_error=False,
+        )
+        self.assertIn('error', resp)
+        self.assertIn('no SQL validator is configured', resp['error'])
+
+    @patch('superset.views.core.get_validator_by_name')
+    @patch.dict('superset._feature_flags',
+                PRESTO_TEST_FEATURE_FLAGS,
+                clear=True)
+    def test_validate_sql_endpoint_mocked(self, get_validator_by_name):
+        """Assert that, with a mocked validator, annotations make it back out
+        from the validate_sql_json endpoint as a list of json dictionaries"""
+        self.login('admin')
+
+        validator = MagicMock()
+        get_validator_by_name.return_value = validator
+        validator.validate.return_value = [
+            SQLValidationAnnotation(
+                message="I don't know what I expected, but it wasn't this",
+                line_number=4,
+                start_column=12,
+                end_column=42,
+            ),
+        ]
+
+        resp = self.validate_sql(
+            'SELECT * FROM somewhere_over_the_rainbow',
+            client_id='1',
+            raise_on_error=False,
+        )
+
+        self.assertEqual(1, len(resp))
+        self.assertIn('expected,', resp[0]['message'])
+
+    @patch('superset.views.core.get_validator_by_name')
+    @patch.dict('superset._feature_flags',
+                PRESTO_TEST_FEATURE_FLAGS,
+                clear=True)
+    def test_validate_sql_endpoint_failure(self, get_validator_by_name):
+        """Assert that validate_sql_json errors out when the selected validator
+        raises an unexpected exception"""
+        self.login('admin')
+
+        validator = MagicMock()
+        get_validator_by_name.return_value = validator
+        validator.validate.side_effect = Exception('Kaboom!')
+
+        resp = self.validate_sql(
+            'SELECT * FROM ab_user',
+            client_id='1',
+            raise_on_error=False,
+        )
+        self.assertIn('error', resp)
+        self.assertIn('Kaboom!', resp['error'])
+
+
+class BaseValidatorTests(SupersetTestCase):
+    """Testing for the base sql validator"""
+    def setUp(self):
+        self.validator = BaseSQLValidator
+
+    def test_validator_excepts(self):
+        with self.assertRaises(NotImplementedError):
+            self.validator.validate(None, None, None)
+
+
+class PrestoValidatorTests(SupersetTestCase):
+    """Testing for the prestodb sql validator"""
+    def setUp(self):
+        self.validator = PrestoDBSQLValidator
+        self.database = MagicMock()  # noqa
+        self.database_engine = self.database.get_sqla_engine.return_value
+        self.database_conn = self.database_engine.raw_connection.return_value
+        self.database_cursor = self.database_conn.cursor.return_value
+        self.database_cursor.poll.return_value = None
+
+    def tearDown(self):
+        self.logout()
+
+    PRESTO_ERROR_TEMPLATE = {
+        'errorLocation': {
+            'lineNumber': 10,
+            'columnNumber': 20,
+        },
+        'message': "your query isn't how I like it",
+    }
+
+    @patch('superset.sql_validators.presto_db.g')
+    def test_validator_success(self, flask_g):
+        flask_g.user.username = 'nobody'
+        sql = 'SELECT 1 FROM default.notarealtable'
+        schema = 'default'
+
+        errors = self.validator.validate(sql, schema, self.database)
+
+        self.assertEqual([], errors)
+
+    @patch('superset.sql_validators.presto_db.g')
+    def test_validator_db_error(self, flask_g):
+        flask_g.user.username = 'nobody'
+        sql = 'SELECT 1 FROM default.notarealtable'
+        schema = 'default'
+
+        fetch_fn = self.database.db_engine_spec.fetch_data
+        fetch_fn.side_effect = DatabaseError('dummy db error')
+
+        with self.assertRaises(PrestoSQLValidationError):
+            self.validator.validate(sql, schema, self.database)
+
+    @patch('superset.sql_validators.presto_db.g')
+    def test_validator_unexpected_error(self, flask_g):
+        flask_g.user.username = 'nobody'
+        sql = 'SELECT 1 FROM default.notarealtable'
+        schema = 'default'
+
+        fetch_fn = self.database.db_engine_spec.fetch_data
+        fetch_fn.side_effect = Exception('a mysterious failure')
+
+        with self.assertRaises(Exception):
+            self.validator.validate(sql, schema, self.database)
+
+    @patch('superset.sql_validators.presto_db.g')
+    def test_validator_query_error(self, flask_g):
+        flask_g.user.username = 'nobody'
+        sql = 'SELECT 1 FROM default.notarealtable'
+        schema = 'default'
+
+        fetch_fn = self.database.db_engine_spec.fetch_data
+        fetch_fn.side_effect = DatabaseError(self.PRESTO_ERROR_TEMPLATE)
+
+        errors = self.validator.validate(sql, schema, self.database)
+
+        self.assertEqual(1, len(errors))
+
+    def test_validate_sql_endpoint(self):
+        self.login('admin')
+        # NB this is effectively an integration test -- when there's a default
+        #    validator for sqlite, this test will fail because the validator
+        #    will no longer error out.
+        resp = self.validate_sql(
+            'SELECT * FROM ab_user',
+            client_id='1',
+            raise_on_error=False,
+        )
+        self.assertIn('error', resp)
+        self.assertIn('no SQL validator is configured', resp['error'])
+
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message