madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From kaknikhil <...@git.apache.org>
Subject [GitHub] madlib pull request #230: Balanced sets final
Date Fri, 02 Feb 2018 20:32:26 GMT
Github user kaknikhil commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/230#discussion_r165736819
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -0,0 +1,748 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file EXCEPT in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +m4_changequote(`<!', `!>')
    +
    +import math
    +
    +if __name__ != "__main__":
    +    import plpy
    +    from utilities.control import MinWarning
    +    from utilities.utilities import _assert
    +    from utilities.utilities import extract_keyvalue_params
    +    from utilities.utilities import unique_string
    +    from utilities.validate_args import columns_exist_in_table
    +    from utilities.validate_args import get_cols
    +    from utilities.validate_args import table_exists
    +    from utilities.validate_args import table_is_empty
    +else:
    +    # Used only for Unit Testing
    +    # FIXME: repeating a function from utilities that is needed by the unit test.
    +    # This should be removed once a unittest framework in used for testing.
    +    import random
    +    import time
    +
    +    def unique_string(desp='', **kwargs):
    +        """
    +        Generate random remporary names for temp table and other names.
    +        It has a SQL interface so both SQL and Python functions can call it.
    +        """
    +        r1 = random.randint(1, 100000000)
    +        r2 = int(time.time())
    +        r3 = int(time.time()) % random.randint(1, 100000000)
    +        u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3)
