madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From iyerr3 <...@git.apache.org>
Subject [GitHub] incubator-madlib pull request: SVM: Add cross validation support a...
Date Tue, 24 Nov 2015 22:25:29 GMT
Github user iyerr3 commented on a diff in the pull request:

    https://github.com/apache/incubator-madlib/pull/4#discussion_r45804084
  
    --- Diff: src/ports/postgres/modules/svm/svm.py_in ---
    @@ -66,211 +253,118 @@ def svm(schema_madlib, source_table, model_table,
         """
         Executes the linear support vector classification algorithm.
         """
    -    # verbosing
    +   # verbosing
         verbosity_level = "info" if verbose else "error"
         with MinWarning(verbosity_level):
    -        # validate input
    -        input_tbl_valid(source_table, 'SVM')
    -        _assert(is_var_valid(source_table, dependent_varname),
    -                "SVM error: invalid dependent_varname ('" + str(dependent_varname) +
    -                "') for source_table (" + source_table + ")!")
    -        _assert(is_var_valid(source_table, independent_varname),
    -                "SVM error: invalid independent_varname ('" + str(independent_varname)
+
    -                "') for source_table (" + source_table + ")!")
    -
    -        dep_type = get_expr_type(dependent_varname, source_table)
    -        if '[]' in dep_type:
    -            plpy.error("SVM error: dependent_varname cannot be of array type!")
    -
    -        # validate output tables
    -        output_tbl_valid(model_table, 'SVM')
    -        summary_table = add_postfix(model_table, "_summary")
    -        output_tbl_valid(summary_table, 'SVM')
    -
    -        # arguments for iterating
    -        n_features = plpy.execute("SELECT array_upper({0}, 1) AS dim "
    -                                  "FROM {1} LIMIT 1".
    -                                  format(independent_varname, source_table)
    -                                  )[0]['dim']
    -        if grouping_col:
    -            grouping_list = [i + "::text"
    -                             for i in explicit_bool_to_text(
    -                                source_table,
    -                                _string_to_array_with_quotes(grouping_col),
    -                                schema_madlib)]
    -            grouping_str = ','.join(grouping_list)
    -        else:
    -            grouping_str = "Null"
    -        grouping_str1 = "" if not grouping_col else grouping_col + ","
    -        grouping_str2 = "1 = 1" if not grouping_col else grouping_col
    -
    -        args = {
    -                'rel_args': unique_string(desp='rel_args'),
    -                'rel_state': unique_string(desp='rel_state'),
    -                'col_grp_iteration': unique_string(desp='col_grp_iteration'),
    -                'col_grp_state': unique_string(desp='col_grp_state'),
    -                'col_grp_key': unique_string(desp='col_grp_key'),
    -                'col_n_tuples': unique_string(desp='col_n_tuples'),
    -                'state_type': "double precision[]",
    -                'rel_source': source_table,
    -                'col_ind_var': independent_varname,
    -                'col_dep_var': dependent_varname}
    -        args.update(locals())
    -        # variables defined above cannot be moved below this line
    -        # -------------------------------------------------------
    -
    -        # other params
    -        kernel_func = 'linear' if not kernel_func else kernel_func.lower()
    -        # Add non-linear kernels below after implementing them.
    -        supported_kernels = ['linear']
    -        try:
    -            # allow user to specify a prefix substring of
    -            # supported kernel function names. This works because the supported
    -            # kernel functions have unique prefixes.
    -            kernel_func = next(x for x in supported_kernels if x.startswith(kernel_func))
    -        except StopIteration:
    -            # next() returns a StopIteration if no element found
    -            plpy.error("SVM Error: Invalid kernel function: {0}. "
    -                       "Supported kernel functions are ({1})"
    -                       .format(kernel_func, ','.join(sorted(supported_kernels))))
    -
    -        if grouping_col:
    -            cols_in_tbl_valid(source_table, _string_to_array_with_quotes(grouping_col),
'SVM')
    -            intersect = frozenset(_string_to_array(grouping_col)).intersection(
    -                                    frozenset(
    -                                        ('coef', '__random_feature_data',
    -                                         '__random_feature_data', 'loss'
    -                                         'num_rows_processed', 'num_rows_skipped',
    -                                         'norm_of_gradient', 'num_iterations')))
    -            if len(intersect) > 0:
    -                plpy.error("SVM error: Conflicting grouping column name.\n"
    -                           "Some predefined keyword(s) ({0}) are not allowed!".format(
    -                                ', '.join(intersect)))
    -
    -        args.update(_extract_params(schema_madlib, params))
    -        args.update(_process_epsilon(is_svc, args))
    -
    -        if not is_svc:
    -            # transform col_dep_var to binary (1 or -1) if classification
    -            args.update({
    -                    'col_dep_var_trans': dependent_varname,
    -                    'mapping': 'NULL',
    -                    'method': 'SVR'})
    -        else:
    -            # dependent variable mapping
    -            dep_labels=plpy.execute("""
    -                SELECT {dependent_varname} AS y
    -                FROM {source_table}
    -                WHERE ({dependent_varname}) IS NOT NULL
    -                GROUP BY ({dependent_varname})
    -                ORDER BY ({dependent_varname})""".format(**locals()))
    -            dep_var_mapping = ["'" + d['y'] + "'" if isinstance(d['y'], basestring)
    -                               else str(d['y']) for d in dep_labels]
    -            if len(dep_var_mapping) != 2:
    -                plpy.error("SVM error: Classification currently only supports binary
output")
    -
    -            col_dep_var_trans = (
    -                """
    -                CASE WHEN ({col_dep_var}) IS NULL THEN NULL
    -                    WHEN ({col_dep_var}) = {mapped_value_for_negative} THEN -1.0
    -                    ELSE 1.0
    -                END
    -                """
    -                .format(col_dep_var=dependent_varname,
    -                        mapped_value_for_negative=dep_var_mapping[0])
    -                )
    -
    -            args.update({
    -                 'mapped_value_for_negative': dep_var_mapping[0],
    -                 'col_dep_var_trans': col_dep_var_trans,
    -                 'mapping': dep_var_mapping[0] + "," + dep_var_mapping[1],
    -                 'method': 'SVC'})
    -
    -        args['stepsize'] = args['init_stepsize']
    -        args['is_l2'] = True if args['norm'] == 'l2' else False
    -
    -        # place holder for compatibility
    -        plpy.execute("CREATE TABLE pg_temp.{0} AS SELECT 1".format(args['rel_args']))
    -        # actual iterative algorithm computation
    -        n_iters_run = _compute_svm(args)
    -
    -        # organizing results
    -        groupby_str = "GROUP BY {grouping_col}, {col_grp_key}".format(**args) if grouping_col
else ""
    -        using_str = "USING ({col_grp_key})".format(**args) if grouping_col else "ON TRUE"
    -        model_table_query = """
    -            CREATE TABLE {model_table} AS
    -                SELECT
    -                    {grouping_str1}
    -                    (result).coefficients           AS coef,
    -                    (result).loss                   AS loss,
    -                    (result).norm_of_gradient       AS norm_of_gradient,
    -                    {n_iters_run}                   AS num_iterations,
    -                    (result).num_rows_processed     AS num_rows_processed,
    -                    n_tuples_including_nulls - (result).num_rows_processed
    -                                                    AS num_rows_skipped,
    -                    NULL                            AS __random_feature_data,
    -                    ARRAY[{mapping}]::{dep_type}[]  AS dep_var_mapping
    -                FROM
    -                (
    -                    SELECT
    -                        {schema_madlib}.internal_linear_svm_igd_result(
    -                            {col_grp_state}
    -                        ) AS result,
    -                        {col_grp_key}
    -                    FROM {rel_state}
    -                    WHERE {col_grp_iteration} = {n_iters_run}
    -                ) rel_state_subq
    -                JOIN
    -                (
    -                    SELECT
    -                        {grouping_str1}
    -                        count(*) AS n_tuples_including_nulls,
    -                        array_to_string(ARRAY[{grouping_str}],
    -                                        ','
    -                                       ) AS {col_grp_key}
    -                    FROM {source_table}
    -                    {groupby_str}
    -                ) n_tuples_including_nulls_subq
    -                {using_str}
    -            """.format(n_iters_run=n_iters_run,
    -                       groupby_str=groupby_str,
    -                       using_str=using_str, **args)
    -        plpy.execute(model_table_query)
    -
    -        if isinstance(args['lambda'], list):
    -            args['lambda_str'] = '{' + ','.join(str(e) for e in args['lambda']) + '}'
    -        else:
    -            args['lambda_str'] = str(args['lambda'])
    -
    -        plpy.execute("""
    -                CREATE TABLE {summary_table} AS
    -                SELECT
    -                    '{method}'::text                    AS method,
    -                    '__MADLIB_VERSION__'::text          AS version_number,
    -                    '{source_table}'::text              AS source_table,
    -                    '{model_table}'::text               AS model_table,
    -                    '{dependent_varname}'::text         AS dependent_varname,
    -                    '{independent_varname}'::text       AS independent_varname,
    -                    'linear'::text                      AS kernel_func,
    -                    NULL::text                          AS kernel_params,
    -                    '{grouping_text}'::text             AS grouping_col,
    -                    'init_stepsize={init_stepsize}, '   ||
    -                        'decay_factor={decay_factor}, ' ||
    -                        'max_iter={max_iter}, '         ||
    -                        'tolerance={tolerance}'::text   AS optim_params,
    -                    'lambda={lambda_str}, ' ||
    -                        'norm={norm}, '     ||
    -                        'n_folds={n_folds}'::text       AS reg_params,
    -                    count(*)::integer                   AS num_all_groups,
    -                    0::integer                          AS num_failed_groups,
    -                    sum(num_rows_processed)::bigint     AS total_rows_processed,
    -                    sum(num_rows_skipped)::bigint       AS total_rows_skipped,
    -                    '{epsilon}'::double precision       AS epsilon,
    -                    '{eps_table}'::text                 AS eps_table
    -                FROM {model_table};
    -                """.format(grouping_text="NULL" if not grouping_col else grouping_col,
    -                           **args))
    -# ------------------------------------------------------------------------------
    +        _verify_table(source_table,
    +                      model_table,
    +                      dependent_varname,
    +                      independent_varname)
    +        args = locals()
    +        args['params_dict'] = _extract_params(schema_madlib, params)
    +        _cross_validate_svm(args)
    +        _svm_parsed_params(**args)
    +
    +
    +def _cross_validate_svm(args):
    +    # updating params_dict will also update
    +    # also update args['params_dict']
    +    params_dict = args['params_dict']
    +
    +    if params_dict['n_folds'] > 1 and args['grouping_col']:
    +        plpy.error('SVM error: cross validation '
    +                   'with grouping is not supported!')
    +
    +    # currently only support cross validation
    +    # on lambda and epsilon
    +    cv_params = {}
    +    if len(params_dict['lambda']) > 1:
    +        cv_params['lambda'] = params_dict['lambda']
    +    else:
    +        params_dict['lambda'] = params_dict['lambda'][0]
    +    if len(params_dict['epsilon']) > 1 and not args['is_svc']:
    +        cv_params['epsilon'] = params_dict['epsilon']
    +    else:
    +        params_dict['epsilon'] = params_dict['epsilon'][0]
    +    if len(params_dict['init_stepsize']) > 1:
    +        cv_params['init_stepsize'] = params_dict['init_stepsize']
    +    else:
    +        params_dict['init_stepsize'] = params_dict['init_stepsize'][0]
    +    if len(params_dict['max_iter']) > 1:
    +        cv_params['max_iter'] = params_dict['max_iter']
    +    else:
    +        params_dict['max_iter'] = params_dict['max_iter'][0]
    +    if len(params_dict['decay_factor']) > 1:
    +        cv_params['decay_factor'] = params_dict['decay_factor']
    +    else:
    +        params_dict['decay_factor'] = params_dict['decay_factor'][0]
    +
    +    if not cv_params and params_dict['n_folds'] <= 1:
    +        return
    +
    +    if cv_params and params_dict['n_folds'] <= 1:
    +        plpy.error("SVM Error: parameters must be a scalar "
    +                   "or of length 1 when n_folds is 0 or 1")
    +        return
    +
    +    if not cv_params and params_dict['n_folds'] > 1:
    +        plpy.warning('SVM Warning: no cross validate params provided! '
    +                     'Ignore {}-folds cross validation request.'
    +                     .format(params_dict['n_folds']))
    +        return
    +
    +    scorer = 'classification' if args['is_svc'] else 'regression'
    +    sub_args = {'params_dict':cv_params}
    +    cv = CrossValidator(_svm_parsed_params,svm_predict,scorer,args)
    +    val_res = cv.validate(sub_args, params_dict['n_folds']).sorted()
    +    val_res.output_tbl(params_dict['validation_result'])
    +    params_dict.update(val_res.first('sub_args')['params_dict'])
    +
    +
    +def _svm_parsed_params(schema_madlib, source_table, model_table,
    +                       dependent_varname, independent_varname, kernel_func,
    +                       kernel_params, grouping_col, params_dict, is_svc,
    +                       verbose, **kwargs):
    +    """
    +    Executes the linear support vector classification algorithm.
    +    """
    +    grouping_str = _verify_grouping(schema_madlib,
    +                                    source_table,
    +                                    grouping_col)
    +
    +    kernel_func = _verify_kernel(kernel_func)
    +
    +    # arguments for iterating
    +    n_features = num_features(source_table,
    +                              independent_varname)
    +
    +    args = {
    +            'rel_args': unique_string(desp='rel_args'),
    +            'rel_state': unique_string(desp='rel_state'),
    +            'col_grp_iteration': unique_string(desp='col_grp_iteration'),
    +            'col_grp_state': unique_string(desp='col_grp_state'),
    +            'col_grp_key': unique_string(desp='col_grp_key'),
    +            'col_n_tuples': unique_string(desp='col_n_tuples'),
    +            'state_type': "double precision[]",
    +            'n_features': n_features,
    +            'verbose': verbose,
    +            'schema_madlib': schema_madlib,
    +            'grouping_str': grouping_str,
    +            'grouping_col': grouping_col,
    +            'rel_source': source_table,
    +            'col_ind_var': independent_varname,
    +            'col_dep_var': dependent_varname}
    +
    +    args.update(_verify_params_dict(params_dict))
    +    args.update(_process_epsilon(is_svc, args))
    +    args.update(_svc_or_svr(is_svc, source_table, dependent_varname))
    +
    +    # place holder for compatibility
    +    plpy.execute("CREATE TABLE pg_temp.{0} AS SELECT 1".format(args['rel_args']))
    +    # actual iterative algorithm computation
    +    n_iters_run = _compute_svm(args)
    +    _summary(n_iters_run, model_table, args)
    +>>>>>>> b105d1c... SVM: Add cross validation support and generic
CrossValidator class
    --- End diff --
    
    @mktal Invalid line. Also run the py files through a linter. There are a few places where
we're not following the pep8 guidelines.  


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message