Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/230#discussion_r165529067
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,748 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file EXCEPT in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+m4_changequote(`<!', `!>')
+
+import math
+
+if __name__ != "__main__":
+ import plpy
+ from utilities.control import MinWarning
+ from utilities.utilities import _assert
+ from utilities.utilities import extract_keyvalue_params
+ from utilities.utilities import unique_string
+ from utilities.validate_args import columns_exist_in_table
+ from utilities.validate_args import get_cols
+ from utilities.validate_args import table_exists
+ from utilities.validate_args import table_is_empty
+else:
+ # Used only for Unit Testing
+ # FIXME: repeating a function from utilities that is needed by the unit test.
+ # This should be removed once a unittest framework in used for testing.
+ import random
+ import time
+
+ def unique_string(desp='', **kwargs):
+ """
+ Generate random remporary names for temp table and other names.
+ It has a SQL interface so both SQL and Python functions can call it.
+ """
+ r1 = random.randint(1, 100000000)
+ r2 = int(time.time())
+ r3 = int(time.time()) % random.randint(1, 100000000)
+ u_string = "__madlib_temp_" + desp + str(r1) + "_" + str(r2) + "_" + str(r3)
+ "__"
+ return u_string
+# ------------------------------------------------------------------------------
+
+UNIFORM = 'uniform'
+UNDERSAMPLE = 'undersample'
+OVERSAMPLE = 'oversample'
+NOSAMPLE = 'nosample'
+
+NEW_ID_COLUMN = '__madlib_id__'
+NULL_IDENTIFIER = '__madlib_null_id__'
+
+def _get_frequency_distribution(source_table, class_col):
+ """ Returns a dict containing the number of rows associated with each class
+ level. Each class level value is converted to a string using ::text.
+ """
+ query_result = plpy.execute("""
+ SELECT {class_col}::text AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ GROUP BY {class_col}
+ """.format(**locals()))
+ actual_level_counts = {}
+ for each_row in query_result:
+ level = each_row['classes']
+ if level:
+ level = level.strip()
+ actual_level_counts[level] = each_row['class_count']
+ return actual_level_counts
+
+
+def _validate_and_get_sampling_strategy(sampling_strategy_str, output_table_size,
+ supported_strategies=None, default=UNIFORM):
+ """ Returns the sampling strategy based on the class_sizes input param.
+ @param sampling_strategy_str The sampling strategy specified by the
+ user (class_sizes param)
+ @returns:
+ Str. One of [UNIFORM, UNDERSAMPLE, OVERSAMPLE]. Default is UNIFORM.
+ """
+ if not sampling_strategy_str:
+ sampling_strategy_str = default
+ else:
+ if len(sampling_strategy_str) < 3:
+ # Require at least 3 characters since UNIFORM and UNDERSAMPLE have
+ # common prefix substring
+ plpy.error("Sample: Invalid class_sizes parameter")
+
+ if not supported_strategies:
+ supported_strategies = [UNIFORM, UNDERSAMPLE, OVERSAMPLE]
+ try:
+ # allow user to specify a prefix substring of
+ # supported strategies.
+ sampling_strategy_str = next(x for x in supported_strategies
+ if x.startswith(sampling_strategy_str.lower()))
+ except StopIteration:
+ # next() returns a StopIteration if no element found
+ plpy.error("Sample: Invalid class_sizes parameter: "
+ "{0}. Supported class_size parameters are ({1})"
+ .format(sampling_strategy_str, ','.join(sorted(supported_strategies))))
+
+ _assert(sampling_strategy_str.lower() in (UNIFORM, UNDERSAMPLE, OVERSAMPLE) or
+ (sampling_strategy_str.find('=') > 0),
+ "Sample: Invalid class size ({sampling_strategy_str}).".format(**locals()))
+
+ _assert(not(sampling_strategy_str.lower() == 'oversample' and output_table_size),
+ "Sample: Cannot set output_table_size with oversampling.")
+
+ _assert(not(sampling_strategy_str.lower() == 'undersample' and output_table_size),
+ "Sample: Cannot set output_table_size with undersampling.")
+
+ return sampling_strategy_str
+# ------------------------------------------------------------------------------
+
+
+def _choose_strategy(actual_count, desired_count):
+ """ Choose sampling strategy by comparing actual and desired sample counts
+
+ @param actual_count: Actual number of samples for some level
+ @param desired_count: Desired number of sample for the level
+ @returns:
+ Str. Sampling strategy string (either UNDERSAMPlE or OVERSAMPLE)
+ """
+ # OVERSAMPLE when the actual count is less than the desired count
+ # UNDERSAMPLE when the actual count is more than the desired count
+
+ # If the actual count for a class level is the same as desired count, then
+ # we could potentially return the input rows as is. This, however,
+ # precludes the case of bootstrapping (i.e. returning same number of rows
+ # but after sampling with replacement). Hence, we treat the actual=desired
+ # as UNDERSAMPLE. It's specifically set to UNDERSAMPLE since it provides
+ # both 'with' and 'without' replacement (OVERSAMPLE is always with
+ # replacement and NOSAMPLE is always without replacement)
+ if actual_count < desired_count:
+ return OVERSAMPLE
+ else:
+ return UNDERSAMPLE
+# -------------------------------------------------------------------------
+
+def _get_target_level_counts(sampling_strategy_str, desired_level_counts,
+ actual_level_counts, output_table_size):
+ """
+ @param sampling_strategy_str: one of [UNIFORM, UNDERSAMPLE, OVERSAMPLE, None].
+ This is 'None' only if this is user-defined, i.e.,
+ a comma separated list of class levels and number of
+ rows desired pairs.
+ @param desired_level_counts: Dict that is defined and populated only when
+ sampling_strategy_str is None.
+ @param actual_level_counts: Dict of various class levels and number of rows
+ in each of them in the input table
+ @param output_table_size: Size of the desired output table (NULL or Integer)
+
+ @returns:
+ Dict. Number of samples to be drawn, and the sampling strategy to be
+ used for each class level.
+ """
+ target_level_counts = {}
+ if not sampling_strategy_str:
--- End diff --
The more i look at this function the more it feels like that the if and else conditions
could be two different functions but I understand that it does add a bit more redirection.
I don't feel strongly about this, just my two cents.
---
|