madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jingyimei <...@git.apache.org>
Subject [GitHub] madlib pull request #289: RF: Add impurity variable importance
Date Tue, 10 Jul 2018 21:32:29 GMT
Github user jingyimei commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/289#discussion_r201501719
  
    --- Diff: src/ports/postgres/modules/recursive_partitioning/random_forest.py_in ---
    @@ -1291,38 +1300,64 @@ def _create_group_table(
             schema_madlib, output_table_name, oob_error_table,
             importance_table, cat_features_info_table, grp_key_to_grp_cols,
             grouping_cols, tree_terminated):
    -    """ Ceate the group table for random forest"""
    +    """ Create the group table for random forest"""
    +
    +    cat_var_importance_str = ''
    +    con_var_importance_str = ''
    +    impurity_var_importance_str = ''
    +    left_join_importance_table_str = ''
    +    join_impurity_table_str = ''
    +
    +    if importance_table:
    +        impurity_var_importance_table_name = unique_string(desp='impurity')
    +        plpy.execute("""
    +            CREATE TEMP TABLE {impurity_var_importance_table_name} AS
    +            SELECT
    +                gid,
    +                {schema_madlib}.array_avg(impurity_var_importance, False) AS impurity_var_importance
    +            FROM {output_table_name}
    +            GROUP BY gid
    +            """.format(**locals()))
    +
    +        cat_var_importance_str = ", cat_var_importance AS oob_cat_var_importance,"
    +        con_var_importance_str = "con_var_importance AS oob_con_var_importance,"
    +        impurity_var_importance_str = "impurity_var_importance"
    +        left_join_importance_table_str = """LEFT OUTER JOIN {importance_table}
    +            USING (gid)""".format(importance_table=importance_table)
    +        join_impurity_table_str = """JOIN {impurity_var_importance_table_name} USING
(gid)""".format(impurity_var_importance_table_name=impurity_var_importance_table_name)
    +
         grouping_cols_str = ('' if grouping_cols is None
                              else grouping_cols + ",")
         group_table_name = add_postfix(output_table_name, "_group")
    +
         sql_create_group_table = """
                 CREATE TABLE {group_table_name} AS
                 SELECT
                     gid,
                     {grouping_cols_str}
    -                grp_finished as success,
    +                grp_finished AS success,
                     cat_n_levels,
                     cat_levels_in_text,
    -                oob_error,
    -                cat_var_importance,
    -                con_var_importance
    +                oob_error
    +                {cat_var_importance_str}
    +                {con_var_importance_str}
    +                {impurity_var_importance_str}
                 FROM
                     {oob_error_table}
                 JOIN
                     {grp_key_to_grp_cols}
                 USING (gid)
                 JOIN (
                     SELECT
    -                    unnest($1) as grp_key,
    -                    unnest($2) as grp_finished
    +                    unnest($1) AS grp_key,
    +                    unnest($2) AS grp_finished
                 ) tree_terminated
                 USING (grp_key)
                 JOIN
                     {cat_features_info_table}
                 USING (gid)
    -            LEFT OUTER JOIN
    -                {importance_table}
    -            USING (gid)
    +            {left_join_importance_table_str}
    --- End diff --
    
    We can name it oob_importance_table_* to explicitly distinguish it from impurity importance
table


---

Mime
View raw message