From dev-return-4319-archive-asf-public=cust-asf.ponee.io@madlib.apache.org Mon Apr 1 21:20:56 2019 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id AE813180672 for ; Mon, 1 Apr 2019 23:20:55 +0200 (CEST) Received: (qmail 31479 invoked by uid 500); 1 Apr 2019 21:20:49 -0000 Mailing-List: contact dev-help@madlib.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@madlib.apache.org Delivered-To: mailing list dev@madlib.apache.org Received: (qmail 30909 invoked by uid 99); 1 Apr 2019 21:20:49 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 01 Apr 2019 21:20:49 +0000 From: GitBox To: dev@madlib.apache.org Subject: [GitHub] [madlib] kaknikhil commented on a change in pull request #361: Minibatch Preprocessor DL: Add optional num_classes param. Message-ID: <155415364886.30588.16071196727354139615.gitbox@gitbox.apache.org> Date: Mon, 01 Apr 2019 21:20:48 -0000 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 8bit kaknikhil commented on a change in pull request #361: Minibatch Preprocessor DL: Add optional num_classes param. URL: https://github.com/apache/madlib/pull/361#discussion_r271054629 ########## File path: src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in ########## @@ -363,21 +365,70 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor): self._validate_args() self.num_of_buffers = self._get_num_buffers() - self.to_one_hot_encode = True if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): self.dependent_levels = None else: self.dependent_levels = get_distinct_col_levels( - self.source_table, self.dependent_varname, self.dependent_vartype) + self.source_table, self.dependent_varname, + self.dependent_vartype, exclude_nulls=False) + # if any class level was NULL in sql, that would show up as + # None in self.dependent_levels. Replace all None with NULL + # in the list. + self.dependent_levels = ['NULL' if level is None else level + for level in self.dependent_levels] + self._validate_num_classes() + + def _validate_num_classes(self): + if self.num_classes is not None and \ + self.num_classes < len(self.dependent_levels): + plpy.error("{0}: Invalid num_classes value specified. It must "\ + "be equal to or greater than distinct class values found "\ + "in table ({1}).".format( + self.module_name, len(self.dependent_levels))) + + def get_dep_var_array_expr(self): + """ + :param dependent_varname: Name of the dependent variable + :param num_classes: Number of class values to consider in 1-hot + :return: + This function returns a tuple of + 1. A string with transformed dependent varname depending on it's type + 2. All the distinct dependent class levels encoded as a string + + If dep_type == numeric[] , do not encode + 1. dependent_varname = rings + transformed_value = ARRAY[rings] + 2. dependent_varname = ARRAY[a, b, c] + transformed_value = ARRAY[a, b, c] + else if dep_type in ("text", "boolean"), encode: + 3. dependent_varname = rings (encoding) + transformed_value = ARRAY[rings=1, rings=2, rings=3] + """ + # Assuming the input NUMERIC[] is already one_hot_encoded, + # so casting to INTEGER[] + if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY): + return self.dependent_varname + '::INTEGER[]' + + # For DL use case, we want to allow NULL as a valid class value, + # so the query must have 'IS NOT DISTINCT FROM' instead of '=' + # like in the generic get_one_hot_encoded_expr() defined in + # db_utils.py_in. We also have this optional 'num_classes' param + # that affects the logic of 1-hot encoding. Since this is very + # specific to minibatch_preprocessing_dl for now, let's keep + # it here instead of refactoring it out to a generic helper function. + one_hot_encoded_expr = ["({0}) IS NOT DISTINCT FROM {1}".format( + self.dependent_varname, c) for c in self.dependent_levels] + if self.num_classes: + one_hot_encoded_expr.extend(['0' + for i in range(self.num_classes-len(self.dependent_levels))]) Review comment: Can we create a class variable for `self.num_classes-len(self.dependent_levels)` in the constructor so that it can be reused and is easy to read ? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services