superset-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ville...@apache.org
Subject [incubator-superset] branch master updated: Add support for period character in table names (#7453)
Date Sun, 26 May 2019 03:13:33 GMT
This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new f7d3413  Add support for period character in table names (#7453)
f7d3413 is described below

commit f7d3413a501d8b643318fe7c0641eba608a079f5
Author: Ville Brofeldt <33317356+villebro@users.noreply.github.com>
AuthorDate: Sun May 26 06:13:16 2019 +0300

    Add support for period character in table names (#7453)
    
    * Move schema name handling in table names from frontend to backend
    
    * Rename all_schema_names to get_all_schema_names
    
    * Fix js errors
    
    * Fix additional js linting errors
    
    * Refactor datasource getters and fix linting errors
    
    * Update js unit tests
    
    * Add python unit test for get_table_names method
    
    * Add python unit test for get_table_names method
    
    * Fix js linting error
---
 .../javascripts/components/TableSelector_spec.jsx  | 11 +--
 .../assets/spec/javascripts/sqllab/fixtures.js     |  6 +-
 .../src/SqlLab/components/SqlEditorLeftBar.jsx     | 15 ++--
 superset/assets/src/components/TableSelector.jsx   |  9 +--
 superset/cli.py                                    |  4 +-
 superset/db_engine_specs.py                        | 84 +++++++++++-----------
 superset/models/core.py                            | 57 ++++++---------
 superset/security.py                               |  6 +-
 superset/utils/core.py                             |  7 +-
 superset/views/core.py                             | 67 +++++++++--------
 tests/db_engine_specs_test.py                      | 19 +++++
 11 files changed, 148 insertions(+), 137 deletions(-)

diff --git a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
index 70e2cca..1366592 100644
--- a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
+++ b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx
@@ -208,19 +208,20 @@ describe('TableSelector', () => {
 
     it('test 1', () => {
       wrapper.instance().changeTable({
-        value: 'birth_names',
+        value: { schema: 'main', table: 'birth_names' },
         label: 'birth_names',
       });
       expect(wrapper.state().tableName).toBe('birth_names');
     });
 
-    it('test 2', () => {
+    it('should call onTableChange with schema from table object', () => {
+      wrapper.setProps({ schema: null });
       wrapper.instance().changeTable({
-        value: 'main.my_table',
-        label: 'my_table',
+        value: { schema: 'other_schema', table: 'my_table' },
+        label: 'other_schema.my_table',
       });
       expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table');
-      expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main');
+      expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema');
     });
   });
 
diff --git a/superset/assets/spec/javascripts/sqllab/fixtures.js b/superset/assets/spec/javascripts/sqllab/fixtures.js
index 6471be1..f43f43f 100644
--- a/superset/assets/spec/javascripts/sqllab/fixtures.js
+++ b/superset/assets/spec/javascripts/sqllab/fixtures.js
@@ -329,15 +329,15 @@ export const databases = {
 export const tables = {
   options: [
     {
-      value: 'birth_names',
+      value: { schema: 'main', table: 'birth_names' },
       label: 'birth_names',
     },
     {
-      value: 'energy_usage',
+      value: { schema: 'main', table: 'energy_usage' },
       label: 'energy_usage',
     },
     {
-      value: 'wb_health_population',
+      value: { schema: 'main', table: 'wb_health_population' },
       label: 'wb_health_population',
     },
   ],
diff --git a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
index 9d0796c..43ea487 100644
--- a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
+++ b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
@@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends React.PureComponent {
       this.setState({ tableName: '' });
       return;
     }
-    const namePieces = tableOpt.value.split('.');
-    let tableName = namePieces[0];
-    let schemaName = this.props.queryEditor.schema;
-    if (namePieces.length === 1) {
-      this.setState({ tableName });
-    } else {
-      schemaName = namePieces[0];
-      tableName = namePieces[1];
-      this.setState({ tableName });
-      this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
-    }
+    const schemaName = tableOpt.value.schema;
+    const tableName = tableOpt.value.table;
+    this.setState({ tableName });
+    this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
     this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
   }
 
diff --git a/superset/assets/src/components/TableSelector.jsx b/superset/assets/src/components/TableSelector.jsx
index ba2cebb..940e1c2 100644
--- a/superset/assets/src/components/TableSelector.jsx
+++ b/superset/assets/src/components/TableSelector.jsx
@@ -170,13 +170,8 @@ export default class TableSelector extends React.PureComponent {
       this.setState({ tableName: '' });
       return;
     }
-    const namePieces = tableOpt.value.split('.');
-    let tableName = namePieces[0];
-    let schemaName = this.props.schema;
-    if (namePieces.length > 1) {
-      schemaName = namePieces[0];
-      tableName = namePieces[1];
-    }
+    const schemaName = tableOpt.value.schema;
+    const tableName = tableOpt.value.table;
     if (this.props.tableNameSticky) {
       this.setState({ tableName }, this.onChange);
     }
diff --git a/superset/cli.py b/superset/cli.py
index 7b441b4..edb0102 100755
--- a/superset/cli.py
+++ b/superset/cli.py
@@ -288,9 +288,9 @@ def update_datasources_cache():
         if database.allow_multi_schema_metadata_fetch:
             print('Fetching {} datasources ...'.format(database.name))
             try:
-                database.all_table_names_in_database(
+                database.get_all_table_names_in_database(
                     force=True, cache=True, cache_timeout=24 * 60 * 60)
-                database.all_view_names_in_database(
+                database.get_all_view_names_in_database(
                     force=True, cache=True, cache_timeout=24 * 60 * 60)
             except Exception as e:
                 print('{}'.format(str(e)))
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 35a591f..67aba12 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -122,6 +122,7 @@ class BaseEngineSpec(object):
     force_column_alias_quotes = False
     arraysize = 0
     max_column_name_length = 0
+    try_remove_schema_from_table_name = True
 
     @classmethod
     def get_time_expr(cls, expr, pdf, time_grain, grain):
@@ -279,33 +280,32 @@ class BaseEngineSpec(object):
         return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
 
     @classmethod
-    def fetch_result_sets(cls, db, datasource_type):
-        """Returns a list of tables [schema1.table1, schema2.table2, ...]
+    def get_all_datasource_names(cls, db, datasource_type: str) \
+            -> List[utils.DatasourceName]:
+        """Returns a list of all tables or views in database.
 
-        Datasource_type can be 'table' or 'view'.
-        Empty schema corresponds to the list of full names of the all
-        tables or views: <schema>.<result_set_name>.
+        :param db: Database instance
+        :param datasource_type: Datasource_type can be 'table' or 'view'
+        :return: List of all datasources in database or schema
         """
-        schemas = db.all_schema_names(cache=db.schema_cache_enabled,
-                                      cache_timeout=db.schema_cache_timeout,
-                                      force=True)
-        all_result_sets = []
+        schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
+                                          cache_timeout=db.schema_cache_timeout,
+                                          force=True)
+        all_datasources: List[utils.DatasourceName] = []
         for schema in schemas:
             if datasource_type == 'table':
-                all_datasource_names = db.all_table_names_in_schema(
+                all_datasources += db.get_all_table_names_in_schema(
                     schema=schema, force=True,
                     cache=db.table_cache_enabled,
                     cache_timeout=db.table_cache_timeout)
             elif datasource_type == 'view':
-                all_datasource_names = db.all_view_names_in_schema(
+                all_datasources += db.get_all_view_names_in_schema(
                     schema=schema, force=True,
                     cache=db.table_cache_enabled,
                     cache_timeout=db.table_cache_timeout)
             else:
                 raise Exception(f'Unsupported datasource_type: {datasource_type}')
-            all_result_sets += [
-                '{}.{}'.format(schema, t) for t in all_datasource_names]
-        return all_result_sets
+        return all_datasources
 
     @classmethod
     def handle_cursor(cls, cursor, query, session):
@@ -352,11 +352,17 @@ class BaseEngineSpec(object):
 
     @classmethod
     def get_table_names(cls, inspector, schema):
-        return sorted(inspector.get_table_names(schema))
+        tables = inspector.get_table_names(schema)
+        if schema and cls.try_remove_schema_from_table_name:
+            tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
+        return sorted(tables)
 
     @classmethod
     def get_view_names(cls, inspector, schema):
-        return sorted(inspector.get_view_names(schema))
+        views = inspector.get_view_names(schema)
+        if schema and cls.try_remove_schema_from_table_name:
+            views = [re.sub(f'^{schema}\\.', '', view) for view in views]
+        return sorted(views)
 
     @classmethod
     def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list:
@@ -528,6 +534,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
 class PostgresEngineSpec(PostgresBaseEngineSpec):
     engine = 'postgresql'
     max_column_name_length = 63
+    try_remove_schema_from_table_name = False
 
     @classmethod
     def get_table_names(cls, inspector, schema):
@@ -685,29 +692,25 @@ class SqliteEngineSpec(BaseEngineSpec):
         return "datetime({col}, 'unixepoch')"
 
     @classmethod
-    def fetch_result_sets(cls, db, datasource_type):
-        schemas = db.all_schema_names(cache=db.schema_cache_enabled,
-                                      cache_timeout=db.schema_cache_timeout,
-                                      force=True)
-        all_result_sets = []
+    def get_all_datasource_names(cls, db, datasource_type: str) \
+            -> List[utils.DatasourceName]:
+        schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
+                                          cache_timeout=db.schema_cache_timeout,
+                                          force=True)
         schema = schemas[0]
         if datasource_type == 'table':
-            all_datasource_names = db.all_table_names_in_schema(
+            return db.get_all_table_names_in_schema(
                 schema=schema, force=True,
                 cache=db.table_cache_enabled,
                 cache_timeout=db.table_cache_timeout)
         elif datasource_type == 'view':
-            all_datasource_names = db.all_view_names_in_schema(
+            return db.get_all_view_names_in_schema(
                 schema=schema, force=True,
                 cache=db.table_cache_enabled,
                 cache_timeout=db.table_cache_timeout)
         else:
             raise Exception(f'Unsupported datasource_type: {datasource_type}')
 
-        all_result_sets += [
-            '{}.{}'.format(schema, t) for t in all_datasource_names]
-        return all_result_sets
-
     @classmethod
     def convert_dttm(cls, target_type, dttm):
         iso = dttm.isoformat().replace('T', ' ')
@@ -1107,24 +1110,19 @@ class PrestoEngineSpec(BaseEngineSpec):
         return 'from_unixtime({col})'
 
     @classmethod
-    def fetch_result_sets(cls, db, datasource_type):
-        """Returns a list of tables [schema1.table1, schema2.table2, ...]
-
-        Datasource_type can be 'table' or 'view'.
-        Empty schema corresponds to the list of full names of the all
-        tables or views: <schema>.<result_set_name>.
-        """
-        result_set_df = db.get_df(
+    def get_all_datasource_names(cls, db, datasource_type: str) \
+            -> List[utils.DatasourceName]:
+        datasource_df = db.get_df(
             """SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
                ORDER BY concat(table_schema, '.', table_name)""".format(
                 datasource_type.upper(),
             ),
             None)
-        result_sets = []
-        for unused, row in result_set_df.iterrows():
-            result_sets.append('{}.{}'.format(
-                row['table_schema'], row['table_name']))
-        return result_sets
+        datasource_names: List[utils.DatasourceName] = []
+        for unused, row in datasource_df.iterrows():
+            datasource_names.append(utils.DatasourceName(
+                schema=row['table_schema'], table=row['table_name']))
+        return datasource_names
 
     @classmethod
     def extra_table_metadata(cls, database, table_name, schema_name):
@@ -1385,9 +1383,9 @@ class HiveEngineSpec(PrestoEngineSpec):
         hive.Cursor.fetch_logs = patched_hive.fetch_logs
 
     @classmethod
-    def fetch_result_sets(cls, db, datasource_type):
-        return BaseEngineSpec.fetch_result_sets(
-            db, datasource_type)
+    def get_all_datasource_names(cls, db, datasource_type: str) \
+            -> List[utils.DatasourceName]:
+        return BaseEngineSpec.get_all_datasource_names(db, datasource_type)
 
     @classmethod
     def fetch_data(cls, cursor, limit):
diff --git a/superset/models/core.py b/superset/models/core.py
index e16a234..047a3dd 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -23,6 +23,7 @@ import functools
 import json
 import logging
 import textwrap
+from typing import List
 
 from flask import escape, g, Markup, request
 from flask_appbuilder import Model
@@ -65,6 +66,7 @@ metadata = Model.metadata  # pylint: disable=no-member
 
 PASSWORD_MASK = 'X' * 10
 
+
 def set_related_perm(mapper, connection, target):  # noqa
     src_class = target.cls_model
     id_ = target.datasource_id
@@ -184,7 +186,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
             description=self.description,
             cache_timeout=self.cache_timeout)
 
-    @datasource.getter
+    @datasource.getter  # type: ignore
     @utils.memoized
     def get_datasource(self):
         return (
@@ -210,7 +212,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
         datasource = self.datasource
         return datasource.url if datasource else None
 
-    @property
+    @property  # type: ignore
     @utils.memoized
     def viz(self):
         d = json.loads(self.params)
@@ -930,100 +932,87 @@ class Database(Model, AuditMixinNullable, ImportMixin):
     @cache_util.memoized_func(
         key=lambda *args, **kwargs: 'db:{}:schema:None:table_list',
         attribute_in_key='id')
-    def all_table_names_in_database(self, cache=False,
-                                    cache_timeout=None, force=False):
+    def get_all_table_names_in_database(self, cache: bool = False,
+                                        cache_timeout: bool = None,
+                                        force=False) -> List[utils.DatasourceName]:
         """Parameters need to be passed as keyword arguments."""
         if not self.allow_multi_schema_metadata_fetch:
             return []
-        return self.db_engine_spec.fetch_result_sets(self, 'table')
+        return self.db_engine_spec.get_all_datasource_names(self, 'table')
 
     @cache_util.memoized_func(
         key=lambda *args, **kwargs: 'db:{}:schema:None:view_list',
         attribute_in_key='id')
-    def all_view_names_in_database(self, cache=False,
-                                   cache_timeout=None, force=False):
+    def get_all_view_names_in_database(self, cache: bool = False,
+                                       cache_timeout: bool = None,
+                                       force: bool = False) -> List[utils.DatasourceName]:
         """Parameters need to be passed as keyword arguments."""
         if not self.allow_multi_schema_metadata_fetch:
             return []
-        return self.db_engine_spec.fetch_result_sets(self, 'view')
+        return self.db_engine_spec.get_all_datasource_names(self, 'view')
 
     @cache_util.memoized_func(
         key=lambda *args, **kwargs: 'db:{{}}:schema:{}:table_list'.format(
             kwargs.get('schema')),
         attribute_in_key='id')
-    def all_table_names_in_schema(self, schema, cache=False,
-                                  cache_timeout=None, force=False):
+    def get_all_table_names_in_schema(self, schema: str, cache: bool = False,
+                                      cache_timeout: int = None, force: bool = False):
         """Parameters need to be passed as keyword arguments.
 
         For unused parameters, they are referenced in
         cache_util.memoized_func decorator.
 
         :param schema: schema name
-        :type schema: str
         :param cache: whether cache is enabled for the function
-        :type cache: bool
         :param cache_timeout: timeout in seconds for the cache
-        :type cache_timeout: int
         :param force: whether to force refresh the cache
-        :type force: bool
-        :return: table list
-        :rtype: list
+        :return: list of tables
         """
-        tables = []
         try:
             tables = self.db_engine_spec.get_table_names(
                 inspector=self.inspector, schema=schema)
+            return [utils.DatasourceName(table=table, schema=schema) for table in tables]
         except Exception as e:
             logging.exception(e)
-        return tables
 
     @cache_util.memoized_func(
         key=lambda *args, **kwargs: 'db:{{}}:schema:{}:view_list'.format(
             kwargs.get('schema')),
         attribute_in_key='id')
-    def all_view_names_in_schema(self, schema, cache=False,
-                                 cache_timeout=None, force=False):
+    def get_all_view_names_in_schema(self, schema: str, cache: bool = False,
+                                     cache_timeout: int = None, force: bool = False):
         """Parameters need to be passed as keyword arguments.
 
         For unused parameters, they are referenced in
         cache_util.memoized_func decorator.
 
         :param schema: schema name
-        :type schema: str
         :param cache: whether cache is enabled for the function
-        :type cache: bool
         :param cache_timeout: timeout in seconds for the cache
-        :type cache_timeout: int
         :param force: whether to force refresh the cache
-        :type force: bool
-        :return: view list
-        :rtype: list
+        :return: list of views
         """
-        views = []
         try:
             views = self.db_engine_spec.get_view_names(
                 inspector=self.inspector, schema=schema)
+            return [utils.DatasourceName(table=view, schema=schema) for view in views]
         except Exception as e:
             logging.exception(e)
-        return views
 
     @cache_util.memoized_func(
         key=lambda *args, **kwargs: 'db:{}:schema_list',
         attribute_in_key='id')
-    def all_schema_names(self, cache=False, cache_timeout=None, force=False):
+    def get_all_schema_names(self, cache: bool = False, cache_timeout: int = None,
+                             force: bool = False) -> List[str]:
         """Parameters need to be passed as keyword arguments.
 
         For unused parameters, they are referenced in
         cache_util.memoized_func decorator.
 
         :param cache: whether cache is enabled for the function
-        :type cache: bool
         :param cache_timeout: timeout in seconds for the cache
-        :type cache_timeout: int
         :param force: whether to force refresh the cache
-        :type force: bool
         :return: schema list
-        :rtype: list
         """
         return self.db_engine_spec.get_schema_names(self.inspector)
 
@@ -1232,7 +1221,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
     def datasource(self):
         return self.get_datasource
 
-    @datasource.getter
+    @datasource.getter  # type: ignore
     @utils.memoized
     def get_datasource(self):
         # pylint: disable=no-member
diff --git a/superset/security.py b/superset/security.py
index f8ae057..89eab5d 100644
--- a/superset/security.py
+++ b/superset/security.py
@@ -17,6 +17,7 @@
 # pylint: disable=C,R,W
 """A set of constants and methods to manage permissions and security"""
 import logging
+from typing import List
 
 from flask import g
 from flask_appbuilder.security.sqla import models as ab_models
@@ -26,6 +27,7 @@ from sqlalchemy import or_
 from superset import sql_parse
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.exceptions import SupersetSecurityException
+from superset.utils.core import DatasourceName
 
 
 class SupersetSecurityManager(SecurityManager):
@@ -240,7 +242,9 @@ class SupersetSecurityManager(SecurityManager):
                     subset.add(t.schema)
         return sorted(list(subset))
 
-    def accessible_by_user(self, database, datasource_names, schema=None):
+    def get_datasources_accessible_by_user(
+            self, database, datasource_names: List[DatasourceName],
+            schema: str = None) -> List[DatasourceName]:
         from superset import db
         if self.database_access(database) or self.all_datasource_access():
             return datasource_names
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 3b41457..2defa70 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -32,7 +32,7 @@ import signal
 import smtplib
 import sys
 from time import struct_time
-from typing import List, Optional, Tuple
+from typing import List, NamedTuple, Optional, Tuple
 from urllib.parse import unquote_plus
 import uuid
 import zlib
@@ -1100,3 +1100,8 @@ def MediumText() -> Variant:
 
 def shortid() -> str:
     return '{}'.format(uuid.uuid4())[-12:]
+
+
+class DatasourceName(NamedTuple):
+    table: str
+    schema: str
diff --git a/superset/views/core.py b/superset/views/core.py
index 883a2d9..0a6ddef 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -22,7 +22,7 @@ import logging
 import os
 import re
 import traceback
-from typing import List  # noqa: F401
+from typing import Dict, List  # noqa: F401
 from urllib import parse
 
 from flask import (
@@ -311,7 +311,7 @@ class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): 
# noqa
         db.set_sqlalchemy_uri(db.sqlalchemy_uri)
         security_manager.add_permission_view_menu('database_access', db.perm)
         # adding a new database we always want to force refresh schema list
-        for schema in db.all_schema_names():
+        for schema in db.get_all_schema_names():
             security_manager.add_permission_view_menu(
                 'schema_access', security_manager.get_schema_perm(db, schema))
 
@@ -1545,7 +1545,7 @@ class Superset(BaseSupersetView):
             .first()
         )
         if database:
-            schemas = database.all_schema_names(
+            schemas = database.get_all_schema_names(
                 cache=database.schema_cache_enabled,
                 cache_timeout=database.schema_cache_timeout,
                 force=force_refresh)
@@ -1570,50 +1570,57 @@ class Superset(BaseSupersetView):
         database = db.session.query(models.Database).filter_by(id=db_id).one()
 
         if schema:
-            table_names = database.all_table_names_in_schema(
+            tables = database.get_all_table_names_in_schema(
                 schema=schema, force=force_refresh,
                 cache=database.table_cache_enabled,
-                cache_timeout=database.table_cache_timeout)
-            view_names = database.all_view_names_in_schema(
+                cache_timeout=database.table_cache_timeout) or []
+            views = database.get_all_view_names_in_schema(
                 schema=schema, force=force_refresh,
                 cache=database.table_cache_enabled,
-                cache_timeout=database.table_cache_timeout)
+                cache_timeout=database.table_cache_timeout) or []
         else:
-            table_names = database.all_table_names_in_database(
+            tables = database.get_all_table_names_in_database(
                 cache=True, force=False, cache_timeout=24 * 60 * 60)
-            view_names = database.all_view_names_in_database(
+            views = database.get_all_view_names_in_database(
                 cache=True, force=False, cache_timeout=24 * 60 * 60)
-        table_names = security_manager.accessible_by_user(database, table_names, schema)
-        view_names = security_manager.accessible_by_user(database, view_names, schema)
+        tables = security_manager.get_datasources_accessible_by_user(
+            database, tables, schema)
+        views = security_manager.get_datasources_accessible_by_user(
+            database, views, schema)
+
+        def get_datasource_label(ds_name: utils.DatasourceName) -> str:
+            return ds_name.table if schema else f'{ds_name.schema}.{ds_name.table}'
 
         if substr:
-            table_names = [tn for tn in table_names if substr in tn]
-            view_names = [vn for vn in view_names if substr in vn]
+            tables = [tn for tn in tables if substr in get_datasource_label(tn)]
+            views = [vn for vn in views if substr in get_datasource_label(vn)]
 
         if not schema and database.default_schemas:
-            def get_schema(tbl_or_view_name):
-                return tbl_or_view_name.split('.')[0] if '.' in tbl_or_view_name else None
-
             user_schema = g.user.email.split('@')[0]
             valid_schemas = set(database.default_schemas + [user_schema])
 
-            table_names = [tn for tn in table_names if get_schema(tn) in valid_schemas]
-            view_names = [vn for vn in view_names if get_schema(vn) in valid_schemas]
+            tables = [tn for tn in tables if tn.schema in valid_schemas]
+            views = [vn for vn in views if vn.schema in valid_schemas]
 
-        max_items = config.get('MAX_TABLE_NAMES') or len(table_names)
-        total_items = len(table_names) + len(view_names)
-        max_tables = len(table_names)
-        max_views = len(view_names)
+        max_items = config.get('MAX_TABLE_NAMES') or len(tables)
+        total_items = len(tables) + len(views)
+        max_tables = len(tables)
+        max_views = len(views)
         if total_items and substr:
-            max_tables = max_items * len(table_names) // total_items
-            max_views = max_items * len(view_names) // total_items
-
-        table_options = [{'value': tn, 'label': tn}
-                         for tn in table_names[:max_tables]]
-        table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)}
-                              for vn in view_names[:max_views]])
+            max_tables = max_items * len(tables) // total_items
+            max_views = max_items * len(views) // total_items
+
+        def get_datasource_value(ds_name: utils.DatasourceName) -> Dict[str, str]:
+            return {'schema': ds_name.schema, 'table': ds_name.table}
+
+        table_options = [{'value': get_datasource_value(tn),
+                          'label': get_datasource_label(tn)}
+                         for tn in tables[:max_tables]]
+        table_options.extend([{'value': get_datasource_value(vn),
+                               'label': f'[view] {get_datasource_label(vn)}'}
+                              for vn in views[:max_views]])
         payload = {
-            'tableLength': len(table_names) + len(view_names),
+            'tableLength': len(tables) + len(views),
             'options': table_options,
         }
         return json_success(json.dumps(payload))
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index e0d914f..e190014 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -464,3 +464,22 @@ class DbEngineSpecsTestCase(SupersetTestCase):
         query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
         query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col
= N'abc'"  # noqa
         self.assertEqual(query, query_expected)
+
+    def test_get_table_names(self):
+        inspector = mock.Mock()
+        inspector.get_table_names = mock.Mock(return_value=['schema.table', 'table_2'])
+        inspector.get_foreign_table_names = mock.Mock(return_value=['table_3'])
+
+        """ Make sure base engine spec removes schema name from table name
+        ie. when try_remove_schema_from_table_name == True. """
+        base_result_expected = ['table', 'table_2']
+        base_result = db_engine_specs.BaseEngineSpec.get_table_names(
+            schema='schema', inspector=inspector)
+        self.assertListEqual(base_result_expected, base_result)
+
+        """ Make sure postgres doesn't try to remove schema name from table name
+        ie. when try_remove_schema_from_table_name == False. """
+        pg_result_expected = ['schema.table', 'table_2', 'table_3']
+        pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
+            schema='schema', inspector=inspector)
+        self.assertListEqual(pg_result_expected, pg_result)


Mime
View raw message