Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 22069200B14 for ; Fri, 3 Jun 2016 09:48:47 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 20B1A160A25; Fri, 3 Jun 2016 07:48:47 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id D4951160A2A for ; Fri, 3 Jun 2016 09:48:44 +0200 (CEST) Received: (qmail 58429 invoked by uid 500); 3 Jun 2016 07:48:44 -0000 Mailing-List: contact commits-help@singa.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@singa.incubator.apache.org Delivered-To: mailing list commits@singa.incubator.apache.org Received: (qmail 58420 invoked by uid 99); 3 Jun 2016 07:48:44 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd4-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 03 Jun 2016 07:48:44 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd4-us-west.apache.org (ASF Mail Server at spamd4-us-west.apache.org) with ESMTP id 7C8FEC12BC for ; Fri, 3 Jun 2016 07:48:43 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd4-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.646 X-Spam-Level: X-Spam-Status: No, score=-4.646 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-1.426] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd4-us-west.apache.org [10.40.0.11]) (amavisd-new, port 10024) with ESMTP id 09eDa3l0SGt4 for ; Fri, 3 Jun 2016 07:48:23 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id E1B8D60CF1 for ; Fri, 3 Jun 2016 07:48:10 +0000 (UTC) Received: (qmail 56161 invoked by uid 99); 3 Jun 2016 07:48:10 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 03 Jun 2016 07:48:10 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id EBECDDFAED; Fri, 3 Jun 2016 07:48:09 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: wangsh@apache.org To: commits@singa.incubator.apache.org Date: Fri, 03 Jun 2016 07:48:36 -0000 Message-Id: <59c5ff74b424449e898d607287f52330@git.apache.org> In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [31/60] incubator-singa git commit: SINGA-167 - Add Tensor Math function APIs archived-at: Fri, 03 Jun 2016 07:48:47 -0000 SINGA-167 - Add Tensor Math function APIs Add basic linalg functions for Tensor Add blas functions for Tensor. Unify gemm and gemv in Tensor::Mult this commit also contains code for Param class, which woud be removed in the next commit. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/02851fac Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/02851fac Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/02851fac Branch: refs/heads/dev Commit: 02851fac11ae6455b60d1cd5be4c2b6f142696cf Parents: e36bc92 Author: Wei Wang Authored: Fri May 13 21:00:48 2016 +0800 Committer: wangwei Committed: Tue May 17 00:40:23 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 2 +- include/singa/core/math.h | 273 --------------------- include/singa/core/tensor.h | 285 +++++++++++----------- include/singa/model/layer.h | 23 +- include/singa/model/param.h | 97 ++++++++ src/core/device/device.cc | 1 + src/core/math/cpp_math.cc | 54 ----- src/core/math/cuda_math.cc | 48 ---- src/core/math/opencl_math.cc | 24 -- src/core/tensor/tensor.cc | 379 ++++++++++++++++++++++++++---- src/core/tensor/tensor_math.h | 302 ++++++++++++++++++++++++ src/core/tensor/tensor_math_cpp.h | 57 +++++ src/core/tensor/tensor_math_cuda.h | 53 +++++ src/core/tensor/tensor_math_opencl.h | 28 +++ src/model/layer/layer.cc | 8 + src/proto/layer.proto | 22 +- test/singa/test_cpp_math.cc | 4 +- test/singa/test_tensor.cc | 35 +-- test/singa/test_tensor_math.cc | 84 +++++++ 19 files changed, 1135 insertions(+), 644 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index 21b3804..67a82e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6) PROJECT(singa) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -std=c++11") # Flags IF(UNIX OR APPLE) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/core/math.h ---------------------------------------------------------------------- diff --git a/include/singa/core/math.h b/include/singa/core/math.h deleted file mode 100644 index 511d9ee..0000000 --- a/include/singa/core/math.h +++ /dev/null @@ -1,273 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef SINGA_CORE_MATH_H_ -#define SINGA_CORE_MATH_H_ -#include -#include "singa/core/common.h" -#include "singa/utils/logging.h" - -namespace singa { - -/// \file math.h Math functions for linear algebra, neural net and random -/// operations. -/// All functions have a template argument, DType for DataType, Lib for the -/// backend library, e.g., lib::Cublas, lib::Cudnn, etc. - -/// Some operations would have many config/hyper-parameters, e.g., Conv, and -/// these config vary among diff implementations, e.g., cuda/cudnn/opencl. -/// To separate the modules, we pass a OpConf pointer to the Tensor Op function. -/// The specific fields are implemented by inheriting OpConf, and casting the -/// pointer between the base and the sub-class. -class OpConf { - public: - template - T* CastTo() { - static_assert(std::is_base_of::value, - "The cast type must be a sub-class of OpConf"); - return static_cast(this); - } -}; - -// ================Linear algebra functions==================================== -template -void Sum(int count, const Blob* input, DType* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Abs(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Sign(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Base is e, Neper number -template -void Exp(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Natual logarithm, the base is e, Neper number. -template -void Log(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Sqrt(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Tanh(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Do v^x for every v from the input tensor -template -void Pow(int count, DType x, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Do v^x for every v from the lhs and every x from rhs -template -void Pow(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Clamp every element into [low, high] -template -void Clamp(int count, DType low, DType high, const Blob* input, Blob* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = x + input -template -void Add(int count, DType x, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = x * input -/// div could be enabled by calling Mult with 1/x -template -void Mult(int count, DType x, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = lhs + rhs -template -void Add(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = lhs - rhs -template -void Sub(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = lhs * rhs -template -void Mult(int count, const Blob* lhs, const Blob* rhs, Blob* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = lhs / rhs -template -void Div(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// outer-product. -/// lhs and rhs are vectors of len m and n. ret is matrix of shape m * n -template -void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// TODO(wangwei) unify SumRow and SumCol. -/// Sum the rows of the input matrix into a vector -template -void SumRow(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} -/// Sum the rows of the input matrix into a vector -template -void SumCol(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// TODO(wangwei) unify AddRow and AddCol. -/// Add the vector v to every row of A as the row of ret -template -void AddRow(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// Add the vector v to every column of A as the column of ret -template -void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas -// ===== Level 1 -/// return the index of the element with the max value. -template -void Amax(int count, const Blob* input, int* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// return the index of the element with the min value. -template -void Amin(int count, const Blob* input, int* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} -/// ret = sum |x| for all x in input -template -void Asum(int count, const Blob* input, DType* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret = alpha * input + ret -template -void Axpy(int count, DType alpha, const Blob* input, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret *= x -template -void Scale(int count, DType x, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret, - Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// ===== Level 2 -/// ret = alpha * op(A) * v + beta * ret. -/// op(A) = A if trans = false; A^T otherwise; rows(A) = m, cols(A) = n. -template -void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v, - DType beta, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// ===== Level 3 -/// ret = alpha * op(A) * op(B) + beta * ret. -/// op(A) = A if trans = false; A^T otherwise; rows(A) = m, cols(A) = n. -template -void GEMV(bool transA, bool transB, int m, int n, int k, DType alpha, - const Blob* A, const Blob* B, DType beta, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// ================Random functions=========================================== -// The random generator should be extracted from ctx. -template -void Uniform(int count, DType low, DType high, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -template -void Gaussian(int count, DType mean, DType std, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// each element of ret would be 1 with prob p and 0 with 1-p. 0<= p <= 1 -template -void Bernoulli(int count, DType p, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -/// ret[i] would be 1 with prob p[i] and 0 with 1-p[i]. 0<= p[i] <= 1 -template -void Bernoulli(int count, const Blob* p, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// ================Neural net functions======================================= -/// Do 2D conv. -/// c is input image channel, w is input width, h is input height -/// nb_kernel is output channel, kw, and kh are kenerl width and height -/* -template -void Conv2D(int c, int w, int h, int nb_kernel, int kw, int kh, - const Blob* input, const Blob* kernel, Blob* ret, Context* ctx) { - LOG(FATAL) << "Not Implemented"; -} -*/ -} // namespace singa - -#endif // SINGA_CORE_MATH_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/core/tensor.h ---------------------------------------------------------------------- diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index 725f657..4278078 100644 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -20,23 +20,29 @@ #define SINGA_CORE_TENSOR_H_ #include +#include #include "singa/core/common.h" #include "singa/core/device.h" -#include "singa/core/math.h" #include "singa/proto/core.pb.h" #include "singa/utils/logging.h" using std::vector; +using std::tuple; namespace singa { typedef vector Shape; inline int Product(Shape shape) { if (shape.size() == 0) return 0; + return Product(shape.begin(), shape.end()); +} + +inline int Product(vector::iterator begin, vector::iterator end) { + CHECK(begin != end); int v = 1; - for (auto s : shape) - v *= s; + for (auto it = being; it < end; it++) + v* = *it; return v; } @@ -60,19 +66,20 @@ inline int SizeOf(DataType t) { class Tensor { public: ~Tensor(); - Tensor() = default; - explicit Tensor(const Shape& shape, DataType dtype = kFloat32); + Tensor(); + Tensor(Shape&& shape, DataType dtype = kFloat32); + Tensor(const Shape& shape, DataType dtype = kFloat32); + Tensor(Shape&& shape, Device* dev, DataType dtype = kFloat32); Tensor(const Shape& shape, Device* dev, DataType dtype = kFloat32); /// Copy Tensor to share the internal data. No deep copy. Tensor(const Tensor& from); - /// Copy Tensor to share the internal data. No deep copy. Tensor(Tensor&& from); /// For functions in xx_math.cc to access the blob. /// Users should not operate against Blob directly. - /// It will malloc memory for the tensor if not allocated before. + /// blob_ is allocated in constructors. Blob* blob() const { return blob_; } @@ -82,9 +89,9 @@ class Tensor { } /// Return immutable Tensor values with given type. - template - const T* data() { - return static_cast (blob()->data()); + template + const DType* data() const { + return static_cast (blob()->data()); } /// data type, including kFloat16, kFloat32, kInt @@ -96,20 +103,28 @@ class Tensor { return shape_; } + int nDim() const { + return shape_.size(); + } + bool transpose() const { return transpose_; } + /// Return number of total elements int Size() const { return blob_->size() / SizeOf(data_type_); } + /// Return memory size (i.e., Bytes) int MemSize() const { return blob_->size(); } + /// Reset the tensor shape, it may reallocate blob, if MemSize() changes. void ReShape(const Shape& shape); + /// Reset the data type, it would reallocate blob if type changes. void AsType(DataType type); /// Reset the device. @@ -119,8 +134,9 @@ class Tensor { /// Equivalent to ToDevice(host_dev). void ToHost(); - /// For init the tensor values, copy 'size' bytes data. - void CopyDataFromHostPtr(const void* src, size_t size); + /// For init the tensor values, copy 'num' elements. + template + void CopyDataFromHostPtr(const DType* src, int num); /// Copy data from another Tensor which may be on a diff device. /// Meta data would not be copied! @@ -141,49 +157,39 @@ class Tensor { /// Copy the meta info with data blob shared. void operator=(Tensor&& t); + void operator+=(const Tensor& t); - /* - void operator+=(Tensor&& t); + // void operator+=(Tensor&& t); void operator-=(const Tensor& t); - void operator-=(Tensor&& t); + // void operator-=(Tensor&& t); void operator*=(const Tensor& t); - void operator*=(Tensor&& t); + // void operator*=(Tensor&& t); void operator/=(const Tensor& t); - void operator/=(Tensor&& t); + // void operator/=(Tensor&& t); // Scalar operations. /// T is a scalar type - template - void operator+=(const T x); + template + void operator+=(DType x); /// T is a scalar type - template - void operator-=(const T x); + template + void operator-=(const DType x); /// T is a scalar type - template - void operator*=(const T x); + template + void operator*=(const DType x); /// T is a scalar type - template - void operator/=(const T x); - - void Log(int base = 2); - void Tanh(); - void Sigmoid(); - void ReLU(); - - // random functions. - void Uniform(float low, float high); - template - void Gaussian(float mean, float std); + template + void operator/=(const DType x); /// save Tensor into a proto msg // void ToProto(TensorProto* t); /// load Tensor from proto msg // void FromProto(const TensorProto& t); - */ + protected: bool transpose_ = false; DataType data_type_ = kFloat32; @@ -194,142 +200,131 @@ class Tensor { Shape shape_; }; -/// For tensors with sparse content, e.g., missing columns or rows. +// For tensors with sparse content, e.g., missing columns or rows. // class SparseTensor : public Tensor {}; -// ==================Simple Linear Algebra Operations========================= -/* -Tensor Tanh(const Tensor& t); -Tensor Log(const Tensor& t); -Tensor Sigmoid(const Tensor& t); -Tensor ReLU(const Tensor& t); -Tensor Softmax(const Tensor& t); -*/ +/// Copy 'num' elements of src to dst. +/// The first 'src_offset' ('dst_offset') elements will be skipped. void CopyData(Tensor* dst, const Tensor& src, - int msize, + int num, int src_offset = 0, int dst_offset = 0); -// element-wise ops +/// Copy 'nBytes' bytes of src data to dst. +/// The first 'src_offset' ('dst_offset') bytes will be skipped. +void CopyRawData(Tensor* dst, + const Tensor& src, + int nBytes, + int src_offset = 0, + int dst_offset = 0); + +// ==================Simple Linear Algebra Operations========================= +Tensor Abs(const Tensor& t); +Tensor Exp(const Tensor& t); +Tensor Log(const Tensor& t); +Tensor ReLU(const Tensor& t); +Tensor Sigmoid(const Tensor& t); +Tensor Sign(const Tensor& t); +Tensor Sqrt(const Tensor& t); +Tensor Tanh(const Tensor& t); + +/// Regarding the internal data as 2d, with shape_[0]*...*shape_[axis] rows, +/// and shape_[axis+1]*...*shape_[nDim()] columns. +/// and do softmax along each row. +Tensor Softmax(const Tensor& t, int axis = -1); +void Softmax(const Tensor& t, Tensor* ret, int axis = -1); + +/// Element-wise opeartion, ret[i]=t[i]^x +template +Tensor Pow(const Tensor& t, DType x); +/// Element-wise opeartion, ret[i]=t[i]^x +template +void Pow(const Tensor& t, DType x, Tensor* ret); +/// Element-wise opeartion, ret[i]=baes[i]^exp[i] +Tensor Pow(const Tensor& base, Tensor exp); +/// Element-wise opeartion, ret[i]=baes[i]^exp[i] +void Pow(const Tensor& base, const Tensor& exp, Tensor* ret); Tensor operator+(const Tensor& lhs, const Tensor& rhs); void Add(const Tensor& lhs, const Tensor& rhs, Tensor* ret); -/* Tensor operator-(const Tensor& lhs, const Tensor& rhs); void Sub(const Tensor& lhs, const Tensor& rhs, Tensor* ret); Tensor operator*(const Tensor& lhs, const Tensor& rhs); -void operator*(const Tensor& lhs, const Tensor& rhs, Tensor* ret); +void EltwiseMult(const Tensor& lhs, const Tensor& rhs, Tensor* ret); Tensor operator/(const Tensor& lhs, const Tensor& rhs); -void operator/(const Tensor& lhs, const Tensor& rhs, Tensor* ret); +void Div(const Tensor& lhs, const Tensor& rhs, Tensor* ret); -template -Tensor operator+(const T x, const Tensor& t); -template -void operator+(const T x, const Tensor& t, Tensor* ret); +template +Tensor operator+(const Tensor& t, DType x); +template +void Add(const Tensor& t, DType x, Tensor* ret); -template -Tensor operator-(const T x, const Tensor& t); -template -void operator-(const T x, const Tensor& t, Tensor* ret); +template +Tensor operator-(const Tensor& t, DType x); +template +void Sub(const Tensor& t, DType x, Tensor* ret); -template -Tensor operator*(const T x, const Tensor& t); -template -void operator*(const T x, const Tensor& t, Tensor* ret); +template +Tensor operator*(const Tensor& t, DType x); +template +void EltwiseMult(const Tensor& t, DType x, Tensor* ret); -template -Tensor operator/(const T x, const Tensor& t); -template -void operator/(const T x, const Tensor& t, Tensor* ret); +template +Tensor operator/(const Tensor& t, DType x); +template +void Div(const Tensor& t, DType x, Tensor* ret); //================Blas operations============================================ +// ===== Level 1 +// TODO(wangwei) make amax/amin/asum a member function of tensor +// void Amax(Tensor, Context* ctx); Get the index of the max value in a vector +// void Asum(Tensor Context* ctx); + +// template +// void Axpy(DType x, const Blob& t, Blob* ret, Context* ctx); + +/// Do matrix vector multipication or matrix matrix multiplication depdending +/// on the Tensor shape. ret = lhs * rhs +template Tensor Mult(const Tensor& lhs, const Tensor& rhs); +/// Do matrix vector multipication or matrix matrix multiplication depdending +/// on the Tensor shape. ret = lhs * rhs +template void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret); -tempalte T Dot(const Tensor& lhs, const Tensor& rhs); - -//================Neural Net operations====================================== +/// Do matrix vector multipication or matrix matrix multiplication depdending +/// on the Tensor shape. ret = alpha lhs * rhs + beta * ret +template +Tensor Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs); +/// Do matrix vector multipication or matrix matrix multiplication depdending +/// on the Tensor shape. ret = alpha lhs * rhs + beta * ret +template +void Mult(DType alpha, const Tensor& lhs, DType beta, const Tensor& rhs, + Tensor* C); -/// Convolution Op. 'Conf' is ConvConf; -void Conv(const OpConf* conf, - const Tensor& input, - const Tensor& W, - const Tensor &b, - Tensor* ret); +// tempalte T Dot(const Tensor& lhs, const Tensor& rhs); //================Random operations========================================== -Tensor Uniform(float low, float high, const Shape& shape, Device* dev); - -Tensor Gaussian(float mean, float std, const Shape& shape, Device* dev); -*/ -//============================================================================ -/// typedef DType accroding to type value. -/// DType would be used in the code block __VA_ARGS__. -#define TYPE_SWITCH(type, DType, ...) \ - do { \ - switch (type) { \ - case kFloat32: { \ - typedef float DType; \ - { __VA_ARGS__ } \ - break; \ - } \ - case kInt: { \ - typedef int DType; \ - { __VA_ARGS__ } \ - break; \ - } \ - case kChar: { \ - typedef char DType; \ - { __VA_ARGS__ } \ - break; \ - } \ - default: \ - LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \ - } \ - } while (0) - -/// typedef DType and Lib according to values of type and lib respectively. -/// type is from DataType, and lib is from LibType. -/// DType and Lib would be used in __VA_ARGS__. -#define TYPE_LIB_SWITCH(dtype, DType, ltype, Lib, ...) \ - do { \ - const int _SwitchShift = 3; \ - int _SwitchHash = ((dtype) << _SwitchShift) + (ltype); \ - switch (_SwitchHash) { \ - case ((kFloat32 << _SwitchShift) + kCuda): { \ - typedef float DType; \ - typedef lib::Cuda Lib; \ - { __VA_ARGS__ } \ - break; \ - } \ - case ((kFloat32 << _SwitchShift) + kCudnn): { \ - typedef float DType; \ - typedef lib::Cudnn Lib; \ - { __VA_ARGS__ } \ - break; \ - } \ - case ((kFloat32 << _SwitchShift) + kCpp): { \ - typedef float DType; \ - typedef lib::Cpp Lib; \ - { __VA_ARGS__ } \ - break; \ - } \ - case ((kFloat32 << _SwitchShift) + kOpencl): { \ - typedef float DType; \ - typedef lib::Opencl Lib; \ - { __VA_ARGS__ } \ - break; \ - } \ - default: \ - LOG(FATAL) << "Unknown combination of data type " \ - << DataType_Name(dtype) << " and library " \ - << LibType_Name(ltype); \ - } \ - } while (0) - - +/// For each element x set x = 0 if random() < p; otherwise x = 1. +Tensor Bernoulli(float p, Blob* t); +/// Fill in Tensor 't' following uniform distribution. +Tensor Uniform(float low, DType high, Blob* t); +/// Fill in Tensor 't' following Gaussian distribution. +Tensor Gaussian(float mean, DType std, Blob* t); +//================Neural Net operations====================================== +// following API of cudnn, e.g., conv, pool, lrn, batchnorm, softmax +void ConvFwd(const ConvConf& conf, const Tensor& x, const Tensor& w, Tensor* y); +void ConvBwdBias(const ConvConf& conf, const Tensor& dy, Tensor* db); +void ConvBwdFilter(const ConvConf& conf, const Tensor& dy, const Tensor& x, + Tensor* dw); +void ConvBwdData(const ConvConf& conf, const Tensor& dy, const Tensor& w, + Tensor* db); +void PoolFwd(const PoolConf& conf, const Tensor& x, Tensor* y, + Tensor* mask = nullptr); +void PoolBwd(const PoolConf& conf, const Tensor& y, const Tensor& dy, + const Tensor& x, Tensor* dx); } // namespace singa #endif // SINGA_CORE_TENSOR_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/model/layer.h ---------------------------------------------------------------------- diff --git a/include/singa/model/layer.h b/include/singa/model/layer.h index 37f3fa8..7b9b6d4 100644 --- a/include/singa/model/layer.h +++ b/include/singa/model/layer.h @@ -45,7 +45,9 @@ class Layer { } /// Set meta data fields configured in 'conf' (a proto message). - virtual void Setup(const LayerConf& conf) {} + virtual void Setup(const LayerConf& conf) { + name_ = conf.name(); + } /// Do feature transformation for given 'input' Tensor. /// It is the forward pass for feed-forward nets and rnn nets. @@ -67,6 +69,7 @@ class Layer { const vector& input) { return vector{}; } + // return /// Move the layer (including its parameters and other Tensor) onto the given /// device @@ -82,28 +85,26 @@ class Layer { } /// Serialize the layer info, including params)_, into a LayerConf message. - virtual std::string ToProto(LayerConf* param) const = 0; + virtual std::string ToProto(LayerConf* conf) const { + conf->set_name(name_); + } /// Serialize the layer info, including params_, into a string representing /// a LayerParameter message. - /* - std::string ToProtoStr() const { - std:: string str; - SerializeToString(&str); - } - */ + std::string ToProtoStr() const; /// Return all Param instances of this layer. - const vector params() const { return params_; } + /// Each layer could cache the Param objects. + /// To save memory of , it can also create it when this function + /// is called + const vector GetParam(); /// Each layer instance would optionally have a name. /// Used for debugging and logging. const std::string name() const { return name_; } - protected: std::string name_; - std::vector params_; }; } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/include/singa/model/param.h ---------------------------------------------------------------------- diff --git a/include/singa/model/param.h b/include/singa/model/param.h new file mode 100644 index 0000000..b859b1c --- /dev/null +++ b/include/singa/model/param.h @@ -0,0 +1,97 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#ifndef SINGA_MODEL_PARAM_H_ +#define SINGA_MODEL_PARAM_H_ +#include "singa/core/tensor.h" +#include +#include +using std::vector; +using std::string; +namespace singa { +/// Base Param class for storing set of parameters, e.g., a weight matrix or a +/// bias vector. +/// It includes multiple Tensor s for parameter values, gradients, etc. +class Param { + public: + ~Param(); + Param(const ParamSpec& conf); + Param(Param&& p); + Param(const Param& p); + void operator=(Param&& p); + void operator=(const Param& p); + + Tensor& value() { + return value_; + } + + Tensor& grad() { + return grad_; + } + + void set_value(const Tensor& t) { + value_ = t; + } + + void set_value(Tensor&& t) { + value_ = std::move(t); + } + + void set_grad(const Tensor& t) { + isGradValid_ = true; + grad_ = t; + } + + void set_grad(Tensor&& t) { + grad_ = std::move(t); + } + + // void Compress(); + // string ToString(); + + protected: + string name_; + Tensor value_; + float lr_mult_ = 1.0f, decay_mult_ = 1.0f; +}; + +class ParamGrad { +// return grad tensor or data to recover the grad tensor, e.g., if W = U * V +// then, ParamGrad could just store U and V. provide func for serailize and +// deserialize. +}; + +// updater just copy the ParamGrad to a device and submit ops to that device, e.g., +// add grad; check update_condidtion; apply sgd; copy back. +// consider rpc (no rmda). + +Param* CreateParam(string type) { + Param* p = nullptr; + if (type == "default") + p = new Param(); + else + LOG(FATAL) << "Currently param type " << type << " is not implemented." + << "Pls use the 'default' type"; + return p; +} +#endif // SINGA_MODEL_PARAM_H_ + +} // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 5bdab6f..4976a32 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -49,6 +49,7 @@ void Device::FreeBlob(Blob* blob) { void Device::CopyData(Blob* dst, const Blob& src, int len, int dst_offset, int src_offset) { + memcpy(reinterpret_cast(dst->mutable_data()) + dst_offset, (const Byte*)src.data() + src_offset, len); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/cpp_math.cc ---------------------------------------------------------------------- diff --git a/src/core/math/cpp_math.cc b/src/core/math/cpp_math.cc deleted file mode 100644 index 638d693..0000000 --- a/src/core/math/cpp_math.cc +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "singa/core/math.h" -#include "singa/core/common.h" - -#ifdef USE_CBLAS -#include -#endif - -namespace singa { -template<> -void Add(int count, - const Blob* lhs, - const Blob* rhs, - Blob* ret, - Context* ctx) { - // CHECK_EQ(ctx->stream, nullptr); - float *dptr = static_cast(ret->mutable_data()); - const float *lptr = static_cast(lhs->data()); - const float *rptr = static_cast(rhs->data()); - for (int i = 0; i < count; i++) { - dptr[i] = lptr[i] + rptr[i]; - } -} - -#ifdef USE_CBLAS -template<> -void Dot(int count, - const Blob* lhs, - const Blob* rhs, - float* ret, - Context* ctx) { - float dptr = ret->mutable_data(), lptr = lhs->data(), rptr = rhs->data(); - *ret = cblas_sdot(count, lptr, 1, rptr, 1); -} - -#endif -} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/cuda_math.cc ---------------------------------------------------------------------- diff --git a/src/core/math/cuda_math.cc b/src/core/math/cuda_math.cc deleted file mode 100644 index 1cff1c2..0000000 --- a/src/core/math/cuda_math.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "singa/core/math.h" -#include "singa/core/common.h" - - -namespace singa { - -#ifdef USE_CUDA -template<> -void Add(int count, const Blob* lhs, const Blob* rhs, - Blob* ret, Context* ctx) { - cublasSetStream(ctx->handle, ctx->stream); - cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1); - cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1); -} - -#ifdef USE_CUDNN -template<> -void Conv(const OpConf *conf, - const Blob* input, - const Blob* W, - const Blob* b, - Blob* ret, - Context* ctx) { - // auto conv_conf = conf->CastTo(); - // conv op -} - -#endif -#endif -} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/math/opencl_math.cc ---------------------------------------------------------------------- diff --git a/src/core/math/opencl_math.cc b/src/core/math/opencl_math.cc deleted file mode 100644 index 7012610..0000000 --- a/src/core/math/opencl_math.cc +++ /dev/null @@ -1,24 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "singa/core/math.h" - -namespace singa { - - -} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 8fdc2ed..51b785e 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -15,28 +15,42 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "singa/core/tensor.h" -#include "singa/core/math.h" +#include "./tensor_math.h" +#include "./tensor_math_cpp.h" +#include "./tensor_math_cuda.h" +#include "./tensor_math_opencl.h" namespace singa { + Tensor::~Tensor() { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); blob_ = nullptr; } +Tensor::Tensor() { + device_ = &hostDeviceSingleton; +} + Tensor::Tensor(const Shape& shape, DataType dtype) : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) { device_ = &hostDeviceSingleton; blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); } - +Tensor::Tensor(Shape&& shape, DataType dtype) + : data_type_(dtype), device_(&hostDeviceSingleton), shape_(shape) { + device_ = &hostDeviceSingleton; + blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); +} Tensor::Tensor(const Shape& shape, Device* device, DataType dtype) : data_type_(dtype), device_(device), shape_(shape) { blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); } - +Tensor::Tensor(Shape&& shape, Device* device, DataType dtype) + : data_type_(dtype), device_(device), shape_(shape) { + blob_ = device_->NewBlob(Product(shape_) * SizeOf(data_type_)); +} Tensor::Tensor(const Tensor& t) : transpose_(t.transpose_), data_type_(t.data_type_), @@ -50,7 +64,7 @@ Tensor::Tensor(Tensor&& t) : transpose_(t.transpose_), data_type_(t.data_type_), device_(t.device_), - shape_(t.shape_) { + shape_(std::move(t.shape_)) { blob_ = t.blob_; t.blob_ = nullptr; } @@ -90,18 +104,26 @@ void Tensor::ToHost() { ToDevice(device_->host()); } -void Tensor::CopyDataFromHostPtr(const void* src, size_t size) { +template +void Tensor::CopyDataFromHostPtr(const DType* src, int num) { + CHECK_EQ(sizeof(DType), SizeOf(data_type_)) << "data_type is " + << DataType_Name(data_type_) + << " user given type is of size " + << sizeof(DType); if (src != nullptr) - device_->CopyDataFromHostPtr(blob(), src, size); + device_->CopyDataFromHostPtr(blob(), src, sizeof(DType) * num); else LOG(WARNING) << "Copy data from null host ptr"; } +template void Tensor::CopyDataFromHostPtr(const float* src, int num); void Tensor::CopyData(const Tensor& src) { CHECK_EQ(Size(), src.Size()); + CHECK(blob_ != nullptr); // Do copy only if the src's blob is already initialized. - if (src.blob_ != nullptr) - singa::CopyData(this, src, Size() * SizeOf(data_type_), 0, 0); + if (src.blob_ != nullptr) { + singa::CopyData(this, src, Size(), 0, 0); + } } Tensor Tensor::Clone() { @@ -112,8 +134,10 @@ Tensor Tensor::Clone() { } Tensor Tensor::T() const { + CHECK_EQ(shape_.size(), 2); Tensor t(*this); t.transpose_ = ~transpose_; + std::swap(shape_[0], shape_[1]); return t; } @@ -132,80 +156,315 @@ void Tensor::operator=(Tensor&& t) { if (blob_ != nullptr && blob_->DecRefCount() == 0) device_->FreeBlob(blob_); transpose_ = t.transpose_; - shape_ = t.shape_; + shape_ = std::move(t.shape_); device_ = t.device_; blob_ = t.blob_; t.blob_ = nullptr; } -void Tensor::operator+=(const Tensor& t) { - Add(*this, t, this); -} -// ====================Tensor Operations======================================= +#define GenUnaryTensorArgMemberFunction(op, fn) \ + void Tensor::op(const Tensor& t) { fn(*this, t, this); } + +GenUnaryTensorArgMemberFunction(operator+=, Add); +GenUnaryTensorArgMemberFunction(operator-=, Sub); +GenUnaryTensorArgMemberFunction(operator*=, EltwiseMult); +GenUnaryTensorArgMemberFunction(operator/=, Div); + +#define GenUnaryScalarArgMemberFunction(op, fn) \ + template \ + void Tensor::op(DType x) { \ + fn(*this, x, this); \ + } \ + template void Tensor::op(float x) + +GenUnaryScalarArgMemberFunction(operator-=, Sub); +GenUnaryScalarArgMemberFunction(operator+=, Add); +GenUnaryScalarArgMemberFunction(operator*=, EltwiseMult); +GenUnaryScalarArgMemberFunction(operator/=, Div); +// ====================Tensor Operations======================================= void CopyData(Tensor* dst, const Tensor& src, - int len, + int num, int dst_offset, int src_offset) { - CHECK_GE(src.MemSize(), src_offset + len); - CHECK_GE(dst->MemSize(), dst_offset + len); + CHECK_GE(src.Size(), src_offset + num); + CHECK_GE(dst->Size(), dst_offset + num); + int width = SizeOf(src.data_type()); + CHECK_EQ(width, SizeOf(dst->data_type())); + CopyRawData(dst, src, num * width, dst_offset * width, src_offset * width); +} + +void CopyRawData(Tensor* dst, + const Tensor& src, + int nBytes, + int dst_offset, + int src_offset) { + CHECK_GE(src.MemSize(), src_offset + nBytes); + CHECK_GE(dst->MemSize(), dst_offset + nBytes); Device* src_dev = src.device(), *dst_dev = dst->device(); Blob* src_blob = src.blob(), *dst_blob = dst->blob(); if (dst_dev->device_lib() != src_dev->device_lib()) { // let the none cpp device conduct copy op if (dst_dev->device_lib() == kCpp) { - src_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset); + src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); } else if (src_dev->device_lib() == kCpp) { - dst_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset); + dst_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); } else { LOG(FATAL) << "Not support mem copy betwee Cuda and OpenCL device"; } } else { - src_dev->CopyData(dst_blob, *src_blob, len, dst_offset, src_offset); + src_dev->CopyData(dst_blob, *src_blob, nBytes, dst_offset, src_offset); } } +//============================================================================ +/// typedef DType accroding to type value. +/// DType would be used in the code block __VA_ARGS__. +#define TYPE_SWITCH(type, DType, ...) \ + do { \ + switch (type) { \ + case kFloat32: { \ + typedef float DType; \ + { __VA_ARGS__ } \ + break; \ + } \ + case kInt: { \ + typedef int DType; \ + { __VA_ARGS__ } \ + break; \ + } \ + case kChar: { \ + typedef char DType; \ + { __VA_ARGS__ } \ + break; \ + } \ + default: \ + LOG(FATAL) << "Unknow data type = " << DataType_Name(type); \ + } \ + } while (0) + +/// typedef DType and Lib according to values of type and lib respectively. +/// type is from DataType, and lib is from LibType. +/// DType and Lib would be used in __VA_ARGS__. +#define TYPE_LIB_SWITCH(dtype, DType, ltype, Lib, ...) \ + do { \ + const int _SwitchShift = 3; \ + int _SwitchHash = ((dtype) << _SwitchShift) + (ltype); \ + switch (_SwitchHash) { \ + case ((kFloat32 << _SwitchShift) + kCuda): { \ + typedef float DType; \ + typedef lib::Cuda Lib; \ + { __VA_ARGS__ } \ + break; \ + } \ + case ((kFloat32 << _SwitchShift) + kCudnn): { \ + typedef float DType; \ + typedef lib::Cudnn Lib; \ + { __VA_ARGS__ } \ + break; \ + } \ + case ((kFloat32 << _SwitchShift) + kCpp): { \ + typedef float DType; \ + typedef lib::Cpp Lib; \ + { __VA_ARGS__ } \ + break; \ + } \ + case ((kFloat32 << _SwitchShift) + kOpencl): { \ + typedef float DType; \ + typedef lib::Opencl Lib; \ + { __VA_ARGS__ } \ + break; \ + } \ + default: \ + LOG(FATAL) << "Unknown combination of data type " \ + << DataType_Name(dtype) << " and library " \ + << LibType_Name(ltype); \ + } \ + } while (0) + + +#define EltwiseUnaryTensorFn(fn, t, ret) \ + do { \ + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ + ret->device()->Submit( \ + [t, ret](Context* ctx) { \ + fn(t.Size(), t.blob(), ret->blob(), ctx); \ + }, \ + {t.blob()}, {ret->blob()}); \ + }); \ + } while (0) + +#define GenUnaryTensorFunction(fn) \ + Tensor fn(const Tensor& t) { \ + Tensor ret(t.shape(), t.device(), t.data_type()); \ + auto* retptr = &ret; \ + EltwiseUnaryTensorFn(fn, t, retptr); \ + return ret; \ + } + +GenUnaryTensorFunction(Abs); +GenUnaryTensorFunction(Exp); +GenUnaryTensorFunction(Log); +GenUnaryTensorFunction(ReLU); +GenUnaryTensorFunction(Sigmoid); +GenUnaryTensorFunction(Sign); +GenUnaryTensorFunction(Sqrt); +GenUnaryTensorFunction(Tanh); -Tensor operator+(const Tensor& lhs, const Tensor& rhs) { - Tensor ret(lhs.shape(), lhs.device()); - Add(lhs, rhs, &ret); +Tensor Softmax(const Tensor& t, int axis) { + Tensor ret(t.shape(), t.device(), t.data_type()); + Softmax(t, &ret, axis); return ret; } -void Add(const Tensor& lhs, const Tensor& rhs, Tensor* ret) { - TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { +void Softmax(const Tensor& t, Tensor* ret, int axis) { + int nrow = 1, ncol = t.Size(), size = ncol; + CHECK_GE(axis, -1); + CHECK_GT(t.shape().size(), 0); + if (axis > -1) { + nrow = Product(t.shape().begin(), t.shape().begin() + axis + 1); + CHECK_EQ(size % nrow, 0) << "Size = " << size << " nrow = " << nrow; + ncol = size / nrow; + } + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { ret->device()->Submit( - [lhs, rhs, ret](Context* ctx) { - Add(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), ctx); + [nrow, ncol, t, ret](Context* ctx) { + Softmax(nrow, ncol, t.blob(), ret->blob(), ctx); }, - {lhs.blob(), rhs.blob()}, {ret->blob()}); - }); + {t.blob()}, {ret->blob()}); + }); } -/* -Tensor operator-(const Tensor& lhs, const Tensor& rhs) { - Tensor ret(lhs.shape(), lhs.device()); - Sub(lhs, rhs, &ret); + +#define EltwiseBinaryTensorFn(fn, lhs, rhs, ret) \ + do { \ + TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { \ + ret->device()->Submit( \ + CHECK_EQ(sizeof(DType), SizeOf(rhs.data_type())); \ + [lhs, rhs, ret](Context* ctx) { \ + fn(lhs.Size(), lhs.blob(), rhs.blob(), ret->blob(), \ + ctx); \ + }, \ + {lhs.blob(), rhs.blob()}, {ret->blob()}); \ + }); \ + } while (0) + +#define GenBinaryTensorFunction(op, fn) \ + Tensor op(const Tensor& lhs, const Tensor& rhs) { \ + Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \ + fn(lhs, rhs, &ret); \ + return ret; \ + } \ + void fn(const Tensor& lhs, const Tensor& rhs, Tensor* ret) { \ + EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \ + } + +GenBinaryTensorFunction(operator+, Add); +GenBinaryTensorFunction(operator-, Sub); +GenBinaryTensorFunction(operator*, EltwiseMult); +GenBinaryTensorFunction(operator/, Div); +GenBinaryTensorFunction(Pow, Pow); + +#define EltwiseTensorScalarFn(fn, t, x, ret) \ + do { \ + TYPE_LIB_SWITCH(t.data_type(), DType, t.device()->device_lib(), Lib, { \ + ret->device()->Submit( \ + static_assert(typeid(x) == typeid(DType), \ + "The Scalar type must match the Tensor data type"); \ + [t, x, ret](Context* ctx) { \ + fn(t.Size(), t.blob(), x, ret->blob(), ctx); \ + }, \ + {t.blob()}, {ret->blob()}); \ + }); \ + } while (0) + +#define GenTensorScalarFunction(op, fn) \ + template \ + Tensor op(const Tensor& t, DType x) { \ + Tensor ret(t.shape(), t.device(), t.data_type()); \ + fn(t, x, &ret); \ + return ret; \ + } \ + template \ + void fn(const Tensor& t, DType x, Tensor* ret) { \ + EltwiseTensorScalarFn(fn, t, x, ret); \ + } \ + template Tensor op(const Tensor& t, float x); \ + template void fn(const Tensor& t, const float x, Tensor* ret) + +GenTensorScalarFunction(operator+, Add); +GenTensorScalarFunction(operator-, Sub); +GenTensorScalarFunction(operator*, EltwiseMult); +GenTensorScalarFunction(operator/, Div); +GenTensorScalarFunction(Pow, Pow); + +// ================Blas operations============================================ +template +Tensor Mult(const Tensor& lhs, const Tensor& rhs) { + Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); + Mult(lhs, rhs, &ret); + return ret; +} +template Tensor Mult(const Tensor& lhs, const Tensor& rhs); + +template +void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret) { + Mult(DType(1), lhs, DType(1), rhs, ret); +} +template void Mult(const Tensor& lhs, const Tensor& rhs, Tensor* ret); + +template +Tensor Mult(DType alpha, const Tensor& A, DType beta, const Tensor& B) { + Tensor ret(A.shape(), A.device(), A.data_type()); + Mult(alpha, A, beta, B, &ret); return ret; } +template Tensor Mult(float alpha, const Tensor& lhs, float beta, + const Tensor& rhs); -void Sub(const Tensor& lhs, const Tensor& rhs, Tensor *ret) { - TYPE_LIB_SWITCH(lhs.data_type(), DType, lhs.device()->device_lib(), Lib, { - ret->device()->Submit( - [lhs, rhs, ret](Context* ctx) { - Sub( - lhs.Size(), - lhs.blob(), - rhs.blob(), - ret->blob(), - ctx);} - , {lhs.blob(), rhs.blob()}, {ret->blob()}); +template +void Mult(SType alpha, const Tensor& A, SType beta, const Tensor& B, Tensor* C) +{ + CHECK_EQ(A.shape().size(), 2); + bool transA = A.transpose(); + int m = transA ? A.shape()[1] : A.shape()[0], n = 0; + if (B.shape().size() == 1) { + n = C->Size(); + TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { + static_assert(std::is_same::value, + "The scalar type must be the same as the tensor data type"); + C->device()->Submit( + [transA, m, n, alpha, A, beta, B, C](Context* ctx) { + GEMV(transA, m, n, alpha, A.blob(), + B.blob(), beta, C->blob(), ctx); + }, + {A.blob(), B.blob()}, {C->blob()}); }); + } else { + CHECK(!C->transpose()); + bool transB = B.transpose(); + int k = transB ? B.shape()[1] : B.shape()[0]; + n = C->shape()[1]; + CHECK_EQ(C->shape()[0], m); + CHECK_EQ(A.Size(), m * k); + CHECK_EQ(B.Size(), n * k); + TYPE_LIB_SWITCH(A.data_type(), DType, A.device()->device_lib(), Lib, { + static_assert(std::is_same::value, + "The scalar type must be the same as the tensor data type"); + C->device()->Submit( + [transA, transB, m, n, k, alpha, A, beta, B, C](Context* ctx) { + GEMM(transA, transB, m, n, k, alpha, A.blob(), + B.blob(), beta, C->blob(), ctx); + }, + {A.blob(), B.blob()}, {C->blob()}); + }); + } } +template void Mult(float alpha, const Tensor& lhs, float beta, + const Tensor& rhs, Tensor* ret); -// ================Blas operations============================================ // ================Neural Net operations====================================== - +/* void Conv(const OpConf* conf, const Tensor& input, const Tensor& W, const Tensor& b, Tensor* ret) { TYPE_LIB_SWITCH(input.data_type(), DType, input.device()->nn_lib(), Lib, { @@ -218,5 +477,33 @@ void Conv(const OpConf* conf, const Tensor& input, const Tensor& W, }); } */ +void Bernoulli(float threshold, Tensor* t) { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + t->device()->Submit( + [threshold, t](Context* ctx) { + Bernoulli(t->Size(), threshold, t->blob(), ctx); + }, + {}, {t->blob()}); + }); +} + +void Uniform(float low, float high, Tensor* t) { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + t->device()->Submit( + [low, high, t](Context* ctx) { + Uniform(t->Size(), low, high, t->blob(), ctx); + }, + {}, {t->blob()}); + }); +} +void Gaussian(float mean, float std, Tensor* t) { + TYPE_LIB_SWITCH(t->data_type(), DType, t->device()->nn_lib(), Lib, { + t->device()->Submit( + [mean, std, t](Context* ctx) { + Gaussian(t->Size(), mean, std, t->blob(), ctx); + }, + {}, {t->blob()}); + }); +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h new file mode 100644 index 0000000..a4f68e3 --- /dev/null +++ b/src/core/tensor/tensor_math.h @@ -0,0 +1,302 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_CORE_MATH_H_ +#define SINGA_CORE_MATH_H_ +#include +#include "singa/core/common.h" +#include "singa/utils/logging.h" + +namespace singa { + +/// \file math.h Math functions for linear algebra, neural net and random +/// operations. +/// All functions have a template argument, DType for DataType, Lib for the +/// backend library, e.g., lib::Cublas, lib::Cudnn, etc. + +/// Some operations would have many config/hyper-parameters, e.g., Conv, and +/// these config vary among diff implementations, e.g., cuda/cudnn/opencl. +/// To separate the modules, we pass a OpConf pointer to the Tensor Op function. +/// The specific fields are implemented by inheriting OpConf, and casting the +/// pointer between the base and the sub-class. +class OpConf { + public: + template + T* CastTo() { + static_assert(std::is_base_of::value, + "The cast type must be a sub-class of OpConf"); + return static_cast(this); + } +}; + +// ================Linear algebra functions==================================== +/// ret[i] = |input[i]| +template +void Abs(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// sum all elements of input into ret +template +void Sum(int count, const Blob* input, DType* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret[i] = sign(input[i]) +template +void Sign(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Base is e, Neper number. ret[i]=exp(input[i]) +template +void Exp(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Natual logarithm, the base is e, Neper number ret[i]=log(input[i]). +template +void Log(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Element-wise operation, ret[i]=sqrt([input[i]) +template +void Sqrt(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Element-wise operation, ret[i]=tanh([input[i]) +template +void Tanh(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// Element-wise operation, ret[i]=max(0, input[i]) +template +void ReLU(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// Element-wise operation, ret[i]=sigmoid([input[i]) +template +void Sigmoid(int count, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Element-wise operation, do v^x for every v from the input tensor +template +void Pow(int count, const Blob* input, DType x, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Element-wise operation, do v^x for every v from the lhs and every x from rhs +template +void Pow(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Element-wise operation, clamp every element into [low, high] +/// if x>high, then x=high; if x +void Clamp(int count, DType low, DType high, const Blob* input, Blob* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret = input + x +template +void Add(int count, const Blob* input, DType x, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret = input - x +template +void Sub(int count, const Blob* input, DType x, Blob* ret, Context* ctx) { + Add(count, input, -x, ret, ctx); +} +/// ret = input * x +template +void EltwiseMult(int count, const Blob* input, DType x, Blob* ret, Context* ctx) +{ + LOG(FATAL) << "Not Implemented"; +} +/// ret = input / x +template +void Div(int count, const Blob* input, DType x, Blob* ret, Context* ctx) { + EltwiseMult(count, input, DType(1) / x, ret, ctx); +} + +/// ret = lhs + rhs +template +void Add(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret = lhs - rhs +template +void Sub(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret = lhs * rhs +template +void EltwiseMult(int count, const Blob* lhs, const Blob* rhs, Blob* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret = lhs / rhs +template +void Div(int count, const Blob* lhs, const Blob* rhs, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// outer-product. +/// lhs and rhs are vectors of len m and n. ret is matrix of shape m * n +template +void Outer(int m, int n, const Blob* lhs, const Blob* rhs, Blob* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// TODO(wangwei) unify SumRow and SumCol. +/// Sum the rows of the input matrix into a vector +template +void SumRow(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// Sum the rows of the input matrix into a vector +template +void SumCol(int nrow, int ncol, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// TODO(wangwei) unify AddRow and AddCol. +/// Add the vector v to every row of A as the row of ret +template +void AddRow(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// Add the vector v to every column of A as the column of ret +template +void AddCol(int nrow, int ncol, const Blob* A, const Blob* v, Blob* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// ===== BLAS functions, ref to http://docs.nvidia.com/cuda/cublas +// ===== Level 1 +/// return the index of the element with the max value. +template +void Amax(int count, const Blob* input, int* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// return the index of the element with the min value. +template +void Amin(int count, const Blob* input, int* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +/// ret = sum |x| for all x in input +template +void Asum(int count, const Blob* input, DType* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret = alpha * input + ret +template +void Axpy(int count, DType alpha, const Blob* input, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +/// ret *= x +template +void Scale(int count, DType x, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +template +void Dot(int count, const Blob* lhs, const Blob* rhs, DType* ret, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// ===== Level 2 +/// ret = alpha * op(A) * v + beta * ret. +/// op(A) = A if trans = false; A^T otherwise; rows(op(A)) = m, cols(op(A)) = n. +template +void GEMV(bool trans, int m, int n, DType alpha, const Blob* A, const Blob* v, + DType beta, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// ===== Level 3 +/// ret = alpha * op(A) * op(B) + beta * ret. +/// op(A) = A if trans = false; A^T otherwise; rows(ret) = m, cols(ret) = n. +template +void GEMM(bool transA, bool transB, int m, int n, int k, DType alpha, + const Blob* A, const Blob* B, DType beta, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// ================Random functions=========================================== +/// Each element of ret would be 1 with prob p and 0 with 1-p. 0<= p <= 1 +// Get the random generator from 'ctx' +// If DType is not float, then convert the threshold to DType +template +void Bernoulli(int count, float threshold, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +// The random generator should be extracted from ctx. +// If DType is not float, then convert the low and high to DType +template +void Uniform(int count, float low, float high, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} +// The random generator should be extracted from ctx. +// If DType is not float, then convert the mean and std to DType +template +void Gaussian(int count, float mean, float std, Blob* ret, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +// ================Neural net functions======================================= +template +void ConvFwd(ConvConf* conf, const Blob* x, const Blob* w, Blob* y, + Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +template +void ConvBwdBias(const ConvConf* conf, const Blob* dy, Blob* db, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +template +void PoolFwd(const PoolConf* conf, const Blob* x, Blob* y, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +template +void PoolBwd(const PoolConf* conf, const Blob* y, const Blob* dy, const Blob* x, + Blob* dx, Context* ctx) { + LOG(FATAL) << "Not Implemented"; +} + +} // namespace singa + +#endif // SINGA_CORE_MATH_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_cpp.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h new file mode 100644 index 0000000..a953085 --- /dev/null +++ b/src/core/tensor/tensor_math_cpp.h @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ +#define SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ +#include "./tensor_math.h" +#include "singa/core/common.h" + +#ifdef USE_CBLAS +#include +#endif + +namespace singa { +template<> +void Add(int count, + const Blob* lhs, + const Blob* rhs, + Blob* ret, + Context* ctx) { + // CHECK_EQ(ctx->stream, nullptr); + float *dptr = static_cast(ret->mutable_data()); + const float *lptr = static_cast(lhs->data()); + const float *rptr = static_cast(rhs->data()); + for (int i = 0; i < count; i++) { + dptr[i] = lptr[i] + rptr[i]; + } +} + +#ifdef USE_CBLAS +template<> +void Dot(int count, + const Blob* lhs, + const Blob* rhs, + float* ret, + Context* ctx) { + float dptr = ret->mutable_data(), lptr = lhs->data(), rptr = rhs->data(); + *ret = cblas_sdot(count, lptr, 1, rptr, 1); +} + +#endif +} + +#endif // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_cuda.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_cuda.h b/src/core/tensor/tensor_math_cuda.h new file mode 100644 index 0000000..e1c72d8 --- /dev/null +++ b/src/core/tensor/tensor_math_cuda.h @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_ +#define SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_ +#include "./tensor_math.h" +#include "singa/core/common.h" + + +namespace singa { + +#ifdef USE_CUDA +template<> +void Add(int count, const Blob* lhs, const Blob* rhs, + Blob* ret, Context* ctx) { + cublasSetStream(ctx->handle, ctx->stream); + cublasScopy(ctx->handle, count, lhs->data(), 1, ret->mutable_data(), 1); + cublasSaxpy(ctx->handle, 1.0f, rhs->data(), 1, ret->mutable_data(), 1); +} + +#ifdef USE_CUDNN +template<> +void Conv(const OpConf *conf, + const Blob* input, + const Blob* W, + const Blob* b, + Blob* ret, + Context* ctx) { + // auto conv_conf = conf->CastTo(); + // conv op +} + +#endif +#endif +} + + +#endif // SINGA_CORE_TENSOR_TENSOR_MATH_CUDA_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/core/tensor/tensor_math_opencl.h ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor_math_opencl.h b/src/core/tensor/tensor_math_opencl.h new file mode 100644 index 0000000..c4b1347 --- /dev/null +++ b/src/core/tensor/tensor_math_opencl.h @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SINGA_CORE_TENSOR_TENSOR_MATH_OPENCL_H_ +#include "./tensor_math.h" + +namespace singa { + + +} + + +#endif // SINGA_CORE_TENSOR_TENSOR_MATH_OPENCL_H_ http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/model/layer/layer.cc ---------------------------------------------------------------------- diff --git a/src/model/layer/layer.cc b/src/model/layer/layer.cc index 1f0e34d..0e83cde 100644 --- a/src/model/layer/layer.cc +++ b/src/model/layer/layer.cc @@ -18,5 +18,13 @@ #include "singa/model/layer.h" namespace singa { +const vector ComputeFeature(int flag, const vector& input) { + const vector input_blobs; +} + +void ComputeFeature(int flag, const vector& input) { + const vector input_blobs; + +} } // namespace singa http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/src/proto/layer.proto ---------------------------------------------------------------------- diff --git a/src/proto/layer.proto b/src/proto/layer.proto index bb87af9..0fbbb5d 100644 --- a/src/proto/layer.proto +++ b/src/proto/layer.proto @@ -97,6 +97,10 @@ message ParamSpec { // The multiplier on the global weight decay for this parameter. optional float decay_mult = 4 [default = 1.0]; + + // SINGA field for creating diff Param, e.g. SparseParam or CompressableParam + // Curently only have a default param implementation. + optional string type = 20 [default = "default"]; } // NOTE @@ -154,27 +158,27 @@ message LayerConf { optional ConcatConf concat_conf = 104; optional ContrastiveLossConf contrastive_loss_conf = 105; optional ConvolutionConf convolution_conf = 106; - optional DataConf data_conf = 107; + // optional DataConf data_conf = 107; optional DropoutConf dropout_conf = 108; - optional DummyDataConf dummy_data_conf = 109; + // optional DummyDataConf dummy_data_conf = 109; optional EltwiseConf eltwise_conf = 110; optional EmbedConf embed_conf = 137; optional ExpConf exp_conf = 111; optional FlattenConf flatten_conf = 135; - optional HDF5DataConf hdf5_data_conf = 112; - optional HDF5OutputConf hdf5_output_conf = 113; + // optional HDF5DataConf hdf5_data_conf = 112; + // optional HDF5OutputConf hdf5_output_conf = 113; optional HingeLossConf hinge_loss_conf = 114; - optional ImageDataConf image_data_conf = 115; + // optional ImageDataConf image_data_conf = 115; optional InfogainLossConf infogain_loss_conf = 116; optional InnerProductConf inner_product_conf = 117; optional LogConf log_conf = 134; optional LRNConf lrn_conf = 118; - optional MemoryDataConf memory_data_conf = 119; + // optional MemoryDataConf memory_data_conf = 119; optional MVNConf mvn_conf = 120; optional PoolingConf pooling_conf = 121; optional PowerConf power_conf = 122; optional PReLUConf prelu_conf = 131; - optional PythonConf python_conf = 130; + // optional PythonConf python_conf = 130; optional ReductionConf reduction_conf = 136; optional ReLUConf relu_conf = 123; optional ReshapeConf reshape_conf = 133; @@ -185,7 +189,7 @@ message LayerConf { optional TanHConf tanh_conf = 127; optional ThresholdConf threshold_conf = 128; optional TileConf tile_conf = 138; - optional WindowDataConf window_data_conf = 129; + //optional WindowDataConf window_data_conf = 129; } // Message that stores hyper-parameters used to apply transformation @@ -835,7 +839,7 @@ message PReLUConf { // Surpassing Human-Level Performance on ImageNet Classification, 2015. // Initial value of a_i. Default is a_i=0.25 for all i. - optional FillerParameter filler = 1; + optional FillerConf filler = 1; // Whether or not slope paramters are shared across channels. optional bool channel_shared = 2 [default = false]; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_cpp_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cpp_math.cc b/test/singa/test_cpp_math.cc index 268785d..78c713f 100644 --- a/test/singa/test_cpp_math.cc +++ b/test/singa/test_cpp_math.cc @@ -20,8 +20,6 @@ *************************************************************/ #include "gtest/gtest.h" -#include "singa/core/math.h" +#include "../src/core/tensor/tensor_math_cpp.h" -TEST(CppMath, Add) { -} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_tensor.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor.cc b/test/singa/test_tensor.cc index 04068ae..86200a8 100644 --- a/test/singa/test_tensor.cc +++ b/test/singa/test_tensor.cc @@ -15,7 +15,7 @@ TEST(TensorTest, TestConstructor) { EXPECT_NE(float_t.device(), nullptr); - singa::Tensor float16_t(singa::Shape{2,3}, singa::kFloat16); + singa::Tensor float16_t(Shape{2,3}, singa::kFloat16); EXPECT_EQ(singa::kFloat16, float16_t.data_type()); EXPECT_EQ(6, float16_t.Size()); EXPECT_EQ(12, float16_t.blob()->size()); @@ -68,7 +68,7 @@ TEST(TensorClass, ToDevice) { TEST(TensorClass, CopyDataFromHostPtr) { float data[] = {1.0f, 2.0f, 3.0f}; Tensor t(Shape{3}); - t.CopyDataFromHostPtr(data, sizeof(float) * 3); + t.CopyDataFromHostPtr(data, 3); const float* dptr = static_cast(t.blob()->data()); EXPECT_FLOAT_EQ(1.0f, dptr[0]); EXPECT_FLOAT_EQ(2.0f, dptr[1]); @@ -78,7 +78,7 @@ TEST(TensorClass, CopyDataFromHostPtr) { TEST(TensorClass, CopyData) { float data[] = {1.0f, 2.0f, 3.0f}; Tensor t(Shape{3}); - t.CopyDataFromHostPtr(data, sizeof(float) * 3); + t.CopyDataFromHostPtr(data, 3); Tensor o(Shape{3}); o.CopyData(t); @@ -91,7 +91,7 @@ TEST(TensorClass, CopyData) { TEST(TensorClass, Clone) { float data[] = {1.0f, 2.0f, 3.0f}; Tensor t(Shape{3}); - t.CopyDataFromHostPtr(data, sizeof(float) * 3); + t.CopyDataFromHostPtr(data, 3); Tensor o = t.Clone(); const float* dptr = static_cast(o.blob()->data()); @@ -110,30 +110,5 @@ TEST(TensorClass, T) { EXPECT_TRUE((t.shape() == o.shape())); } -TEST(TensorClass, Add) { - const float data[] = {1.0f, 2.0f, 3.0f, 1.1f, 2.1f, 3.1f}; - Tensor t(Shape{3}); - t.CopyDataFromHostPtr(data, sizeof(float) * 3); - Tensor o = t.Clone(); - o += t; - const float* dptr = o.data(); - EXPECT_FLOAT_EQ(2.0f, dptr[0]); - EXPECT_FLOAT_EQ(4.0f, dptr[1]); - EXPECT_FLOAT_EQ(6.0f, dptr[2]); - - Tensor p(Shape{3}); - o += p; - const float* dptr1 = o.data(); - EXPECT_FLOAT_EQ(2.0f, dptr1[0]); - EXPECT_FLOAT_EQ(4.0f, dptr1[1]); - EXPECT_FLOAT_EQ(6.0f, dptr1[2]); - - Tensor q(Shape{3}); - q.CopyDataFromHostPtr(data + 3, sizeof(float) * 3); - t += q; - const float* dptr2 = t.data(); - EXPECT_FLOAT_EQ(2.1f, dptr2[0]); - EXPECT_FLOAT_EQ(4.1f, dptr2[1]); - EXPECT_FLOAT_EQ(6.1f, dptr2[2]); -} + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/02851fac/test/singa/test_tensor_math.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc new file mode 100644 index 0000000..51e7cfb --- /dev/null +++ b/test/singa/test_tensor_math.cc @@ -0,0 +1,84 @@ +#include "gtest/gtest.h" +#include "singa/core/tensor.h" +using singa::Tensor; +using singa::Shape; +using singa::Device; + +class TestTensorMath : public ::testing::Test { + protected: + virtual void SetUp() { + const float dat1[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const float dat2[] = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f}; + a.ReShape(singa::Shape{6}); + b.ReShape(singa::Shape{6}); + c.ReShape(singa::Shape{6, 1}); + d.ReShape(singa::Shape{3, 2}); + + a.CopyDataFromHostPtr(dat1, 6); + b.CopyDataFromHostPtr(dat2, 6); + } + Tensor a, b, c, d; +}; + +TEST_F(TestTensorMath, MemberAddTensor) { + Tensor aa = a.Clone(); + aa += a; + const float* dptr = aa.data(); + EXPECT_FLOAT_EQ(2.0f, dptr[0]); + EXPECT_FLOAT_EQ(4.0f, dptr[1]); + EXPECT_FLOAT_EQ(6.0f, dptr[2]); + + // check p is initialized to 0 + Tensor p(Shape{6}); + p += aa; + const float* dptr1 = p.data(); + EXPECT_FLOAT_EQ(2.0f, dptr1[0]); + EXPECT_FLOAT_EQ(4.0f, dptr1[1]); + EXPECT_FLOAT_EQ(6.0f, dptr1[2]); + + a += b; + const float* dptr2 = a.data(); + EXPECT_FLOAT_EQ(2.1f, dptr2[0]); + EXPECT_FLOAT_EQ(4.1f, dptr2[1]); + EXPECT_FLOAT_EQ(6.1f, dptr2[2]); + EXPECT_FLOAT_EQ(12.1f, dptr2[5]); +} +/* +TEST(TensorClass, SubTensor) { + Tensor a(Shape{2,3}), b(Shape{6}); + float x[]={1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + float y[]={1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f}; + a.CopyDataFromHostPtr(x, 6); + b.CopyDataFromHostPtr(y, 6); + b -= a; + const float* dptr = b.data(); + EXPECT_FLOAT_EQ(0.1f, dptr[0]); + EXPECT_FLOAT_EQ(0.1f, dptr[1]); + EXPECT_FLOAT_EQ(0.1f, dptr[2]); + EXPECT_FLOAT_EQ(0.1f, dptr[5]); +} +*/ + +TEST_F(TestTensorMath, AddTensors) { + Tensor ret(a.shape(), a.device(), a.data_type()); + Add(a, b, &ret); + const float* dptr = ret.data(); + EXPECT_FLOAT_EQ(2.1f, dptr[0]); + EXPECT_FLOAT_EQ(4.1f, dptr[1]); + EXPECT_FLOAT_EQ(6.1f, dptr[2]); + EXPECT_FLOAT_EQ(12.1f, dptr[5]); + + const Tensor d = a + b; + const float* dptr2 = d.data(); + EXPECT_FLOAT_EQ(2.1f, dptr2[0]); + EXPECT_FLOAT_EQ(4.1f, dptr2[1]); + EXPECT_FLOAT_EQ(6.1f, dptr2[2]); + EXPECT_FLOAT_EQ(12.1f, dptr2[5]); + + Add(a, b, &a); + const float* dptr1 = a.data(); + EXPECT_FLOAT_EQ(2.1f, dptr1[0]); + EXPECT_FLOAT_EQ(4.1f, dptr1[1]); + EXPECT_FLOAT_EQ(6.1f, dptr1[2]); + EXPECT_FLOAT_EQ(12.1f, dptr1[5]); +}