singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [01/12] incubator-singa git commit: Transfer code from nusinga repo to singa apache repo. New commuinication framework is implemented to unify the frameworks of existing distributed deep learning systems. Communication is now implmented using ZeroMQ. API
Date Sun, 03 May 2015 14:04:06 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 95b1e6dd3 -> b2dc51d23


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/cluster.cc
----------------------------------------------------------------------
diff --git a/src/utils/cluster.cc b/src/utils/cluster.cc
new file mode 100644
index 0000000..ac47422
--- /dev/null
+++ b/src/utils/cluster.cc
@@ -0,0 +1,52 @@
+#include <glog/logging.h>
+#include <fcntl.h>
+#include <fstream>
+#include "utils/cluster.h"
+#include "proto/cluster.pb.h"
+#include <sys/stat.h>
+#include <sys/types.h>
+namespace singa {
+
+std::shared_ptr<Cluster> Cluster::instance_;
+Cluster::Cluster(const ClusterProto &cluster, int procs_id) {
+  procs_id_=procs_id;
+  cluster_ = cluster;
+  SetupFolders(cluster);
+  int nprocs;
+  if(server_worker_separate())
+    nprocs=nworker_procs()+nserver_procs();
+  else
+    nprocs=std::max(nworker_procs(), nserver_procs());
+  CHECK_LT(procs_id, nprocs);
+  if (cluster_.has_nprocs())
+    CHECK_EQ(cluster.nprocs(), nprocs);
+  else
+    cluster_.set_nprocs(nprocs);
+  if(nprocs>1){
+    std::ifstream ifs(cluster.hostfile(), std::ifstream::in);
+    std::string line;
+    while(std::getline(ifs, line)&&endpoints_.size()<nprocs){
+      endpoints_.push_back(line);
+    }
+    CHECK_EQ(endpoints_.size(), nprocs);
+  }
+}
+
+void Cluster::SetupFolders(const ClusterProto &cluster){
+  // create visulization folder
+  mkdir(vis_folder().c_str(),  S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
+}
+
+shared_ptr<Cluster> Cluster::Get(const ClusterProto& cluster, int procs_id){
+  instance_.reset(new Cluster(cluster, procs_id));
+  return instance_;
+}
+
+shared_ptr<Cluster> Cluster::Get() {
+  if(!instance_) {
+    LOG(ERROR)<<"The first call to Get should "
+              <<"provide the sys/model conf path";
+  }
+  return instance_;
+}
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/common.cc
----------------------------------------------------------------------
diff --git a/src/utils/common.cc b/src/utils/common.cc
new file mode 100644
index 0000000..0697060
--- /dev/null
+++ b/src/utils/common.cc
@@ -0,0 +1,89 @@
+#include <fcntl.h>
+#include <google/protobuf/io/coded_stream.h>
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/io/zero_copy_stream_impl.h>
+#include "utils/common.h"
+using std::ios;
+using std::max;
+using google::protobuf::io::FileInputStream;
+using google::protobuf::io::FileOutputStream;
+using google::protobuf::io::ZeroCopyInputStream;
+using google::protobuf::io::CodedInputStream;
+using google::protobuf::io::ZeroCopyOutputStream;
+using google::protobuf::io::CodedOutputStream;
+
+namespace singa {
+
+const int kBufLen=1024;
+std::string IntVecToString(const vector<int>& vec) {
+  string disp="(";
+  for(int x: vec)
+    disp+=std::to_string(x)+", ";
+  return disp+")";
+}
+
+/**
+ * Formatted string.
+ */
+string VStringPrintf(string fmt, va_list l) {
+  char buffer[32768];
+  vsnprintf(buffer, 32768, fmt.c_str(), l);
+  return string(buffer);
+}
+
+/**
+ * Formatted string.
+ */
+string StringPrintf(string fmt, ...) {
+  va_list l;
+  va_start(l, fmt); //fmt.AsString().c_str());
+  string result = VStringPrintf(fmt, l);
+  va_end(l);
+  return result;
+}
+
+void Debug() {
+  int i = 0;
+  char hostname[256];
+  gethostname(hostname, sizeof(hostname));
+  printf("PID %d on %s ready for attach\n", getpid(), hostname);
+  fflush(stdout);
+  while (0 == i)
+    sleep(5);
+}
+
+// the proto related functions are from Caffe.
+void ReadProtoFromTextFile(const char* filename,
+    ::google::protobuf::Message* proto) {
+  int fd = open(filename, O_RDONLY);
+  CHECK_NE(fd, -1) << "File not found: " << filename;
+  FileInputStream* input = new FileInputStream(fd);
+  CHECK(google::protobuf::TextFormat::Parse(input, proto));
+  delete input;
+  close(fd);
+}
+void WriteProtoToTextFile(const Message& proto, const char* filename) {
+  int fd = open(filename, O_WRONLY | O_CREAT, 0644);
+  FileOutputStream* output = new FileOutputStream(fd);
+  CHECK(google::protobuf::TextFormat::Print(proto, output));
+  delete output;
+  close(fd);
+}
+void ReadProtoFromBinaryFile(const char* filename, Message* proto) {
+  int fd = open(filename, O_RDONLY);
+  CHECK_NE(fd, -1) << "File not found: " << filename;
+  ZeroCopyInputStream* raw_input = new FileInputStream(fd);
+  CodedInputStream* coded_input = new CodedInputStream(raw_input);
+  // upper limit 512MB, warning threshold 256MB
+  coded_input->SetTotalBytesLimit(536870912, 268435456);
+  CHECK(proto->ParseFromCodedStream(coded_input));
+  delete coded_input;
+  delete raw_input;
+  close(fd);
+}
+void WriteProtoToBinaryFile(const Message& proto, const char* filename) {
+  int fd= open(filename, O_CREAT|O_WRONLY|O_TRUNC, 0644);
+  CHECK(proto.SerializeToFileDescriptor(fd));
+}
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/data_shard.cc
----------------------------------------------------------------------
diff --git a/src/utils/data_shard.cc b/src/utils/data_shard.cc
new file mode 100644
index 0000000..df311e1
--- /dev/null
+++ b/src/utils/data_shard.cc
@@ -0,0 +1,207 @@
+#include <sys/stat.h>
+#include <glog/logging.h>
+
+#include "utils/data_shard.h"
+namespace singa {
+
+DataShard::DataShard(std::string folder, char mode, int capacity){
+  struct stat sb;
+  if(stat(folder.c_str(), &sb) == 0 && S_ISDIR(sb.st_mode)){
+    LOG(INFO)<<"Open shard folder "<<folder;
+  }else{
+    LOG(FATAL)<<"Cannot open shard folder "<<folder;
+  }
+
+  path_= folder+"/shard.dat";
+  if(mode==DataShard::kRead){
+    fdat_.open(path_, std::ios::in|std::ios::binary);
+    CHECK(fdat_.is_open())<<"Cannot create file "<<path_;
+  }
+  if(mode==DataShard::kCreate){
+    fdat_.open(path_, std::ios::binary|std::ios::out|std::ios::trunc);
+    CHECK(fdat_.is_open())<<"Cannot create file "<<path_;
+  }
+  if(mode==DataShard::kAppend){
+    int last_tuple=PrepareForAppend(path_);
+    fdat_.open(path_, std::ios::binary|std::ios::out|std::ios::in|std::ios::ate);
+    CHECK(fdat_.is_open())<<"Cannot create file "<<path_;
+    fdat_.seekp(last_tuple);
+  }
+
+  mode_=mode;
+  offset_=0;
+  bufsize_=0;
+  capacity_=capacity;
+  buf_=new char[capacity];
+}
+
+DataShard::~DataShard(){
+  delete buf_;
+  fdat_.close();
+}
+
+bool DataShard::Insert(const std::string& key, const Message& val) {
+  std::string str;
+  val.SerializeToString(&str);
+  return Insert(key, str);
+}
+// insert one complete tuple
+bool DataShard::Insert(const std::string& key, const std::string& val) {
+  if(keys_.find(key)!=keys_.end()||val.size()==0)
+    return false;
+  int size=key.size()+val.size()+2*sizeof(size_t);
+  if(offset_+size>capacity_){
+    fdat_.write(buf_, offset_);
+    offset_=0;
+    CHECK_LE(size, capacity_)<<"Tuple size is larger than capacity"
+      <<"Try a larger capacity size";
+  }
+  *reinterpret_cast<size_t*>(buf_+offset_)=key.size();
+  offset_+=sizeof(size_t);
+  memcpy(buf_+offset_, key.data(), key.size());
+  offset_+=key.size();
+  *reinterpret_cast<size_t*>(buf_+offset_)=val.size();
+  offset_+=sizeof(size_t);
+  memcpy(buf_+offset_, val.data(), val.size());
+  offset_+=val.size();
+  return true;
+}
+
+void DataShard::Flush() {
+  fdat_.write(buf_, offset_);
+  fdat_.flush();
+  offset_=0;
+}
+
+int DataShard::Next(std::string *key){
+  key->clear();
+  int ssize=sizeof(size_t);
+  if(!PrepareNextField(ssize))
+    return 0;
+  CHECK_LE(offset_+ssize, bufsize_);
+  int keylen=*reinterpret_cast<size_t*>(buf_+offset_);
+  offset_+=ssize;
+
+  if(!PrepareNextField(keylen))
+    return 0;
+  CHECK_LE(offset_+keylen, bufsize_);
+  for(int i=0;i<keylen;i++)
+    key->push_back(buf_[offset_+i]);
+  offset_+=keylen;
+
+  if(!PrepareNextField(ssize))
+    return 0;
+  CHECK_LE(offset_+ssize, bufsize_);
+  int vallen=*reinterpret_cast<size_t*>(buf_+offset_);
+  offset_+=ssize;
+
+  if(!PrepareNextField(vallen))
+    return 0;
+  CHECK_LE(offset_+vallen, bufsize_);
+  return vallen;
+}
+
+bool DataShard::Next(std::string *key, Message* val) {
+  int vallen=Next(key);
+  if(vallen==0)
+    return false;
+  val->ParseFromArray(buf_+offset_, vallen);
+  offset_+=vallen;
+  return true;
+}
+
+bool DataShard::Next(std::string *key, std::string* val) {
+  int vallen=Next(key);
+  if(vallen==0)
+    return false;
+  val->clear();
+  for(int i=0;i<vallen;i++)
+    val->push_back(buf_[offset_+i]);
+  offset_+=vallen;
+  return true;
+}
+
+void DataShard::SeekToFirst(){
+  CHECK_EQ(mode_, kRead);
+  bufsize_=0;
+  offset_=0;
+  fdat_.close();
+  fdat_.open(path_, std::ios::in|std::ios::binary);
+  CHECK(fdat_.is_open())<<"Cannot create file "<<path_;
+}
+
+// if the buf does not have the next complete field, read data from disk
+bool DataShard::PrepareNextField(int size){
+  if(offset_+size>bufsize_){
+    bufsize_-=offset_;
+    CHECK_LE(bufsize_, offset_);
+    for(int i=0;i<bufsize_;i++)
+      buf_[i]=buf_[i+offset_];
+    offset_=0;
+    if(fdat_.eof())
+      return false;
+    else{
+      fdat_.read(buf_+bufsize_, capacity_-bufsize_);
+      bufsize_+=fdat_.gcount();
+    }
+  }
+  return true;
+}
+
+const int DataShard::Count() {
+  std::ifstream fin(path_, std::ios::in|std::ios::binary);
+  CHECK(fdat_.is_open())<<"Cannot create file "<<path_;
+  int count=0;
+  while(true){
+    size_t len;
+    fin.read(reinterpret_cast<char*>(&len), sizeof(len));
+    if(fin.good())
+      fin.seekg(len, std::ios_base::cur);
+    else break;
+    if(fin.good())
+      fin.read(reinterpret_cast<char*>(&len), sizeof(len));
+    else break;
+    if(fin.good())
+      fin.seekg(len, std::ios_base::cur);
+    else break;
+    if(!fin.good())
+      break;
+    count++;
+  }
+  fin.close();
+  return count;
+}
+
+int DataShard::PrepareForAppend(std::string path){
+  std::ifstream fin(path, std::ios::in|std::ios::binary);
+  if(!fin.is_open()){
+    fdat_.open(path, std::ios::out|std::ios::binary);
+    fdat_.flush();
+    fdat_.close();
+    return 0;
+  }
+
+  int last_tuple_offset=0;
+  char buf[256];
+  size_t len;
+  while(true){
+    memset(buf, 0, 256);
+    fin.read(reinterpret_cast<char*>(&len), sizeof(len));
+    if(fin.good())
+      fin.read(buf, len);
+    else break;
+    if(fin.good())
+      fin.read(reinterpret_cast<char*>(&len), sizeof(len));
+    else break;
+    if(fin.good())
+      fin.seekg(len, std::ios_base::cur);
+    else break;
+    if(fin.good())
+      keys_.insert(std::string(buf));
+    else break;
+    last_tuple_offset=fin.tellg();
+  }
+  fin.close();
+  return last_tuple_offset;
+}
+} /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/graph.cc
----------------------------------------------------------------------
diff --git a/src/utils/graph.cc b/src/utils/graph.cc
new file mode 100644
index 0000000..d1cece6
--- /dev/null
+++ b/src/utils/graph.cc
@@ -0,0 +1,148 @@
+#include <algorithm>
+#include "utils/graph.h"
+
+const string Graph::ToString() const {
+  map<string, string> info;
+  return ToString(info);
+}
+const string Graph::ToString(const map<string, string>& info) const {
+  map<string, int> nodeid;
+  string disp="{\"directed\":1,\n";
+
+  // add nodes
+  disp+="\"nodes\":[\n";
+  bool first=true;
+
+  vector<string> colors={"red", "blue", "black", "green"};
+  // see for more shapes at http://www.graphviz.org/doc/info/shapes.html
+  vector<string> shapes={"box", "ellipse"};
+  int id=0;
+  for(auto node: nodes_){
+    char str[1024];
+    string name=node->name();
+    string color=colors[(node->val().locationid)%colors.size()];
+    string shape;
+    string origin=node->val().origin;
+    if(origin=="kSlice"||origin=="kConcate"||origin=="kSplit"
+        ||origin=="kBridgeSrc"||origin=="kBridgeDst")
+      shape=shapes[1];
+    else
+      shape=shapes[0];
+    sprintf(str, "{\"id\":\"%s%s\", \"color\":\"%s\",\"shape\":\"%s\"}\n",
+        name.c_str(), info.find(name)!=info.end()?info.at(name).c_str():"",
+        color.c_str(), shape.c_str());
+    if(!first)
+      disp+=",";
+    else
+      first=false;
+    disp+=string(str);
+    nodeid[name]=id++;
+  }
+  disp+="]\n,";
+
+  // add edges
+  disp+="\"links\":[\n";
+  first=true;
+  for(auto src: nodes_)
+    for(auto dst: src->dstnodes()){
+    char str[1024];
+    sprintf(str, "{\"source\":%d, \"target\":%d, \"color\":\"%s\"}\n",
+        nodeid[src->name()], nodeid[dst->name()], "black");
+    if(!first)
+      disp+=",";
+    else
+      first=false;
+    disp+=string(str);
+  }
+  disp+="]\n";
+  return disp+"}";
+}
+bool Graph::Check() const {
+  return true;
+}
+
+
+// visited all dst nodes and then push current node into the stack
+void Graph::topology_sort_inner(SNode node,
+    map<string, bool> *visited,
+    std::stack<string> *stack) {
+  (*visited)[node->name()] = true;
+  const vector<SNode>& dstnodes=node->dstnodes();
+  for (auto it=dstnodes.rbegin();it!=dstnodes.rend();it++) {
+    if ((*visited)[(*it)->name()])
+      continue;
+    topology_sort_inner((*it),visited, stack);
+  }
+  stack->push(node->name());
+}
+
+// sort to make `bottom' nodes be placed in the front positions
+void Graph::Sort() {
+  // adjacent list from upper layers to lower layers
+  std::map<string, bool> visited;
+  // prepare adjacent list; input layers will be processed firstly,
+  // hence no need to sort them (mark them as visited)
+  for (SNode node: nodes_) {
+    visited[node->name()] = false;
+  }
+  // the `top' layer in the net will be placed at the bottom of the stack
+  // and then be processed (i.e., forward) at last
+  std::stack<string > stack;
+  for (SNode node: nodes_) {
+    if (visited[node->name()] == false)
+      topology_sort_inner(node, &visited, &stack);
+  }
+  nodes_.clear();
+
+  while (!stack.empty()) {
+    nodes_.push_back(name2node_[stack.top()]);
+    stack.pop();
+  }
+}
+
+
+
+SNode Graph::InsertSliceNode(SNode srcnode, const vector<SNode>& dstnodes,
+    const V& info, bool connect_dst){
+  V myinfo=info;
+  myinfo.origin="kSlice";
+  SNode node=AddNode("slice-"+srcnode->name(),myinfo);
+  AddEdge(srcnode, node);
+  if(connect_dst)
+    for(SNode dst: dstnodes)
+      AddEdge(node, dst);
+  return node;
+}
+SNode Graph::InsertConcateNode(const vector<SNode>&srcnodes, SNode dstnode,
+    const V& info){
+  V myinfo=info;
+  myinfo.origin="kConcate";
+  SNode node=AddNode("concate-"+dstnode->name(),myinfo);
+  AddEdge(node, dstnode);
+  for(SNode src: srcnodes)
+    AddEdge(src, node);
+  return node;
+}
+SNode Graph::InsertSplitNode(SNode srcnode, const vector<SNode>& dstnodes){
+  V myinfo=srcnode->val();
+  myinfo.origin="kSplit";
+  SNode node=AddNode("split-"+srcnode->name(), myinfo);
+  AddEdge(srcnode, node);
+  for(SNode dst: dstnodes)
+    AddEdge(node, dst);
+  return node;
+}
+std::pair<SNode, SNode> Graph::InsertBridgeNode(SNode srcnode, SNode dstnode){
+  LayerInfo info=srcnode->val();
+  info.origin="kBridgeSrc";
+  SNode src=AddNode("s-"+srcnode->name()+"-"+dstnode->name(), info);
+  info=dstnode->val();
+  info.origin="kBridgeDst";
+  SNode dst=AddNode("d-"+srcnode->name()+"-"+dstnode->name(), info);
+  AddEdge(srcnode, src);
+  AddEdge(src, dst);
+  AddEdge(dst, dstnode);
+  return pair<SNode, SNode>{src, dst};
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
new file mode 100644
index 0000000..d64c65d
--- /dev/null
+++ b/src/utils/param.cc
@@ -0,0 +1,345 @@
+#include <glog/logging.h>
+#include <cmath>
+#include <chrono>
+#include <random>
+#include "utils/param.h"
+#include "mshadow/tensor.h"
+#include "utils/singleton.h"
+using namespace mshadow;
+using std::vector;
+using std::string;
+namespace singa {
+
+Param::Param(){
+  owner_=-1;
+  fan_in_=0;
+  set_version(-1);
+}
+
+Param::~Param(){}
+
+Msg* Param::GenPutMsg(void* arg){
+  char buf[256];
+  int v=*(int*)arg;
+  sprintf(buf, "%d %d %f %f", v, size(),
+      learning_rate_multiplier(), weight_decay_multiplier());
+  Msg* msg=new Msg();
+  msg->set_type(kPut);
+  msg->add_frame(buf, strlen(buf));
+  msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+	return msg;
+}
+
+Msg* Param::GenGetMsg(void* arg){
+  char buf[10];
+  int v=*(int*)arg;
+  sprintf(buf, "%d", v);
+  Msg* msg=new Msg();
+  msg->set_type(kGet);
+  msg->add_frame(buf, strlen(buf));
+  return msg;
+}
+
+Msg* Param::GenUpdateMsg(void* arg){
+  char buf[10];
+  int v=*(int*)arg;
+  sprintf(buf, "%d", v);
+  Msg* msg=new Msg();
+  msg->set_type(kUpdate);
+  msg->add_frame(buf, strlen(buf));
+
+  msg->add_frame(mutable_cpu_grad(), size()*sizeof(float));
+  return msg;
+}
+
+Msg* Param::GenSyncMsg(void* arg){
+  return nullptr;
+}
+
+Msg* Param::HandlePutMsg(Msg** msg){
+  int v, size;
+  float lr, wc;
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d %d %f %f",
+      &v, &size, &lr, &wc);
+  set_version(v);
+  proto_.set_learning_rate_multiplier(lr);
+  proto_.set_weight_decay_multiplier(wc);
+  CHECK((*msg)->next_frame());
+  vector<int> shape{size};
+  data_.Reshape(shape);
+  grad_.Reshape(shape);
+  history_.Reshape(shape);
+  CHECK_EQ(size* sizeof(float), (*msg)->frame_size());
+  memcpy(data_.mutable_cpu_data(), (*msg)->frame_data(), size*sizeof(float));
+  delete (*msg);
+  *msg=nullptr;
+  return nullptr;
+}
+
+Msg* Param::HandleGetMsg(Msg** msg){
+  int v;
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
+  CHECK_LE(v, version());
+  CHECK(!(*msg)->next_frame());
+  (*msg)->add_frame(data_.mutable_cpu_data(), sizeof(float)*size());
+  (*msg)->SwapAddr();
+  (*msg)->set_type(kRGet);
+  return *msg;
+}
+
+int Param::ParseUpdateMsg(Msg** msg){
+  int v;
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
+  CHECK_LE(v, version());
+  CHECK((*msg)->next_frame());
+  memcpy(mutable_cpu_grad(), (*msg)->frame_data(),(*msg)->frame_size());
+  delete (*msg);
+  *msg=nullptr;
+  return 1;
+}
+
+Msg* Param::GenUpdateResponseMsg(void* arg){
+  Msg* msg=new Msg();
+  char buf[10];
+  sprintf(buf, "%d", version());
+  msg->set_type(kRUpdate);
+  msg->set_target(id());
+  msg->add_frame(buf, strlen(buf));
+  msg->add_frame(mutable_cpu_data(), size()*sizeof(float));
+  return msg;
+}
+
+
+Msg* Param::HandleSyncMsg(Msg** msg){
+  delete *msg;
+  *msg=nullptr;
+  return nullptr;
+}
+
+int Param::ParseSyncResponseMsg(Msg** msg){
+  delete *msg;
+  *msg=nullptr;
+  return 1;
+}
+int Param::ParsePutResponseMsg(Msg **msg){
+  return ParseSyncResponseMsg(msg);
+}
+int Param::ParseGetResponseMsg(Msg **msg){
+  int v;
+  sscanf(static_cast<char*>((*msg)->frame_data()), "%d", &v);
+  set_version(v);
+  CHECK((*msg)->next_frame());
+  memcpy(mutable_cpu_data(), (*msg)->frame_data(), (*msg)->frame_size());
+  return 1;
+}
+int Param::ParseUpdateResponseMsg(Msg **msg){
+  return ParseGetResponseMsg(msg);
+}
+
+void Param::Setup(const ParamProto& proto, const vector<int>& shape,
+    int fan_in){
+  data_.Reshape(shape);
+  grad_.Reshape(shape);
+  history_.Reshape(shape);
+  proto_=proto;
+  fan_in_=fan_in;
+}
+
+void Param::Init(int v){
+  proto_.set_version(v);
+  Tensor<cpu, 1> data(data_.mutable_cpu_data(), Shape1(data_.count()));
+  unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
+  auto random=ASingleton<Random<cpu>>::Instance(seed);
+  switch (proto_.init_method()) {
+  case ParamProto::kConstant:
+    data=proto_.value();
+    break;
+  case ParamProto::kUniform:
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value();
+    break;
+  case ParamProto::kUniformSqrtFanIn:
+    CHECK_GT(fan_in_,0);
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(fan_in_ / 3.0f);
+    break;
+  case ParamProto::kUniformSqrtFanInOut:
+    random->SampleUniform(data, proto_.low(), proto_.high());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(data_.shape()[0] +data_.shape()[1]);
+    break;
+  case ParamProto::kGaussian:
+    random->SampleGaussian(data, proto_.mean(), proto_.std());
+    if(proto_.value())
+      data*= proto_.value();
+    break;
+  case ParamProto::kGaussainSqrtFanIn:
+    random->SampleGaussian(data, proto_.mean(), proto_.std());
+    if(proto_.value())
+      data*= proto_.value()/ sqrt(data_.shape()[0]);
+    break;
+  default:
+    LOG(ERROR) << "Illegal parameter init method ";
+    break;
+  }
+}
+
+/**************************RandomSyncParam********************************
+const vector<int> RandomSyncParam::RandomSample(int seed, int m, int n){
+  vector<int> samples(m);
+  std::mt19937 gen(seed);
+  std::uniform_real_distribution<float> dist(0.f,1.f);
+  for(int i=0,k=0;i<n&&k<m;i++)
+    if((m-k)*1.0f/(n-i)>dist(gen)){
+      samples[k++]=i;
+    }
+  return samples;
+}
+
+zmsg_t* RandomSyncParam::HandleSyncMsg(zmsg_t** msg){
+  int64_t start=zclock_mono();
+  char* control=zframe_strdup(zmsg_first(*msg));
+  int seed, count;
+  sscanf(control, "%d-%d", &seed,&count);
+  delete control;
+  zframe_t* syncframe=zmsg_next(*msg);
+  CHECK_EQ(zframe_size(syncframe), count*sizeof(float));
+  float* syncptr=(float*)zframe_data(syncframe);
+  float* dptr=data_.mutable_cpu_data();
+  int k=0;
+  if(count==data_.count()){
+    for(int idx=0;idx<count;idx++){
+      float x=dptr[idx];
+      dptr[idx]+=syncptr[k];
+      syncptr[k]=x;
+      k++;
+    }
+  }else{
+    for(int idx: RandomSample(seed, count, data_.count())){
+      float x=dptr[idx];
+      dptr[idx]+=syncptr[k];
+      syncptr[k]=x;
+      k++;
+    }
+  }
+  CHECK_EQ(k,count);
+  CHECK_EQ(zframe_size(syncframe), count*sizeof(float));
+  return *msg;
+}
+
+zmsg_t *RandomSyncParam::GenSyncMsgFromWorker(float sample_ratio){
+  int64_t start=zclock_mono();
+  zmsg_t* msg=zmsg_new();
+  unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
+  int m=data_.count()*sample_ratio;
+  zmsg_addstrf(msg, "%u-%d", seed, m);
+  float* updateptr=new float[m];
+  float* dptr=data_.mutable_cpu_data();
+  float* sdptr=snapshot_.mutable_cpu_data();
+  int k=0;
+  if(m==data_.count()){
+    for(int idx=0;idx<m;idx++)
+      updateptr[k++]=dptr[idx]-sdptr[idx];
+  }else{
+    const vector<int> samples=RandomSample(seed, m, data_.count());
+    for(int idx:samples){
+      updateptr[k++]=dptr[idx]-sdptr[idx];
+    }
+  }
+  CHECK_EQ(k,m);
+  zframe_t* frame=zframe_new(updateptr, sizeof(float)*m);
+  zmsg_append(msg, &frame);
+  delete updateptr;
+  worker_gen_sync+=zclock_mono()-start;
+  return msg;
+}
+
+void RandomSyncParam::ParseSyncMsgFromPS(zmsg_t** msg){
+  int64_t start=zclock_mono();
+  //LOG(ERROR)<<"worker sync "<<id();
+  char* control=zmsg_popstr(*msg);
+  int seed, count;
+  sscanf(control, "%u-%d", &seed, &count);
+  //LOG(ERROR)<<"worker sync "<<id()<<" "<<control;
+  delete control;
+  zframe_t* psdataframe=zmsg_pop(*msg);
+  CHECK_EQ(zframe_size(psdataframe), count*sizeof(float));
+  float* psdptr=(float*)zframe_data(psdataframe);
+  float* dptr=data_.mutable_cpu_data();
+  float* sdptr=snapshot_.mutable_cpu_data();
+  int k=0;
+  if(count==data_.count()){
+    for(int idx=0;idx<count;idx++){
+      dptr[idx]+=psdptr[k++]-sdptr[idx];
+      sdptr[idx]=dptr[idx];
+    }
+  }else{
+    for(int idx: RandomSample(seed, count, data_.count())){
+      dptr[idx]+=psdptr[k++]-sdptr[idx];
+      sdptr[idx]=dptr[idx];
+    }
+  }
+  zframe_destroy(&psdataframe);
+  worker_handle_sync+=zclock_mono()-start;
+  zmsg_destroy(msg);
+}
+
+
+void RandomSyncParam::Setup(const ParamProto& proto, const vector<int>& shape,
+    int fan_in){
+  Param::Setup(proto, shape, fan_in);
+  snapshot_.Reshape(shape);
+}
+
+void RandomSyncParam::Init(){
+  Param::Init();
+  memcpy(snapshot_.mutable_cpu_data(), data_.mutable_cpu_data(),
+      sizeof(float)*data_.count());
+}
+*/
+
+/***************************ElasticParam************************************
+zmsg_t* ElasticParam::HandleSyncMsg(zmsg_t** msg){
+  int64_t start=zclock_mono();
+  char* control=zframe_strdup(zmsg_first(*msg));
+  float alpha;int count;
+  sscanf(control, "%f-%d", &alpha,&count);
+  delete control;
+  zframe_t* syncframe=zmsg_next(*msg);
+  CHECK_EQ(size(), count);
+  Tensor<cpu, 1> server(data_.mutable_cpu_data(), Shape1(count));
+  Tensor<cpu, 1> worker((float*)zframe_data(syncframe), Shape1(count));
+  worker=(worker-server)*alpha;
+  server+=worker;
+  return *msg;
+}
+
+zmsg_t *ElasticParam::GenSyncMsgFromWorker(float alpha){
+  int64_t start=zclock_mono();
+  zmsg_t* msg=zmsg_new();
+  zmsg_addstrf(msg, "%f-%d", alpha, size());
+  zmsg_addmem(msg, mutable_cpu_data(), sizeof(float)*size());
+  worker_gen_sync+=zclock_mono()-start;
+  return msg;
+}
+
+void ElasticParam::ParseSyncMsgFromPS(zmsg_t** msg){
+  int64_t start=zclock_mono();
+  //LOG(ERROR)<<"worker sync "<<id();
+  char* control=zmsg_popstr(*msg);
+  float alpha;int count;
+  sscanf(control, "%f-%d", &alpha, &count);
+  delete control;
+  zframe_t* frame=zmsg_pop(*msg);
+  CHECK_EQ(zframe_size(frame), count*sizeof(float));
+  Tensor<cpu, 1> diff((float*)zframe_data(frame), Shape1(count));
+  Tensor<cpu, 1> data(mutable_cpu_data(), Shape1(count));
+  data-=diff;
+  zframe_destroy(&frame);
+  zmsg_destroy(msg);
+  worker_handle_sync+=zclock_mono()-start;
+}
+*/
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/updater.cc
----------------------------------------------------------------------
diff --git a/src/utils/updater.cc b/src/utils/updater.cc
new file mode 100644
index 0000000..0b89ee8
--- /dev/null
+++ b/src/utils/updater.cc
@@ -0,0 +1,192 @@
+
+#include "utils/updater.h"
+#include "mshadow/tensor.h"
+#include "mshadow/cxxnet_op.h"
+#include "proto/model.pb.h"
+using namespace mshadow;
+using namespace mshadow::expr;
+
+namespace  singa {
+
+float Updater::GetLearningRate(int step){
+  float ret = 0., r = 0., base=proto_.base_learning_rate();
+  int freq=0;
+  switch (proto_.learning_rate_change_method()) {
+    case UpdaterProto_ChangeProto_kFixed:
+      ret = base;
+      break;
+    case UpdaterProto_ChangeProto_kLinear:
+      // a is init, b is the final
+      freq=proto_.learning_rate_change_frequency();
+      r = step * 1.0  / freq;
+      ret = (1.0 - r) * base + r * proto_.final_learning_rate();
+      break;
+    case UpdaterProto_ChangeProto_kExponential:
+      // a is init, b is the final, from convnet
+      CHECK_EQ(base, 2 * proto_.final_learning_rate())
+        << "final value should be the half";
+      freq=proto_.learning_rate_change_frequency();
+      ret = base / pow(2, step * 1. / freq);
+      break;
+    case UpdaterProto_ChangeProto_kInverse_t:
+      // a is init, b is the final, from convnet
+      CHECK_EQ(base, 2 * proto_.final_learning_rate())
+        << "final value should be the half";
+      ret = base / (1. + step * 1. / proto_.final_learning_rate());
+      break;
+    case UpdaterProto_ChangeProto_kInverse:
+      // a is init, b is gamma, c is pow
+      ret=base*pow(1.f+proto_.gamma()*step, -proto_.pow());
+      break;
+    case UpdaterProto_ChangeProto_kStep:
+      // a is the base learning rate, b is gamma, from caffe
+      // notice it is step/change_steps, not step*1.0/change_steps
+      freq=proto_.learning_rate_change_frequency();
+      ret = base * pow(proto_.gamma(), step / freq);
+      break;
+    case UpdaterProto_ChangeProto_kFixedStep:
+      for(size_t i=0;i<proto_.step_size();i++){
+        if(step>proto_.step(i))
+          ret=proto_.step_lr(i);
+      }
+      break;
+    default:
+      LOG(ERROR) << "Wrong hyper-parameter update method";
+  }
+  return ret;
+}
+
+/***********************SGD with momentum******************************/
+void SGDUpdater::Init(const UpdaterProto& proto){
+  Updater::Init(proto);
+  base_lr_=proto.base_learning_rate();
+  //CHECK_GT(base_lr_, 0);
+  momentum_=proto.momentum();
+  weight_decay_=proto.weight_decay();
+}
+
+void SGDUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+  Shape<1> s=Shape1(param->size());
+  Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
+  Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
+  float lr=GetLearningRate(step)*param->learning_rate_multiplier();
+  float wd=weight_decay_*param->weight_decay_multiplier();
+  if(wd>0){ // L2 regularization
+    grad+=data*wd;
+  }
+  if(momentum_>0){
+    Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
+    if(step==0) history=0;
+    history=history*momentum_-lr*grad;
+    data+=history;
+  }else{
+    grad*=-lr;
+    data+=grad;
+  }
+}
+
+/***********************Nesterov******************************/
+void NesterovUpdater::Init(const UpdaterProto& proto){
+  Updater::Init(proto);
+  base_lr_=proto.base_learning_rate();
+  CHECK_GT(base_lr_, 0);
+  weight_decay_=proto.weight_decay();
+}
+
+void NesterovUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+  Shape<1> s=Shape1(param->size());
+  Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
+  Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
+  Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
+  TensorContainer<cpu, 1> tmp(s);
+  if(step==0) history=0;
+  float lr=GetLearningRate(step)*param->learning_rate_multiplier();
+  float wd=weight_decay_*param->weight_decay_multiplier();
+  if(wd>0){ // L2 regularization
+    grad+=data*wd;
+  }
+  Copy(tmp, history);
+  history=history*momentum_+lr*grad;
+  tmp=history*(1+momentum_)-tmp*momentum_;
+  data-=tmp;
+}
+/***********************AdaGrad******************************/
+void AdaGradUpdater::Init(const UpdaterProto& proto){
+  Updater::Init(proto);
+  base_lr_=proto.base_learning_rate();
+  CHECK_GT(base_lr_, 0);
+  delta_=proto.delta();
+  weight_decay_=proto.weight_decay();
+}
+
+void AdaGradUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+  Shape<1> s=Shape1(param->size());
+  Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
+  Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
+  Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
+  if(step==0) history=0;
+  history+=F<op::square>(grad*grad_scale);
+  float lr=GetLearningRate(step)*param->learning_rate_multiplier();
+  float wd=weight_decay_*param->weight_decay_multiplier();
+  if(wd>0){ // L2 regularization
+    grad+=data*wd;
+  }
+  data-=lr*grad/(F<op::sqrtop>(history,delta_));
+}
+
+/***********************RMSProp******************************/
+void RMSPropUpdater::Init(const UpdaterProto& proto){
+  Updater::Init(proto);
+  base_lr_=proto.base_learning_rate();
+  CHECK_GT(base_lr_, 0);
+  delta_=proto.delta();
+  rho_=proto.rho();
+  weight_decay_=proto.weight_decay();
+}
+
+void RMSPropUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+  Shape<1> s=Shape1(param->size());
+  Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
+  Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
+  Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
+  if(step==0) history=0;
+  history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale);
+  float lr=GetLearningRate(step)*param->learning_rate_multiplier();
+  float wd=weight_decay_*param->weight_decay_multiplier();
+  if(wd>0){ // L2 regularization
+    grad+=data*wd;
+  }
+  data-=lr*grad/(F<op::sqrtop>(history,delta_));
+}
+
+/***********************AdaDelta******************************
+void AdaDeltaUpdater::Init(const UpdaterProto& proto){
+  Updater::Init(proto);
+  delta_=proto.delta();
+  rho_=proto.rho();
+  weight_decay_=proto.weight_decay();
+}
+
+void AdaDeltaUpdater::Update(int step, shared_ptr<Param> param, float grad_scale){
+  Shape<1> s=Shape1(param->size());
+  Tensor<cpu, 1> data(param->mutable_cpu_data(), s);
+  Tensor<cpu, 1> grad(param->mutable_cpu_grad(), s);
+  Tensor<cpu, 1> history(param->mutable_cpu_history(), s);
+  Tensor<cpu, 1> update(param->mutable_cpu_update(), s);
+  TensorContainer<cpu, 1> tmp(s);
+  float wd=weight_decay_*param->weight_decay_multiplier();
+  if(wd>0){ // L2 regularization
+    grad+=data*wd;
+  }
+  if(step==0){
+    history=0;
+    update=0;
+  }
+  history=history*rho_+(1-rho_)*F<op::square>(grad*grad_scale);
+  tmp=grad*F<op::sqrtop>(update, delta_)/F<op::sqrtop>(history, delta_);
+  update=rho_*update+(1-rho_)*F<op::square>(tmp);
+  data-=tmp;
+}
+*/
+
+} /* singa */


Mime
View raw message