madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [madlib] kaknikhil commented on a change in pull request #361: Minibatch Preprocessor DL: Add optional num_classes param.
Date Mon, 01 Apr 2019 21:20:48 GMT
kaknikhil commented on a change in pull request #361: Minibatch Preprocessor DL: Add optional
num_classes param.
URL: https://github.com/apache/madlib/pull/361#discussion_r271054629
 
 

 ##########
 File path: src/ports/postgres/modules/utilities/minibatch_preprocessing.py_in
 ##########
 @@ -363,21 +365,70 @@ class MiniBatchPreProcessorDL(MiniBatchPreProcessor):
 
         self._validate_args()
         self.num_of_buffers = self._get_num_buffers()
-        self.to_one_hot_encode = True
         if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
             self.dependent_levels = None
         else:
             self.dependent_levels = get_distinct_col_levels(
-                self.source_table, self.dependent_varname, self.dependent_vartype)
+                self.source_table, self.dependent_varname,
+                self.dependent_vartype, exclude_nulls=False)
+            # if any class level was NULL in sql, that would show up as
+            # None in self.dependent_levels. Replace all None with NULL
+            # in the list.
+            self.dependent_levels = ['NULL' if level is None else level
+                for level in self.dependent_levels]
+            self._validate_num_classes()
+
+    def _validate_num_classes(self):
+        if self.num_classes is not None and \
+            self.num_classes < len(self.dependent_levels):
+            plpy.error("{0}: Invalid num_classes value specified. It must "\
+                "be equal to or greater than distinct class values found "\
+                "in table ({1}).".format(
+                    self.module_name, len(self.dependent_levels)))
+
+    def get_dep_var_array_expr(self):
+        """
+        :param dependent_varname: Name of the dependent variable
+        :param num_classes: Number of class values to consider in 1-hot
+        :return:
+            This function returns a tuple of
+            1. A string with transformed dependent varname depending on it's type
+            2. All the distinct dependent class levels encoded as a string
+
+            If dep_type == numeric[] , do not encode
+                    1. dependent_varname = rings
+                        transformed_value = ARRAY[rings]
+                    2. dependent_varname = ARRAY[a, b, c]
+                        transformed_value = ARRAY[a, b, c]
+            else if dep_type in ("text", "boolean"), encode:
+                    3. dependent_varname = rings (encoding)
+                        transformed_value = ARRAY[rings=1, rings=2, rings=3]
+        """
+        # Assuming the input NUMERIC[] is already one_hot_encoded,
+        # so casting to INTEGER[]
+        if is_valid_psql_type(self.dependent_vartype, NUMERIC | ONLY_ARRAY):
+            return self.dependent_varname + '::INTEGER[]'
+
+        # For DL use case, we want to allow NULL as a valid class value,
+        # so the query must have 'IS NOT DISTINCT FROM' instead of '='
+        # like in the generic get_one_hot_encoded_expr() defined in
+        # db_utils.py_in. We also have this optional 'num_classes' param
+        # that affects the logic of 1-hot encoding. Since this is very
+        # specific to minibatch_preprocessing_dl for now, let's keep
+        # it here instead of refactoring it out to a generic helper function.
+        one_hot_encoded_expr = ["({0}) IS NOT DISTINCT FROM {1}".format(
+            self.dependent_varname, c) for c in self.dependent_levels]
+        if self.num_classes:
+            one_hot_encoded_expr.extend(['0'
+                for i in range(self.num_classes-len(self.dependent_levels))])
 
 Review comment:
   Can we create a class variable for `self.num_classes-len(self.dependent_levels)` in the
constructor so that it can be reused and is easy to read ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message