+ "__"
    +        return u_string
    +# ------------------------------------------------------------------------------
    +
    +UNIFORM = 'uniform'
    +UNDERSAMPLE = 'undersample'
    +OVERSAMPLE = 'oversample'
    +NOSAMPLE = 'nosample'
    +
    +NEW_ID_COLUMN = '__madlib_id__'
    +NULL_IDENTIFIER = '__madlib_null_id__'
    +
    +def _get_frequency_distribution(source_table, class_col):
    +    """ Returns a dict containing the number of rows associated with each class
    +        level. Each class level value is converted to a string using ::text.
    +    """
    +    query_result = plpy.execute("""
    +                    SELECT {class_col}::text AS classes,
    +                           count(*) AS class_count
    +                    FROM {source_table}
    +                    GROUP BY {class_col}
    +                 """.format(**locals()))
    +    actual_level_counts = {}
    +    for each_row in query_result:
    +        level = each_row['classes']
    +        if level:
    +            level = level.strip()
    +        actual_level_counts[level] = each_row['class_count']
    +    return actual_level_counts
    +
    +
    +def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size,
    +                            supported_strategies=None, default=UNIFORM):
    +    """ Returns the sampling strategy based on the class_sizes input param.
    +        @param sampling_strategy_str The sampling strategy specified by the
    +                                         user (class_sizes param)
    +        @returns:
    +            Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
    +    """
    +    if not sampling_strategy_str:
    +        sampling_strategy_str = default
    +    else:
    +        if len(sampling_strategy_str) < 3:
    +            # Require at least 3 characters since UNIFORM and UNDERSAMPLE have
    +            # common prefix substring
    +            plpy.error("Sample: Invalid class_sizes parameter")
    +
    +        if not supported_strategies:
    +            supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
    +        try:
    +            # allow user to specify a prefix substring of
    +            # supported strategies.
    +            sampling_strategy_str = next(x for x in supported_strategies
    +                                         if x.startswith(sampling_strategy_str.lower()))
    +        except StopIteration:
    +            # next() returns a StopIteration if no element found
    +            plpy.error("Sample: Invalid class_sizes parameter: "
    +                       "{0}. Supported class_size parameters are ({1})"
    +                       .format(sampling_strategy_str, ','.join(sorted(supported_strategies))))
    +
    +    _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
    +            (sampling_strategy_str.find('=') > 0),
    +            "Sample: Invalid class size ({sampling_strategy_str}).".format(**locals()))
    +
    +    _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
    +            "Sample: Cannot set output_table_size with oversampling.")
    +
    +    _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
    +            "Sample: Cannot set output_table_size with undersampling.")
    +
    +    return sampling_strategy_str
    +# ------------------------------------------------------------------------------
    +
    +
    +def _choose_strategy(actual_count, desired_count):
    +    """ Choose sampling strategy by comparing actual and desired sample counts
    +
    +    @param actual_count: Actual number of samples for some level
    +    @param desired_count: Desired number of sample for the level
    +    @returns:
    +        Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
    +    """
    +    # OVERSAMPLE when the actual count is less than the desired count
    +    # UNDERSAMPLE when the actual count is more than the desired count
    +
    +    # If the actual count for a class level is the same as desired count, then
    +    # we could potentially return the input rows as is.  This, however,
    +    # precludes the case of bootstrapping (i.e. returning same  number of rows
    +    # but after sampling with replacement).  Hence, we treat the actual=desired
    +    # as UNDERSAMPLE.  It's specifically set to UNDERSAMPLE since it provides
    +    # both 'with' and 'without' replacement  (OVERSAMPLE is always with
    +    # replacement and NOSAMPLE is always without replacement)
    +    if actual_count < desired_count:
    +        return OVERSAMPLE
    +    else:
    +        return UNDERSAMPLE
    +# -------------------------------------------------------------------------
    +
    +def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
    +                             actual_level_counts, output_table_size):
    +    """
    +    @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
    +                               This is 'None' only if this is user-defined, i.e.,
    +                               a comma separated list of class levels and number of
    +                               rows desired pairs.
    +    @param desired_level_counts: Dict that is defined and populated only when
    +                                    sampling_strategy_str is None.
    +    @param actual_level_counts: Dict of various class levels and number of rows
    +                                  in each of them in the input table
    +    @param output_table_size: Size of the desired output table (NULL or Integer)
    +
    +    @returns:
    +        Dict. Number of samples to be drawn, and the sampling strategy to be
    +              used for each class level.
    +    """
    +    target_level_counts = {}
    +    if not sampling_strategy_str:
    +        # This case implies user has provided a desired count for one or more
    +        # levels. Counts for the rest of the levels depend on 'output_table_size'.
    +        #   if 'output_table_size' = NULL, unspecified level counts remain as is
    +        #   if 'output_table_size' = <Integer>, divide remaining row count
    +        #                             uniformly among unspecified level counts
    +        for each_level, desired_count in desired_level_counts.items():
    +            sample_strategy = _choose_strategy(actual_level_counts[each_level],
    +                                               desired_count)
    +            target_level_counts[each_level] = (desired_count, sample_strategy)
    +
    +        remaining_levels = (set(actual_level_counts.keys()) -
    +                            set(desired_level_counts.keys()))
    +        if output_table_size:
    +            # Uniformly distribute across the remaining class levels
    +            remaining_rows = output_table_size - sum(desired_level_counts.values())
    +            if remaining_rows > 0:
    +                rows_per_level = math.ceil(float(remaining_rows) /
    +                                           len(remaining_levels))
    +                for each_level in remaining_levels:
    +                    sample_strategy = _choose_strategy(
    +                        actual_level_counts[each_level], rows_per_level)
    +                    target_level_counts[each_level] = (rows_per_level,
    +                                                       sample_strategy)
    +        else:
    +            # When output_table_size is unspecified, rows from the input table
    +            # are sampled as is for remaining class levels. This is same as the
    +            # NOSAMPLE strategy.
    +            for each_level in remaining_levels:
    +                target_level_counts[each_level] = (actual_level_counts[each_level],
    +                                                    NOSAMPLE)
    +    else:
    +        def ceil_of_mean(numbers):
    +            return math.ceil(float(sum(numbers)) / max(len(numbers), 1))
    +
    +        # UNIFORM: Ensure all level counts are same (size determined by output_table_size)
    +        # UNDERSAMPLE: Ensure all level counts are same as the minimum count
    +        # OVERSAMPLE: Ensure all level counts are same as the maximum count
    +        size_function = {UNDERSAMPLE: min,
    +                         OVERSAMPLE: max,
    +                         UNIFORM: ceil_of_mean
    +                         }[sampling_strategy_str]
    +        if sampling_strategy_str == UNIFORM and output_table_size:
    +            # Ignore actual counts for computing target sizes
    +            # if output_table_size is specified
    +            target_size_per_level = math.ceil(float(output_table_size) /
    +                                              len(actual_level_counts))
    +        else:
    +            target_size_per_level = size_function(actual_level_counts.values())
    +        for each_level, actual_count in actual_level_counts.items():
    +            sample_strategy = _choose_strategy(actual_count, target_size_per_level)
    +            target_level_counts[each_level] = (target_size_per_level,
    +                                               sample_strategy)
    +    return target_level_counts
    +
    +# -------------------------------------------------------------------------
    +
    +
    +def _get_sampling_strategy_specific_dict(target_class_sizes):
    +    """ Return three dicts, one each for undersampling, oversampling, and
    +        nosampling. The dict contains the number of samples to be drawn for
    +        each class level.
    +    """
    +    undersample_level_dict = {}
    +    oversample_level_dict = {}
    +    nosample_level_dict = {}
    +    for level, (count, strategy) in target_class_sizes.items():
    +        if strategy == UNDERSAMPLE:
    +            chosen_strategy = undersample_level_dict
    +        elif strategy == OVERSAMPLE:
    +            chosen_strategy = oversample_level_dict
    +        else:
    +            chosen_strategy = nosample_level_dict
    +        chosen_strategy[level] = count
    +    return (undersample_level_dict, oversample_level_dict, nosample_level_dict)
    +# ------------------------------------------------------------------------------
    +
    +
    +def _get_nosample_subquery(source_table, class_col, nosample_levels):
    +    """ Return the subquery for fetching all rows as is from the input table
    +        for specific class levels.
    +    """
    +    if not nosample_levels:
    +        return ''
    +    subquery = """
    +                SELECT *
    +                FROM {0}
    +                WHERE {1} in ({2}) OR {1} IS NULL
    +            """.format(source_table, class_col,
    +                       ','.join(["'{0}'".format(level)
    +                                for level in nosample_levels if level]))
    +    return subquery
    +# ------------------------------------------------------------------------------
    +
    +
    +def _get_without_replacement_subquery(schema_madlib, source_table,
    +                                      source_table_columns, class_col,
    +                                      actual_level_counts, desired_level_counts):
    +    """ Return the subquery for sampling without replacement for specific
    +        class levels.
    +    """
    +    if not desired_level_counts:
    +        return ''
    +    class_col_tmp = unique_string()
    +    row_number_col = unique_string()
    +    desired_count_col = unique_string()
    +
    +    null_value_string = "'{0}'".format(NULL_IDENTIFIER)
    +
    +    desired_level_counts_str = "VALUES " + \
    +            ','.join("({0}, {1})".
    +            format("'{0}'::text".format(k) if k else null_value_string, v)
    +            for k, v in desired_level_counts.items())
    +    subquery = """
    +            SELECT {source_table_columns}
    +            FROM
    +                (
    +                    SELECT {source_table_columns},
    +                           row_number() OVER (PARTITION BY {class_col} ORDER BY random())
AS {row_number_col},
    +                           {desired_count_col}
    +                    FROM
    +                    (
    +                        SELECT {source_table_columns},
    +                               {desired_count_col}
    +                        FROM
    +                            {source_table} s,
    +                            ({desired_level_counts_str})
    +                                q({class_col_tmp}, {desired_count_col})
    +                        WHERE {class_col_tmp} = coalesce({class_col}::text, '{null_level_val}')
    +                    ) q2
    +                ) q3
    +            WHERE {row_number_col} <= {desired_count_col}
    +        """.format(null_level_val=NULL_IDENTIFIER, **locals())
    +    return subquery
    +# ------------------------------------------------------------------------------
    +
    +
    +def _get_with_replacement_subquery(schema_madlib, source_table,
    +                                   source_table_columns, class_col,
    +                                   actual_level_counts, desired_level_counts):
    +    """ Return the query for sampling with replacement for specific class
    --- End diff --
    
    maybe reword this as 
    """
    Return the query for sampling with replacement for specific class levels. Always used
for oversampling since oversampling will always need to use replacement. Used for under sampling
only if with_replacement flag is set to TRUE.
    """


---

Mime
View raw message