From dev-return-2765-apmail-madlib-dev-archive=madlib.apache.org@madlib.apache.org Fri Feb 2 20:32:42 2018 Return-Path: X-Original-To: apmail-madlib-dev-archive@minotaur.apache.org Delivered-To: apmail-madlib-dev-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 17F6017FB1 for ; Fri, 2 Feb 2018 20:32:42 +0000 (UTC) Received: (qmail 93768 invoked by uid 500); 2 Feb 2018 20:32:41 -0000 Delivered-To: apmail-madlib-dev-archive@madlib.apache.org Received: (qmail 93738 invoked by uid 500); 2 Feb 2018 20:32:41 -0000 Mailing-List: contact dev-help@madlib.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@madlib.apache.org Delivered-To: mailing list dev@madlib.apache.org Received: (qmail 93447 invoked by uid 99); 2 Feb 2018 20:32:41 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 02 Feb 2018 20:32:41 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 1DD60E96D8; Fri, 2 Feb 2018 20:32:41 +0000 (UTC) From: kaknikhil To: dev@madlib.apache.org Reply-To: dev@madlib.apache.org References: In-Reply-To: Subject: [GitHub] madlib pull request #230: Balanced sets final Content-Type: text/plain Message-Id: <20180202203241.1DD60E96D8@git1-us-west.apache.org> Date: Fri, 2 Feb 2018 20:32:41 +0000 (UTC) Github user kaknikhil commented on a diff in the pull request: https://github.com/apache/madlib/pull/230#discussion_r165529067 --- Diff: src/ports/postgres/modules/sample/balance_sample.py_in --- @@ -0,0 +1,748 @@ +# coding=utf-8 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file EXCEPT in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +m4_changequote(`') + +import math + +if __name__ != "__main__": + import plpy + from utilities.control import MinWarning + from utilities.utilities import _assert + from utilities.utilities import extract_keyvalue_params + from utilities.utilities import unique_string + from utilities.validate_args import columns_exist_in_table + from utilities.validate_args import get_cols + from utilities.validate_args import table_exists + from utilities.validate_args import table_is_empty +else: + # Used only for Unit Testing + # FIXME: repeating a function from utilities that is needed by the unit test. + # This should be removed once a unittest framework in used for testing. + import random + import time + + def unique_string(desp='', **kwargs): + """ + Generate random remporary names for temp table and other names. + It has a SQL interface so both SQL and Python functions can call it. + """ + r1 = random.randint(1, 100000000) + r2 = int(time.time()) + r3 = int(time.time()) % random.randint(1, 100000000) + u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3) + "__" + return u_string +# ------------------------------------------------------------------------------ + +UNIFORM = 'uniform' +UNDERSAMPLE = 'undersample' +OVERSAMPLE = 'oversample' +NOSAMPLE = 'nosample' + +NEW_ID_COLUMN = '__madlib_id__' +NULL_IDENTIFIER = '__madlib_null_id__' + +def _get_frequency_distribution(source_table, class_col): + """ Returns a dict containing the number of rows associated with each class + level. Each class level value is converted to a string using ::text. + """ + query_result = plpy.execute(""" + SELECT {class_col}::text AS classes, + count(*) AS class_count + FROM {source_table} + GROUP BY {class_col} + """.format(**locals())) + actual_level_counts = {} + for each_row in query_result: + level = each_row['classes'] + if level: + level = level.strip() + actual_level_counts[level] = each_row['class_count'] + return actual_level_counts + + +def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size, + supported_strategies=None, default=UNIFORM): + """ Returns the sampling strategy based on the class_sizes input param. + @param sampling_strategy_str The sampling strategy specified by the + user (class_sizes param) + @returns: + Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM. + """ + if not sampling_strategy_str: + sampling_strategy_str = default + else: + if len(sampling_strategy_str) < 3: + # Require at least 3 characters since UNIFORM and UNDERSAMPLE have + # common prefix substring + plpy.error("Sample: Invalid class_sizes parameter") + + if not supported_strategies: + supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE] + try: + # allow user to specify a prefix substring of + # supported strategies. + sampling_strategy_str = next(x for x in supported_strategies + if x.startswith(sampling_strategy_str.lower())) + except StopIteration: + # next() returns a StopIteration if no element found + plpy.error("Sample: Invalid class_sizes parameter: " + "{0}. Supported class_size parameters are ({1})" + .format(sampling_strategy_str, ','.join(sorted(supported_strategies)))) + + _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or + (sampling_strategy_str.find('=') > 0), + "Sample: Invalid class size ({sampling_strategy_str}).".format(**locals())) + + _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size), + "Sample: Cannot set output_table_size with oversampling.") + + _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size), + "Sample: Cannot set output_table_size with undersampling.") + + return sampling_strategy_str +# ------------------------------------------------------------------------------ + + +def _choose_strategy(actual_count, desired_count): + """ Choose sampling strategy by comparing actual and desired sample counts + + @param actual_count: Actual number of samples for some level + @param desired_count: Desired number of sample for the level + @returns: + Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE) + """ + # OVERSAMPLE when the actual count is less than the desired count + # UNDERSAMPLE when the actual count is more than the desired count + + # If the actual count for a class level is the same as desired count, then + # we could potentially return the input rows as is. This, however, + # precludes the case of bootstrapping (i.e. returning same number of rows + # but after sampling with replacement). Hence, we treat the actual=desired + # as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it provides + # both 'with' and 'without' replacement (OVERSAMPLE is always with + # replacement and NOSAMPLE is always without replacement) + if actual_count < desired_count: + return OVERSAMPLE + else: + return UNDERSAMPLE +# ------------------------------------------------------------------------- + +def _get_target_level_counts(sampling_strategy_str, desired_level_counts, + actual_level_counts, output_table_size): + """ + @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None]. + This is 'None' only if this is user-defined, i.e., + a comma separated list of class levels and number of + rows desired pairs. + @param desired_level_counts: Dict that is defined and populated only when + sampling_strategy_str is None. + @param actual_level_counts: Dict of various class levels and number of rows + in each of them in the input table + @param output_table_size: Size of the desired output table (NULL or Integer) + + @returns: + Dict. Number of samples to be drawn, and the sampling strategy to be + used for each class level. + """ + target_level_counts = {} + if not sampling_strategy_str: --- End diff -- The more i look at this function the more it feels like that the if and else conditions could be two different functions but I understand that it does add a bit more redirection. I don't feel strongly about this, just my two cents. ---