singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [singa] joddiy commented on issue #691: Add save and load method for Module class
Date Fri, 15 May 2020 10:00:29 GMT

joddiy commented on issue #691:
URL: https://github.com/apache/singa/issues/691#issuecomment-629147694


   ## Conslusion first
   
   Good news: 
   > The ONNX can defines the loss and optimizer now within its format. However, current
loss only have `NegativeLogLikelihoodLoss` and `SoftmaxCrossEntropyLoss`. Also, it only can
store optimizers, only have - `Adagrad`, `Adam`, `Momentum`(SGD with standard momentum). 
   
   Bad news:
   > we need to update the onnx to 1.7, which is released last week, may not be so stable.
In this release, ONNX defines a comlicated node called `GraphCall` to specify which gradients
should be computed and how to update the tensors by using these gradients. Since we will update
the weights following the backward, so this part may not be useful for us.
   
   ## ONNX Training Preview (TrainingInfoProto)
   
   In last week, the ONNX team has released a new version [1.7.0](https://github.com/onnx/onnx/releases/tag/v1.7.0)
which upgrade its opset version to 12. In this new rleases, they add a new feature called
[`TrainingInfoProto`](https://github.com/onnx/onnx/blob/3368834cf0b1f0ab9838cf6bdf78a27299d08187/onnx/onnx.proto3#L211-L316).

   
   This new feature defines something about training information. There are two main parts
in it, `initialization-step` and `training-algorithm-step`.
   
   ### initialization-step
   
   `initialization-step` means the developer can defines a `initialization`. For its type,
the `initialization` is a formal ONNX graph. It doesn't have input but seveal outputs. The
developer can defines some nodes in this graph, such as `RandomNormal` or `RandomUniform`,
and in another field called `initialization_binding`, the developer can assign these outputs
to the specific tensors in the inference graph.
   
   The current supported ramdom methods are: `RandomNormal` or `RandomUniform`.
   
   ### training-algorithm-step
   
   `training-algorithm-step` defines a field called `algorithm`. It defines a inference graph
which represents a training algorithm's step. Given required inputs, it computes outputs to
update tensors in its own or in the main computaton graph. `update_binding` contains a key-value
pair of strings to assign the outputs to some specific tensors.
   
   In general, this graph contains loss node, gradient node, optimizer node, increment of
iteration count, and some calls to the inference graph. The field algorithm.node is the only
place the user can use GraphCall operator. 
   
   #### Loss node
   
   - `NegativeLogLikelihoodLoss`
   - `SoftmaxCrossEntropyLoss`
   
   
   #### Optimizer node
   
   - `Adagrad`
   - `Adam`
   - `Momentum`: SG with standard momentum
   
   #### Gradient node
   
   The gradient node actually only defines the necessary information to compute the gradient
for all graph, for example, at the following graph, the gradient defines its inputs containing
the `xs`(intermidate weights) and `zs`(input of the graph), and `y`(the output of the graph),
and its outputs having `dY/dW`, `dY/dZ` whose order corresponds to the inputs in `xs`. 
   
   It doesn't defines any logic about how to compute the `dY/dW`, `dY/dZ`.
   
   ```
   W --> Conv --> H --> Gemm --> Y
   |      ^              ^
   |      |              |
   |      X              Z
   |      |              |
   |      |   .----------'
   |      |   |  (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in
   |      |   |   "xs" followed by "zs")
   |      v   v
   '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y")
          |   |
          |   '-----------------------------------> dY/dW (1st output of Gradient)
          |
          '---------------------------------------> dY/dZ (2nd output of Gradient)
   ```
   
   #### GraphCall node
   
   The GraphCall operator invokes a graph inside TrainingInfoProto's algorithm field. The
GraphCall inputs and outputs are bound to those of invoked graph by position.
   
   Based on the above inference graph, the GraphCall can use like this:
   
   ```
   .-------- W (a global and mutable variable from
   |         |  the inference graph)
   |         |
   |   .-----'-----------.
   |   |                 |
   |   |                 v
   |   | .-- X_1 --> GraphCall(graph_name="MyInferenceGraph")
   |   | |            |  |
   |   | |            |  |
   |   | |   Z_1 -----'  |
   |   | |    |          V
   |   | |    |         Y_1 ---> Loss ---> O
   |   | |    |                    ^
   |   | |    |                    |
   |   | `--. |                    C
   |   |    | |                    |
   |   |    | |   .----------------'
   |   |    | |   |
   |   |    v v   v
   |   `--> Gradient(xs=["W"], zs=["X_1", "Z_1", "C"], y="O")
   |        |
   |        v
   |      dO_dW (gradient of W)      1 (a scalar one)
   |        |                        |
   |        V                        v
   |       Div <--- T ------------> Add ---> T_new
   |        |    (T is the number of training iterations.
   |        |     T is also globally visible and mutable.)
   |        v
   `-----> Sub ----> W_new
   ```
   
   The previous section's inference graph is called by `GraphCall(graph_name="MyInferenceGraph")`,
and it uses a new batch of inputs (`X_1`, `Z_1`) to compute `Y_1`. 
   
   `Gradient` defines the graidents the graph should compute, finally, it gets `W_new` amd
`T_new`.
   
   The it uses the following `update_binding` to udpate the tensors:
   
   ```
   update_binding: {"W": "W_new", "T": "T_new"}
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message