Github user orhankislal commented on a diff in the pull request:
https://github.com/apache/madlib/pull/223#discussion_r161297926
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,994 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file EXCEPT in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import math
+import plpy
+import re
+from collections import defaultdict
+from fractions import Fraction
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param class_size Parameter to define the size of the different
+ class values.
+ @param output_table_size Desired size of the output data set.
+ @param grouping_cols The columns columns that defines the grouping.
+ @param with_replacement The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+ desired_sample_per_class = unique_string(desp='desired_sample_per_class')
+ desired_counts = unique_string(desp='desired_counts')
+
+ if not class_sizes or class_sizes.strip().lower() in ('null', ''):
+ class_sizes = 'uniform'
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols, with_replacement)
+
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+
+ _create_frequency_distribution(class_counts, source_table, class_col)
+ temp_views = [class_counts]
+
+ if class_sizes.lower() == 'undersample' and not with_replacement:
+ """
+ Random undersample without replacement.
+ Randomly order the rows and give a unique (per class)
+ identifier to each one.
+ Select rows that have identifiers under the target limit.
+ """
+ _undersampling_with_no_replacement(source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement,
+ class_counts, source_table_columns)
+
+ _delete_temp_views(temp_views)
+ return
+
+ """
+ Create views for true and desired sample sizes of classes
+ """
+ """
+ include_unsampled_classes tracks is unsampled classes are desired or not.
+ include_unsampled_classes is always true in output_table_size Null cases
but changes given values of desired sample class sizes in comma-delimited classsize paramter.
+ """
+ include_unsampled_classes = True
+ sampling_with_comma_delimited_class_sizes = class_sizes.find(':') > 0
+
+ if sampling_with_comma_delimited_class_sizes:
+ """
+ Compute sample sizes based on
+ comman-delimited list of class_sizes
+ and/or output_table_size
+ """
+ class_sizes, include_unsampled_classes = _validate_format_and_values(class_sizes,
source_table,
+ class_col, output_table_size, class_counts, include_unsampled_classes)
+
+ """
+ Only valid condition for sampling is desired_sample_sizes <= output_table_size
+ """
+ temp_views.extend(_create_desired_and_actual_sampling_views(class_counts,
+ desired_sample_per_class, desired_counts
+ , source_table, output_table, class_col
+ , class_sizes, output_table_size, include_unsampled_classes))
+
+ if class_sizes.lower() == 'uniform':
+ """
+ Compute sample sizes based on
+ uniform distribution of class sizes
+ """
+ temp_views.extend(_compute_uniform_class_sizes(
+ class_counts, desired_sample_per_class, desired_counts
+ , source_table, output_table, class_col, class_sizes,
+ output_table_size))
+
+ oversampling_specific_classes = False
+ desired_undersample_class_sizes = defaultdict(str)
+
+ if sampling_with_comma_delimited_class_sizes or class_sizes.lower() == 'uniform':
+
+ oversampling_specific_classes = plpy.execute("""
+ SELECT * FROM {desired_sample_per_class}
+ WHERE category = 'oversample'
+ """.format(**locals())).nrows() > 0
+ if oversampling_specific_classes:
+ with_replacement = True
+
+ undersampling_res = plpy.execute("""
+ SELECT array_agg(classes::text || ':' || sample_class_size::text)
+ as undersample_set FROM {desired_sample_per_class}
+ WHERE category = 'undersample'
+ """.format(**locals()))
+ if undersampling_res.nrows() > 0 and undersampling_res[0]['undersample_set']
is not None:
+ for val in undersampling_res[0]['undersample_set']:
+ desired_undersample_class_sizes[val.split(':')[0]] = val.split(':')[1]
+
+ if class_sizes.lower() == 'oversample':
+ """
+ oversampling with replacement
+ """
+ with_replacement = True
+ func_name = 'max'
+
+ if class_sizes.lower() == 'undersample' and with_replacement:
+ """
+ Undersampling with replacement.
+ """
+ func_name = 'min'
+
+ if with_replacement:
+ """
+ Random sample with replacement.
+ Undersample will have func_name set to min
+ Oversample will have func_name set to max.
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ if oversampling_specific_classes:
+ select_oversample_classes = """ WHERE {class_col}::text in
+ (SELECT classes
+ FROM {desired_sample_per_class}
+ WHERE category like 'oversample')
+ """.format(**locals())
+ classwise_row_numbering_sql += select_oversample_classes
+
+ """
+ Create independent random values
+ for each class that has a different row count than the target
+ """
+ if oversampling_specific_classes:
+ random_targetclass_size_sample_number_gen_sql = """
+ SELECT
+ {desired_sample_per_class}.classes,
+ generate_series(1, sample_class_size::int) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ {class_counts},
+ {desired_sample_per_class}
+ WHERE
+ {desired_sample_per_class}.classes = {class_counts}.classes
+ AND category like 'oversample'
+ """.format(**locals())
+ else:
+ random_targetclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1, target_class_size::int) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ {func_name}(class_count) AS target_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count != target_class_size
+ """.format(**locals())
+
+ """
+ Match random values with the row identifiers
+ """
+ sample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+ ({random_targetclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col}::text = f2.classes)
+ """.format(**locals())
+
+ if not oversampling_specific_classes:
+ """
+ Find classes with target number of rows
+ """
+ targetclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col}::text IN
+ (SELECT
+ classes AS target_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT {func_name}(class_count) FROM {class_counts}))
+ """.format(**locals())
+
+ """
+ Combine target and other sampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({targetclass_set}) AS a
+ UNION ALL
+ ({sample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ _delete_temp_views(temp_views)
+ return
+
+ """
+ Unsampled classes
+ """
+ nosample_classset_sql = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col}::text IN
+ (SELECT
+ classes
+ FROM {desired_sample_per_class}
+ WHERE category like 'nosample')
+ """.format(**locals())
+ """
+ Union all Undersampled classes
+ """
+ undersampling_classset_sql = ''
+ if len(desired_undersample_class_sizes) > 0:
+ undersampling_classset_sql = ' UNION ALL'.join("""
+ (SELECT {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} = '{clas}'
+ ORDER BY random()
+ LIMIT {limit_bound})
+ """.format(source_table_columns=source_table_columns,
+ source_table=source_table,
+ class_col=class_col,
+ limit_bound=clas_limit,
+ clas=clas) for clas, clas_limit in desired_undersample_class_sizes.iteritems())
+ undersampling_classset_sql = ' UNION ALL ' + undersampling_classset_sql
+
+ """
+ Union all Oversampled classes
+ """
+ oversampling_specific_classes_classset_sql = ''
+ if oversampling_specific_classes:
+ oversampling_specific_classes_classset_sql = """
+ UNION ALL
+ ({sample_otherclass_set})
+ """.format(**locals())
+
+ if (oversampling_specific_classes or len(desired_undersample_class_sizes) >
0):
+ """
+ Combine all sampled and/or unsampled classes
+ """
+ if not include_unsampled_classes:
+ nosample_classset_sql.replace('nosample', '')
+
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({nosample_classset_sql}) AS a
+ {oversampling_specific_classes_classset_sql}
+ {undersampling_classset_sql})
+ """.format(**locals())
+
+ plpy.execute(output_sql)
+
+ _delete_temp_views(temp_views)
+ return
+
+"""
+ Delete all temp views
+"""
+def _delete_temp_views(temp_views):
+ for temp_view in temp_views:
+ plpy.execute("DROP VIEW IF EXISTS {0} cascade".format(temp_view))
+ return
+
+"""
+ Random undersample without replacement.
+"""
+def _undersampling_with_no_replacement(source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement,
+ class_counts, source_table_columns):
+
+ distinct_class_labels = plpy.execute("""
+ SELECT array_agg(DISTINCT {class_col}::text) AS labels
+ FROM {source_table}
+ """.format(**locals()))[0]['labels']
+
+ limit_bound = plpy.execute("""
+ SELECT MIN(class_count)::int AS min
+ FROM {class_counts}""".format(**locals()))[0]['min']
+
+ minority_class = plpy.execute("""
+ SELECT array_agg(classes::text) as minority_class
+ FROM {class_counts}
+ WHERE class_count = {limit_bound}
+ """.format(**locals()))[0]['minority_class']
+
+ distinct_class_labels = [cl for cl in distinct_class_labels
+ if cl not in minority_class]
+
+ foo_table = unique_string(desp='foo')
+ start_output_qry = """
+ SELECT {source_table_columns}
+ FROM (
+ SELECT {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col}::text = '{dcl}'
+ ORDER BY random()
+ LIMIT {limit_bound}
+ ) AS {foo_table}
+ UNION """.format(dcl=distinct_class_labels[0], **locals())
+
+ union_qry = ' UNION '.join("""
+ (SELECT {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col}::text = '{val}'
+ ORDER BY random()
+ LIMIT {limit_bound})
+ """.format(source_table_columns=source_table_columns,
+ source_table=source_table,
+ class_col=class_col,
+ limit_bound=limit_bound,
+ val=val) for val in distinct_class_labels[1:])
+
+ min_class_tuple = "('" + "','".join([str(a) for a in minority_class]) + "')"
+
+ minority_sql = """ UNION
+ SELECT {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN {min_class_tuple} """.format(**locals())
+
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ {start_output_qry}
+ {union_qry}
+ {minority_sql} )""".format(**locals())
+ plpy.execute(output_sql)
+
+"""
+ Captures cases where classes are specified multiple time in comma-delimited string.
+ e.g. class_sizes = '5:6,5:4'
+"""
+class UniqueDict(dict):
+ def __setitem__(self, classkey, class_size):
+ if classkey not in self:
+ """
+ float(class_size).is_integer() ensures only whole numbers are added as
+ class's sample size
+ """
+ if (class_size > 0.0 or class_size < 1.0) or (class_size >= 1.0
and float(class_size).is_integer()):
+ dict.__setitem__(self, classkey, class_size)
+ else:
+ plpy.error("Sample: Sample size should be a fraction between (0.0,1.0)
or a whole number greater than 1")
+ else:
+ plpy.error("Sample: Repeated classes in class_sizes")
+
+"""
+ Check if the class sizes (passed as strings) are alphanumeric, float or whole numbers.
+ Error out on alphanumeric values.
+ e.g. class_sizes = '3:a,5:b
+ Returns type of the size as int or float.
+"""
+def _check_value_type(value):
+ try:
+ float(value)
+ except ValueError:
+ plpy.error("Sample: Specify either fractions (0.0,1.0) or whole numbers for class
sample size.")
+
+ if re.match("^\d+?\.\d+?$", value) is not None and not float(value).is_integer():
+ valueType = float
+ elif (re.match(r"[-+]?\d+$", value) is not None):
+ valueType = int
+ return valueType
+"""
+ Check if classes are present in class_col
+"""
+def _validate_classes(all_classes, source_table, class_col):
+
+ nonexisting_classes = plpy.execute("""
+ SELECT
+ unnest(ARRAY[{all_classes}])
+ EXCEPT
+ SELECT
+ distinct({class_col}::text)
+ FROM {source_table}
+ """.format(**locals()))
+
+ if nonexisting_classes.nrows() > 0:
+ plpy.error("""Sample: Specified classes do not exist in
+ {class_col}""".format(**locals()))
+
+"""
+ Checks the format and values of classes and their respective sizes specified in comman-delimited
string.
+ 1. Checks if the classes specified are present in the source table.
+ 2. Checks for total sample size to be between (0.0,1.0]
+ 3. Checks for output_table_size < total desired size of the classes
+ specified in class_sizes
+ 4. Checks for cases when only fractions or whole numbers specified in sample sizes
+ 5. Checks for value is a whole number greater then 1
+ 6. Checks if same class is specified multiple times in class_sizes
+ 7. Checks is classes are present in class_col
+"""
+def _validate_format_and_values(class_sizes, source_table, class_col,
+ output_table_size, class_counts, include_unsampled_classes):
+
+ class_sizes_arr = class_sizes.split(',')
+
+ cs_dict = UniqueDict(defaultdict())
+
+ numeric_value_sum = 0
+ fraction_value_sum = 0.0
+
+ for x in class_sizes_arr:
+ class_and_size = x.split(':')
+ valueType = _check_value_type(class_and_size[1])
+ # Following error type is invalidated by Frank as of 29th Dec.
+ # if _check_value_type(val[1].strip(), valueType) != valueType:
+ #plpy.error("Sample: Specify either fractions or whole number values for
ALL classes.")
+ fraction_value_sum += valueType(class_and_size[1]) if valueType == float else
0.0
+ numeric_value_sum += valueType(class_and_size[1]) if valueType != float else
0
+ cs_dict[class_and_size[0].strip()] = valueType(class_and_size[1])
+
+ """
+ Check to see if specified classes are present in the class_col
+ """
+ all_classes = str(cs_dict.keys())[1:-1]
+ _validate_classes(all_classes, source_table, class_col)
+
+ """
+ Error out if fraction_value_sum is greater than 1.0 or when fraction sum is 1.0
and other classes with whole numbers as class sizes are also specified.
+ """
+ if fraction_value_sum > 1.0 or (fraction_value_sum == 1.0 and numeric_value_sum
!= 0):
+ plpy.error("""Sample: Fraction sum < 1.0, when any other class is also specified
as class_name:class_size-in-whole-numbers. Fraction sum can be at most 1.0.
+ """.format(**locals()))
+
+ total_table_size = plpy.execute("""
+ SELECT
+ count(*) AS total
+ FROM {source_table}
+ """.format(**locals()))[0]['total']
+
+ """
+ Compute class sizes when no Fractions are specified in class_sizes
+ """
+ if Fraction(fraction_value_sum) == Fraction(0.0):
+ if (not output_table_size):
+ # Sample remaining classes uniformly
+ return class_sizes, include_unsampled_classes
+
+ if output_table_size < numeric_value_sum:
+ plpy.error("""Sample: Output table size ({output_table_size}) must be more
than total specified sample size i.e. {numeric_value_sum}""".format(**locals()))
+
+ if output_table_size == numeric_value_sum:
+ # Do not sample other classes
+ return class_sizes, not include_unsampled_classes
+
+ # Sample remaining classes uniformly with target table size
+ return class_sizes, include_unsampled_classes
+
+ """
+ Compute class sizes when only Fractions are specified in class_sizes, which also
sum to 1.0
+ """
+ if Fraction(fraction_value_sum) == Fraction(1.0):
+ # Do not sample other classes
+ return _compute_class_sizes(cs_dict, total_table_size)[0], not include_unsampled_classes
+
+ """
+ Compute sample classs size when both fractions and whole numbers are mentioned
in class_size comma-delimited string
+ """
+ if Fraction(fraction_value_sum) > Fraction(0.0):
+
+ sum_remaining_class_samples = plpy.execute("""
+ SELECT sum({class_counts}.class_count) AS remaining_classes FROM {class_counts}
+ WHERE classes not IN ({all_classes})
+ """.format(**locals()))
+ """
+ When output_table_size is Null. Use following example logic to compute desired
sample sizes.
+
+ Suppose male=.4,output_table_size= NULL and let’s say there are 2 other
categorical values female=10M, other=1M
+ Use the following logic to calculate 'computed' output_table_size x
+ .4x + 10M + 1M = x
+ where x = computed output_table_size. Here x = 18.3M
+ """
+ if (not output_table_size):
+ y = 1.0 - fraction_value_sum
+
+ if sum_remaining_class_samples.nrows() > 0:
+ numeric_value_sum += sum_remaining_class_samples[0]['remaining_classes']
+
+ if numeric_value_sum == 0:
+ # A rare case happens WHEN there is ONLY one class present in class_col
+ return _compute_class_sizes(cs_dict, total_table_size)[0], not include_unsampled_classes
+
+ # Compute total_desired_sample_size as x
+ x = math.ceil(float(numeric_value_sum) / y)
+ class_size, _ = _compute_class_sizes(cs_dict, x)
+ return class_size, include_unsampled_classes
+
+ """
+ When output_table_size is given,
+ compute the total_desired_sample_size and perform checks to ensure validity
of class_size with total_desired_sample_size
+ """
+ output_table_size = float(output_table_size)
+ class_size, total_desired_sample_size = _compute_class_sizes(cs_dict, output_table_size)
+
+ ## Cases when total desired sample size > specified output table size
+ if total_desired_sample_size > output_table_size:
+ plpy.error("""Sample: Output table size ({output_table_size}) must be more
than total desired sample size i.e {total_desired_sample_size}""".format(**locals()))
+
+ if total_desired_sample_size == output_table_size:
+ # Do not sample other classes
+ return class_size, not include_unsampled_classes
+
+ if total_desired_sample_size < output_table_size:
+ # Sample other classes uniformly
+ return class_size, include_unsampled_classes
+
+"""
+ Proportions of class sizes are multiplied by output_table_size to get desired whole
number value for class sizes.
+ A total_desired_size is sum of all whole number class sizes
+"""
+def _compute_class_sizes(cs_dict, x):
+ class_size = ''
+ total_desired_size = 0
+ for clas, class_val in cs_dict.iteritems():
+ if float(class_val).is_integer():
+ total_desired_size += class_val
+ class_size += str(clas) + ':' + str(class_val) + ','
+ else:
+ class_val = int(round(class_val * x, 0))
+ total_desired_size += class_val
+ class_size += str(clas) + ':' + str(class_val) + ','
+
+ return class_size[:-1], total_desired_size
+
+"""
+ Create view to store class counts of classes in class_col
+"""
+def _create_frequency_distribution(class_counts, source_table, class_col):
+
+ grp_by = "GROUP BY {0}".format(class_col)
--- End diff --
We can move this inside the next command.
---
|