madlib-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [madlib] reductionista commented on a change in pull request #362: DL: Remove num_classes param from madlib_keras_fit()
Date Thu, 04 Apr 2019 22:48:56 GMT
reductionista commented on a change in pull request #362: DL: Remove num_classes param from
madlib_keras_fit()
URL: https://github.com/apache/madlib/pull/362#discussion_r272395326
 
 

 ##########
 File path: src/ports/postgres/modules/utilities/model_arch_info.py_in
 ##########
 @@ -21,69 +21,45 @@ m4_changequote(`<!', `!>')
 
 import sys
 import json
+import plpy
 
-def get_layers(arch):
-    d = json.loads(arch)
+def _get_layers(model_arch):
+    d = json.loads(model_arch)
     config = d['config']
     if type(config) == list:
-        return config  # In keras 1.x, all models are sequential
+        return config  # In keras 2.1.x, all models are sequential
     elif type(config) == dict and 'layers' in config:
         layers = config['layers']
         if type(layers) == list:
             return config['layers']  # In keras 2.x, only sequential models are supported
-    plpy.error('Unable to read input_shape from keras model arch.  Note: only sequential
keras models are supported.')
-    return None
+    plpy.error("Unable to read model architecture JSON.")
 
-def get_input_shape(arch):
-    layers = get_layers(arch)
-    return layers[0]['config']['batch_input_shape'][1:]
+def get_input_shape(model_arch):
+    arch_layers = _get_layers(model_arch)
+    if 'batch_input_shape' in arch_layers[0]['config']:
+        return arch_layers[0]['config']['batch_input_shape'][1:]
+    plpy.error('Unable to get input shape from model architecture.')
 
-def print_model_arch_layers(arch):
-    layers = get_layers(arch)
+def get_num_classes(model_arch):
+    arch_layers = _get_layers(model_arch)
+    if 'units' in arch_layers[-1]['config']:
+        return arch_layers[-1]['config']['units']
+    plpy.error('Unable to get number of classes from model architecture.')
 
-    print("\nModel arch layers:")
+def get_model_arch_layers_str(model_arch):
+    arch_layers = _get_layers(model_arch)
+    layers = "Model arch layers:\n"
     first = True
-    for layer in layers:
+    for layer in arch_layers:
         if first:
             first = False
         else:
-            print("   |")
-            print("   V")
+            layers = "{0}   |\n".format(layers)
+            layers = "{0}   V\n".format(layers)
         class_name = layer['class_name']
         config = layer['config']
         if class_name == 'Dense':
-            print("{0}[{1}]".class_name)
+            layers = "{0}{1}[{2}]\n".format(layers, config, class_name)
 
 Review comment:
   This prints out a bunch of json.  This line looks like it must have already been broken
before, but I think it should be:
   ```
   layers += '{0}[{1}]'.format(class_name,layer['units'])
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message