madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From kaknikhil <...@git.apache.org>
Subject [GitHub] madlib pull request #230: Balanced sets final
Date Fri, 02 Feb 2018 20:32:25 GMT
Github user kaknikhil commented on a diff in the pull request:

    https://github.com/apache/madlib/pull/230#discussion_r165527448
  
    --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
    @@ -0,0 +1,748 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file EXCEPT in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +m4_changequote(`<!', `!>')
    +
    +import math
    +
    +if __name__ != "__main__":
    +    import plpy
    +    from utilities.control import MinWarning
    +    from utilities.utilities import _assert
    +    from utilities.utilities import extract_keyvalue_params
    +    from utilities.utilities import unique_string
    +    from utilities.validate_args import columns_exist_in_table
    +    from utilities.validate_args import get_cols
    +    from utilities.validate_args import table_exists
    +    from utilities.validate_args import table_is_empty
    +else:
    +    # Used only for Unit Testing
    +    # FIXME: repeating a function from utilities that is needed by the unit test.
    +    # This should be removed once a unittest framework in used for testing.
    +    import random
    +    import time
    +
    +    def unique_string(desp='', **kwargs):
    +        """
    +        Generate random remporary names for temp table and other names.
    +        It has a SQL interface so both SQL and Python functions can call it.
    +        """
    +        r1 = random.randint(1, 100000000)
    +        r2 = int(time.time())
    +        r3 = int(time.time()) % random.randint(1, 100000000)
    +        u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3)
+ "__"
    +        return u_string
    +# ------------------------------------------------------------------------------
    +
    +UNIFORM = 'uniform'
    +UNDERSAMPLE = 'undersample'
    +OVERSAMPLE = 'oversample'
    +NOSAMPLE = 'nosample'
    +
    +NEW_ID_COLUMN = '__madlib_id__'
    +NULL_IDENTIFIER = '__madlib_null_id__'
    +
    +def _get_frequency_distribution(source_table, class_col):
    +    """ Returns a dict containing the number of rows associated with each class
    +        level. Each class level value is converted to a string using ::text.
    +    """
    +    query_result = plpy.execute("""
    +                    SELECT {class_col}::text AS classes,
    +                           count(*) AS class_count
    +                    FROM {source_table}
    +                    GROUP BY {class_col}
    +                 """.format(**locals()))
    +    actual_level_counts = {}
    +    for each_row in query_result:
    +        level = each_row['classes']
    +        if level:
    +            level = level.strip()
    +        actual_level_counts[level] = each_row['class_count']
    +    return actual_level_counts
    +
    +
    +def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size,
    +                            supported_strategies=None, default=UNIFORM):
    +    """ Returns the sampling strategy based on the class_sizes input param.
    +        @param sampling_strategy_str The sampling strategy specified by the
    +                                         user (class_sizes param)
    +        @returns:
    +            Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
    +    """
    +    if not sampling_strategy_str:
    +        sampling_strategy_str = default
    +    else:
    +        if len(sampling_strategy_str) < 3:
    +            # Require at least 3 characters since UNIFORM and UNDERSAMPLE have
    +            # common prefix substring
    +            plpy.error("Sample: Invalid class_sizes parameter")
    +
    +        if not supported_strategies:
    +            supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
    +        try:
    +            # allow user to specify a prefix substring of
    +            # supported strategies.
    +            sampling_strategy_str = next(x for x in supported_strategies
    +                                         if x.startswith(sampling_strategy_str.lower()))
    +        except StopIteration:
    +            # next() returns a StopIteration if no element found
    +            plpy.error("Sample: Invalid class_sizes parameter: "
    +                       "{0}. Supported class_size parameters are ({1})"
    +                       .format(sampling_strategy_str, ','.join(sorted(supported_strategies))))
    +
    +    _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
    +            (sampling_strategy_str.find('=') > 0),
    +            "Sample: Invalid class size ({sampling_strategy_str}).".format(**locals()))
    +
    +    _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
    +            "Sample: Cannot set output_table_size with oversampling.")
    +
    +    _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
    +            "Sample: Cannot set output_table_size with undersampling.")
    +
    +    return sampling_strategy_str
    +# ------------------------------------------------------------------------------
    +
    +
    +def _choose_strategy(actual_count, desired_count):
    +    """ Choose sampling strategy by comparing actual and desired sample counts
    +
    +    @param actual_count: Actual number of samples for some level
    +    @param desired_count: Desired number of sample for the level
    +    @returns:
    +        Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
    +    """
    +    # OVERSAMPLE when the actual count is less than the desired count
    +    # UNDERSAMPLE when the actual count is more than the desired count
    +
    +    # If the actual count for a class level is the same as desired count, then
    +    # we could potentially return the input rows as is.  This, however,
    +    # precludes the case of bootstrapping (i.e. returning same  number of rows
    +    # but after sampling with replacement).  Hence, we treat the actual=desired
    +    # as UNDERSAMPLE.  It's specifically set to UNDERSAMPLE since it provides
    +    # both 'with' and 'without' replacement  (OVERSAMPLE is always with
    +    # replacement and NOSAMPLE is always without replacement)
    +    if actual_count < desired_count:
    +        return OVERSAMPLE
    +    else:
    +        return UNDERSAMPLE
    +# -------------------------------------------------------------------------
    +
    +def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
    +                             actual_level_counts, output_table_size):
    +    """
    +    @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
    +                               This is 'None' only if this is user-defined, i.e.,
    +                               a comma separated list of class levels and number of
    +                               rows desired pairs.
    +    @param desired_level_counts: Dict that is defined and populated only when
    --- End diff --
    
    We can reword this as something like`If the previous arg sampling_strategy_str is None,
this dict contains the   class levels and the corresponding number of rows specified by the
user`


---

Mime
View raw message