singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <>
Subject [GitHub] [singa] nudles commented on issue #691: Add save and load method for Module class
Date Fri, 15 May 2020 14:34:29 GMT

nudles commented on issue #691:

   **Updated on May 15 Night**
   class Layer:
      def get_params(self):
          """the params of this layer and sublayers as a dict;  param name is: layername.param
              e.g., self.W = Tensor(), self.b=Tensor()
                     name of W and b is  like conv1.W and conv1.b  
       def get_states(self):
          """states of this layer as sublayers that are necessary for model training/evaluation/inference.
              the states include the params and others, e.g., the running mean and var of
   class Module(Layer):   
     def compile(self ...):
        """set the name of each layer and sublayers, which will be used to create the dict

             for get_params and get_states. Then no need to manually config the layer name

            the __init__ method of a layer.
           For instance,
           class Blk(Layer):
                def __init__(self):
                     self.conv1= Conv2d()
                     self.conv2 = Conv2d()
           class MyModel(Module):
                 def __init__(self):         
                    self.blk1 = Blk() --> blk1.conv1, blk1.conv2
                    self.blk2 = Blk()  --> blk2.conv1, blk2.conv2
     # high priority
     def save_states(self, fpath, aux_states={}):
         """Save states.
            fpath: output file path (without the extension)
            aux_states(dict): values are standard data types or Tensor, 
                                      e.g., epoch ID, learning rate, optimizer states
         states = get_states() + aux_states + input_placeholders
         tensor_dict = {}
         for k, v in states:
              if type(v) is Tensor:
                tensor_dict[k] = v
                states[k] = {'shape': v.shape, 'dtype': v.dtype}
         save states as json file
         save tensor_dict via numpy or hdf5 or protobuf
         zip the output files
     def load_states(self, fpath, dev, use_graph=True, graph_alg='sequence'):
         """Load the model onto dev
            path: input file path (without the extension)
         unzip the input file
         load the json file --> states
         load the tensor files --> tensor_dict
         put the tensors into states
         states --> model_states + input_placeholders + aux_states
         self.compile(input_placeholders, dev, use_graph, graph_alg)
        return the rest states as a dict
   # lower priority
   def save(fpath, model):
       attributes <-- model
       replace all tensors in attributes --> {'shape': v.shape, 'dtype': v.dtype}
       dump the tensors via numpy or protobuf or hdf5
       dump model via pickle
       zip the output files
   def load(fpath, dev, use_graph, graph_alg):
        unzip the input file
        load model via pickle
        load tensors 
        restore the tensors in model attributes
        return the model
   # handle ONNX 
   def to_onnx(model):
       return a onnx model 
   class SONNXModel(Module):
        def __init__(self, onnx_model):
             self.store_output = store_output
             for layer_name, layer_config in get_layer(onnx_model):
                 self.__dict__[layer_name] = CreateLayer(...)
       def forward(self, aux_output):
             run forward according to onnx graph 
            return the last output + aux_output
   class MyModel(SONNXModel):
        def __init__(self, onnx):
             self.layer1 = Conv()
             self.layer2 = Conv()
        def forward(self, x):
              x1, x2 = super.forward(x, aux_output)
              x = self.layer1.forward(x2)
              return self.layer2.forward(x1) + x
         def train_one_batch(self, x, y):
              y_ = self.forward(x)
   * Params: layer parameters (Tensor) that are updated via SGD. `Layer.get_params()`
   * States: Params + other variables that are necessary for model evaluation/inference. Superset
of params.  `Layer.get_states()`
   * Attributes: members of a class instance `class.__dict__`. Superset of states.

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:

View raw message