singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "Chris Yeung (Jira)" <>
Subject [jira] [Created] (SINGA-505) Buffer Operators / Change the Autograd operators to be bufferable
Date Fri, 31 Jan 2020 09:08:00 GMT
Chris Yeung created SINGA-505:

             Summary: Buffer Operators / Change the Autograd operators to be bufferable
                 Key: SINGA-505
             Project: Singa
          Issue Type: Improvement
          Components: Core
            Reporter: Chris Yeung

We can buffer the operators, so that we can extract all the operators in autograd to build
a graph after schedule, where the simplest scheduling can use the FIFO principle from the
buffered operators. A more complex scheduleing algorithm could be implemented which consider
the dependency of operators that could make it parallel. One more clear advantage is that
when we run the graph we only need to run the buffered operators called by the autograd function,
then there will be no need to run the autograd python code again throughout the training process.

So this ticket uses for two purpose:

1. Change the core components (e.g. tensor,device) to support buffering.

2. Change all the autograd operator to be bufferable, i.e. the input and output should be
inside the block. For example, the SoftMax backward cannot be buffered because it is not doing
the operating through the block, and it was using numpy:
    def backward(self, dy):
        # calculations are made on numpy array
        if self.axis == 1:
            dy = singa.DefaultTranspose(dy)
        grad = ctensor2numpy(dy)
        output = ctensor2numpy(self.output)
        out_1 = np.einsum("ki,ki->ki", grad, output)
        medium_out = np.einsum("ki,kj->kij", output, output)
        out_2 = np.einsum("kij,kj->ki", medium_out, grad)
        out = out_1 - out_2
        dx = CTensor(out_1.shape)
        if self.axis == 0:
            return dx
        elif self.axis == 1:
            return singa.DefaultTranspose(dx)

This message was sent by Atlassian Jira

View raw message