singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [02/12] incubator-singa git commit: Transfer code from nusinga repo to singa apache repo. New commuinication framework is implemented to unify the frameworks of existing distributed deep learning systems. Communication is now implmented using ZeroMQ. API
Date Sun, 03 May 2015 14:04:07 GMT
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_table_server.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_table_server.cc b/src/test/dist_test/test_table_server.cc
new file mode 100644
index 0000000..5f3612c
--- /dev/null
+++ b/src/test/dist_test/test_table_server.cc
@@ -0,0 +1,357 @@
+//  Copyright © 2014 Anh Dinh. All Rights Reserved.
+
+#include "core/global-table.h"
+#include "core/common.h"
+#include "core/table.h"
+#include "core/table_server.h"
+#include "utils/global_context.h"
+#include "utils/common.h"
+#include <gflags/gflags.h>
+#include "proto/model.pb.h"
+#include "proto/common.pb.h"
+#include "worker.h"
+#include "coordinator.h"
+#include "utils/common.h"
+#include "utils/proto_helper.h"
+
+#include <cmath>
+#include <stdlib.h>
+#include <vector>
+#include <iostream>
+#include <fstream>
+
+
+/**
+ * Test for table server access. The table is of type <VKey,int>
+ */
+DEFINE_bool(restore_mode, false, "restore from checkpoint file");
+using namespace lapis;
+using std::vector;
+
+DEFINE_int32(checkpoint_frequency, 5000, "frequency for cp");
+DEFINE_int32(checkpoint_after, 1, "cp after this steps");
+DEFINE_string(par_mode, "hybrid",  "time training algorithm");
+DEFINE_bool(restore, false, "restore from checkpoint file");
+
+DEFINE_string(db_backend, "lmdb", "backend db");
+DEFINE_string(system_conf, "examples/imagenet12/system.conf", "configuration file for node roles");
+DEFINE_string(model_conf, "examples/imagenet12/model.conf", "DL model configuration file");
+DEFINE_string(checkpoint_dir,"/data1/wangwei/lapis/","check point dir");
+DEFINE_int32(threshold,1000000, "max # of parameters in a vector");
+DEFINE_int32(iterations,5,"numer of get/put iterations");
+DEFINE_int32(workers,2,"numer of workers doing get/put");
+DECLARE_bool(checkpoint_enabled);
+
+
+DECLARE_bool(checkpoint_enabled);
+
+/**
+ * Get and update handler for VKey.
+ */
+struct AnhUpdateHandler: BaseUpdateHandler<VKey, SGDValue> {
+	bool Update(SGDValue *a, const SGDValue &b) {
+
+		float * adptr = a->mutable_data()->mutable_value()->mutable_data();
+		const float*bdptr = b.grad(0).value().data();
+		for (int i = 0; i < b.grad(0).value_size(); i++)
+			adptr[i] += bdptr[i];
+
+		return true;
+	}
+
+	bool Get(const VKey k, const SGDValue &val, SGDValue *ret) {
+		*ret = val;
+		return true;
+	}
+
+	bool is_checkpointable(const VKey k, const SGDValue v) {
+		return false; //always checkpoint
+	}
+};
+
+typedef map<int, GlobalTable*> Map;
+Map tables;
+shared_ptr<NetworkThread> network;
+shared_ptr<GlobalContext> context;
+std::vector<ServerState*> server_states;
+TableServer *table_server;
+
+#define SIZE 16
+int tuple_sizes[SIZE] = {27448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,14112,4096,4096,1000,384,384,256,256,96};
+
+/**
+ * Initialize tables.
+ */
+void create_mem_table(int id, int num_shards){
+
+	TableDescriptor *info = new TableDescriptor(id, num_shards);
+	  info->key_marshal = new Marshal<VKey>();
+	  info->value_marshal = new Marshal<SGDValue>();
+	  info->sharder = new VKeySharder;
+	  info->accum = new AnhUpdateHandler;
+	  info->partition_factory = new typename SparseTable<VKey, SGDValue>::Factory;
+	  auto table=new TypedGlobalTable<VKey, SGDValue>();
+	  table->Init(info);
+	  tables[id] = table;
+}
+
+/**
+ * Coordinator assigns shards to processes.
+ * @param id table ID.
+ */
+void coordinator_assign_tables(int id) {
+
+	// wait for the servers to be up.
+	for (int i = 0; i < context->num_procs(); i++) {
+		RegisterWorkerRequest req;
+		int src = 0;
+		//  adding memory server.
+		if (context->IsTableServer(i)) {
+			VLOG(3)<< "Waiting for message from table server " << i;
+			network->Read(MPI::ANY_SOURCE, MTYPE_REGISTER_WORKER, &req, &src);
+			server_states.push_back(new ServerState(i));
+		}
+	}
+
+	VLOG(3) << " All servers registered and started up. Ready to go";
+	VLOG(3) << "num of shards" << tables[id]->num_shards() << " for table "	<< id;
+
+	// assign table to shard in round roubin fashion.
+	int server_idx = 0;
+	for (int shard = 0; shard < tables[id]->num_shards(); ++shard) {
+		ServerState &server = *server_states[server_idx];
+		VLOG(3) << "Assigning table (" << id << "," << shard << ") to server "
+				<< server_states[server_idx]->server_id;
+		server.shard_id = shard;
+		server.local_shards.insert(new TaskId(id, shard));
+		server_idx = (server_idx + 1) % server_states.size();
+	}
+	ShardAssignmentRequest req;
+	for (size_t i = 0; i < server_states.size(); ++i) {
+		ServerState &server = *server_states[i];
+		for (auto * task : server.local_shards) {
+			ShardAssignment *s = req.add_assign();
+			s->set_new_worker(server.server_id);
+			s->set_table(task->table);
+			s->set_shard(task->shard);
+			//  update local tables
+			GlobalTable *t = tables.at(task->table);
+			t->get_partition_info(task->shard)->owner = server.server_id;
+			delete task;
+		}
+	}
+
+	network->SyncBroadcast(MTYPE_SHARD_ASSIGNMENT, MTYPE_SHARD_ASSIGNMENT_DONE,
+			req);
+	VLOG(3) << "done table assignment... ";
+}
+
+
+void table_init(){
+	table_server = new TableServer();
+	table_server->StartTableServer(tables);
+	VLOG(3) << "table server started on process "<< NetworkThread::Get()->id();
+}
+
+
+/**
+ * Coordinator loads data to the table.
+ * @param size number of tuples.
+ */
+void coordinator_load_data() {
+	auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]);
+	for (int i = 0; i < SIZE; i++) {
+		VKey key;
+		SGDValue x;
+		DAryProto *data = x.mutable_data();
+		DAryProto *grad = x.add_grad();
+		for (int j = 0; j < tuple_sizes[i]; j++) {
+			data->add_value(j * 1.0f);
+			grad->add_value(j * 1.0f);
+		}
+		key.set_key(i);
+		table->put(key, x);
+	}
+	VLOG(3) << "Done loading " << SIZE << " tuples ...";
+}
+
+/**
+ * Worker gets tuples from the server.
+ * @param size number of tuples to be requested.
+ */
+void get() {
+	auto table = static_cast<TypedGlobalTable<VKey,SGDValue>*>(tables[0]);
+	SGDValue value;
+	for (int i = 0; i < SIZE; i++) {
+		VKey key;
+		key.set_key(i);
+		table->async_get(key, &value);
+	}
+	VLOG(3) << "Done sending get requests ...";
+
+	for (int i = 0; i < SIZE; i++) {
+		VKey key;
+		while (!table->async_get_collect(&key, &value))
+			Sleep(0.0001);
+	}
+}
+
+/**
+ * Worker updates tuples.
+ */
+void update() {
+	auto table = static_cast<TypedGlobalTable<VKey, SGDValue>*>(tables[0]);
+	for (int i = 0; i < SIZE; i++) {
+		VKey key;
+		key.set_key(i);
+
+		SGDValue x;
+		DAryProto *grad = x.add_grad();
+		for (int j = 0; j < tuple_sizes[i]; j++)
+			grad->add_value(j * 1.0f);
+
+		table->update(key, x);
+	}
+	VLOG(3) << "Done updating " << SIZE << " tuples ...";
+}
+
+
+void worker_test_data() {
+	//get(size);
+	update();
+	update();
+	get();
+	/*
+	update(table, tuples);
+	update(table, tuples);
+	update(table, tuples);
+	get(table, tuples);
+	*/
+}
+
+/**
+ * Shutdown the process.
+ */
+void shutdown() {
+	if (context->AmICoordinator()) {
+		EmptyMessage msg;
+		for (int i = 0; i < context->num_procs() - 1; i++)
+			network->Read(MPI::ANY_SOURCE, MTYPE_WORKER_END, &msg);
+		EmptyMessage shutdown_msg;
+		for (int i = 0; i < network->size() - 1; i++) {
+			network->Send(i, MTYPE_SHUTDOWN, shutdown_msg);
+		}
+		//network->Flush();
+		network->Shutdown();
+	} else {
+		//network->Flush();
+		network->Send(context->num_procs() - 1, MTYPE_WORKER_END,
+				EmptyMessage());
+		EmptyMessage msg;
+		network->Read(context->num_procs() - 1, MTYPE_SHUTDOWN, &msg);
+
+		if (context->AmITableServer()){
+			RequestDispatcher::Get()->PrintStats();
+			table_server->ShutdownTableServer();
+		}
+
+		network->Shutdown();
+	}
+}
+
+/**
+ * Worker handle shard assignment from the coordinator.
+ */
+void HandleShardAssignment() {
+
+	ShardAssignmentRequest shard_req;
+	auto mpi = NetworkThread::Get();
+	mpi->Read(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT, &shard_req);
+
+	//  request read from coordinator
+	for (int i = 0; i < shard_req.assign_size(); i++) {
+		const ShardAssignment &a = shard_req.assign(i);
+		GlobalTable *t = tables.at(a.table());
+		t->get_partition_info(a.shard())->owner = a.new_worker();
+
+		//if local shard, create check-point files
+		if (FLAGS_checkpoint_enabled && t->is_local_shard(a.shard())) {
+			string checkpoint_file = StringPrintf("%s/checkpoint_%d",
+					FLAGS_checkpoint_dir.c_str(), a.shard());
+			char hostname[256];
+			gethostname(hostname, sizeof(hostname));
+
+			FILE *tmp_file = fopen(checkpoint_file.c_str(), "r");
+			if (tmp_file) { //exists -> open to reading and writing
+				fclose(tmp_file);
+				auto cp = t->checkpoint_files();
+
+				if (FLAGS_restore_mode) { //open in read mode to restore, then close
+					LogFile *file = new LogFile(checkpoint_file, "rw", 0);
+					int table_size = file->read_latest_table_size();
+					delete file;
+
+					double start = Now();
+					(*cp)[a.shard()] = new LogFile(checkpoint_file, "r",
+							a.shard());
+					t->Restore(a.shard());
+					delete (*cp)[a.shard()];
+					double end = Now();
+					LOG(ERROR) << "restore time\t" << end - start << "\tfor\t"
+							<< table_size << "\tthreshold\t" << FLAGS_threshold;
+				}
+				char hostname[256];
+				gethostname(hostname, sizeof(hostname));
+				(*cp)[a.shard()] = new LogFile(checkpoint_file, "a", a.shard());
+			} else { // not exist -> open to writing first time
+				auto cp = t->checkpoint_files();
+				(*cp)[a.shard()] = new LogFile(checkpoint_file, "w", a.shard());
+			}
+		}
+	}
+
+	EmptyMessage empty;
+	mpi->Send(GlobalContext::kCoordinator, MTYPE_SHARD_ASSIGNMENT_DONE, empty);
+	VLOG(3) << "Done handling shard assignment ...";
+
+}
+
+
+int main(int argc, char **argv) {
+	FLAGS_logtostderr = 1;
+	int provided;
+	MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+	context = GlobalContext::Get(FLAGS_system_conf);
+	network = NetworkThread::Get();
+
+	ModelProto model;
+	ReadProtoFromTextFile(FLAGS_model_conf.c_str(), &model);
+
+	create_mem_table(0, context->num_table_servers());
+
+	if (context->AmICoordinator()) {
+		coordinator_assign_tables(0);
+		coordinator_load_data();
+		network->barrier();
+	} else {
+		if (context->AmITableServer()) {
+			table_init();
+			HandleShardAssignment();
+			network->barrier();
+		} else {
+			HandleShardAssignment();
+			network->barrier();
+			Sleep(1);
+			VLOG(3) << "Worker cleared the barrier ...";
+			worker_test_data();
+		}
+	}
+
+	shutdown();
+	return 0;
+}
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/dist_test/test_tuple.cc
----------------------------------------------------------------------
diff --git a/src/test/dist_test/test_tuple.cc b/src/test/dist_test/test_tuple.cc
new file mode 100644
index 0000000..727f8e3
--- /dev/null
+++ b/src/test/dist_test/test_tuple.cc
@@ -0,0 +1,258 @@
+#include <cstdio>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "server.h"
+#include "proto/worker.pb.h"
+#include "utils/network_service.h"
+#include "core/common.h"
+#include "core/network_queue.h"
+#include "proto/model.pb.h"
+#include "proto/common.pb.h"
+#include "utils/global_context.h"
+
+/**
+ * @file test_tuple.cc
+ *
+ * Test performance of TableServer put/get/update operations.
+ */
+DECLARE_double(sleep_time);
+
+using namespace lapis;
+using namespace std;
+using std::vector;
+
+#define NKEYS 1000
+#define TUPLE_SIZE 50000000
+
+#ifndef FLAGS_v
+  DEFINE_int32(v, 3, "vlog controller");
+#endif
+
+
+#define SIZE 16
+#define THRESHOLD 500000
+int tuple_sizes[SIZE] = {37448736, 16777216, 4096000, 1327104, 884736, 884736, 614400,14112,4096,4096,1000,384,384,256,256,96};
+vector<int> valsizes;
+int collect_size;
+int num_tuples;
+
+void Put(int tid, int size, int version) {
+	RequestBase request;
+	request.set_table(0);
+	request.set_source(NetworkService::Get()->id());
+	PutRequest *put_req = request.MutableExtension(PutRequest::name);
+	int shard = tid % GlobalContext::Get()->num_servers();
+	put_req->set_shard(shard);
+	TableData *tuple = put_req->mutable_data();
+
+	TKey* key = tuple->mutable_key();
+	TVal* val = tuple->mutable_value();
+
+	key->set_id(tid);
+	key->set_version(version);
+
+	DAryProto *data = val->mutable_data();
+	for (int i = 0; i < size; i++){
+		data->add_value(0.0f);
+	}
+
+	// TODO check the msg type
+	NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+}
+
+void Update(int tid, int size, int version) {
+	RequestBase request;
+	request.set_table(0);
+	request.set_source(NetworkService::Get()->id());
+	UpdateRequest *update_req = request.MutableExtension(UpdateRequest::name);
+	int shard = tid % GlobalContext::Get()->num_servers();
+	update_req->set_shard(shard);
+	TableData *tuple = update_req->mutable_data();
+
+	TKey* key = tuple->mutable_key();
+	TVal* val = tuple->mutable_value();
+
+	key->set_id(tid);
+	key->set_version(version);
+
+	DAryProto *data = val->mutable_grad();
+	for (int i = 0; i < size; i++)
+		data->add_value(1.0f);
+	// TODO check the msg type
+	NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+}
+
+void print_result(TableData *data){
+	TKey *key = data->mutable_key();
+	TVal *val = data->mutable_value();
+	int k = key->id();
+	VLOG(3) << "key = " << k;
+	string s;
+	for (int i=0; i<TUPLE_SIZE; i++)
+		s.append(to_string(val->mutable_data()->value(i))).append(" ");
+	VLOG(3) << "val = " <<s;
+}
+
+void AsyncGet(int tid, int version) {
+	RequestBase request;
+	request.set_table(0);
+	request.set_source(GlobalContext::Get()->rank()); //NetworkService::Get()->id());
+	GetRequest *get_req = request.MutableExtension(GetRequest::name);
+	int shard = tid % GlobalContext::Get()->num_servers();
+	get_req->set_shard(shard);
+
+	TKey *key = get_req->mutable_key();
+	key->set_id(tid);
+	key->set_version(version);
+	NetworkService::Get()->Send(shard, MTYPE_REQUEST, request);
+
+}
+
+void Collect(){
+	int count = collect_size;
+	double start_collect = Now();
+	while (count){
+		while (true) {
+				Message *resp = NetworkService::Get()->Receive();
+				if (!resp)
+					Sleep(FLAGS_sleep_time);
+				else{
+					delete resp;
+					break;
+				}
+			}
+		count--;
+	}
+	double end_collect = Now();
+	VLOG(3) << "Collected " << collect_size << " tuples in " << (end_collect-start_collect);
+}
+
+/**
+ * Workers wait for the barrier, then one of them send SHUTDOWN message
+ * to all table servers.
+ */
+void worker_send_shutdown(int id){
+	auto gc = lapis::GlobalContext::Get();
+	NetworkService *network_service_ = NetworkService::Get().get();
+	MPI_Barrier(gc->workergroup_comm());
+	if (gc->rank()==id){
+		for (int i=0; i<gc->num_procs(); i++){
+			if (gc->IsTableServer(i)){
+				EmptyMessage msg;
+				network_service_->Send(i, MTYPE_SHUTDOWN,msg);
+			}
+		}
+	}
+}
+
+/**
+ * One worker with the specific ID puts, others wait.
+ */
+void worker_load_data(int id){
+	auto gc = lapis::GlobalContext::Get();
+	for (int i = 0; i < SIZE; i++) {
+		int m = tuple_sizes[i];
+		if (m < THRESHOLD)
+			valsizes.push_back(m);
+		else {
+			for (int j = 0; j < m / THRESHOLD; j++)
+				valsizes.push_back(THRESHOLD);
+			if (m % THRESHOLD)
+				valsizes.push_back(m%THRESHOLD);
+		}
+	}
+	num_tuples = (int)valsizes.size();
+	collect_size = 0;
+	for (int i=0; i<num_tuples; i++)
+		if (i%gc->group_size()==gc->worker_id())
+			collect_size++;
+
+	if (gc->rank()==id){
+		for (size_t i=0; i<valsizes.size(); i++)
+			Put(i,valsizes[i],0);
+		VLOG(3) << "Done loading data, num_keys = "<<valsizes.size() << " process " << id;
+	}
+	VLOG(3) << "Collect size = " << collect_size;
+	MPI_Barrier(gc->workergroup_comm());
+}
+
+void worker_update_data() {
+	auto gc = lapis::GlobalContext::Get();
+	for (int i = 0; i < num_tuples; i++)
+		if (i%gc->group_size()==gc->worker_id())
+			Update(i,valsizes[i],0);
+
+	VLOG(3) << "Done update ... for "<<collect_size << " tuples ";
+}
+
+/*
+ * Async get.
+ */
+void worker_get_data(){
+	auto gc = lapis::GlobalContext::Get();
+	for (int i=0; i<num_tuples; i++)
+		if (i%gc->group_size()==gc->worker_id())
+			AsyncGet(i,0);
+	Collect();
+	VLOG(3) << "Done collect ...";
+}
+
+void start_network_service_for_worker(){
+	NetworkService *network_service_ = NetworkService::Get().get();
+	network_service_->Init(GlobalContext::Get()->rank(), Network::Get().get(), new SimpleQueue());
+	network_service_->StartNetworkService();
+}
+
+int main(int argc, char **argv) {
+	google::InitGoogleLogging(argv[0]);
+	gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+	int provided;
+
+
+	MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &provided);
+
+
+	FLAGS_logtostderr = 1;
+
+
+	// Init GlobalContext
+	Cluster cluster;
+	cluster.set_server_start(0);
+	cluster.set_server_end(8);
+	cluster.set_worker_start(8);
+	cluster.set_worker_end(24);
+	cluster.set_group_size(8);
+	cluster.set_data_folder("/data1/wangwei/lapis");
+
+	auto gc = lapis::GlobalContext::Get(cluster);
+
+	// worker or table server
+	if (gc->AmITableServer()) {
+		lapis::TableServer server;
+		SGDProto sgd;
+		sgd.set_learning_rate(0.01);
+		sgd.set_momentum(0.9);
+		sgd.set_weight_decay(0.1);
+		sgd.set_gamma(0.5);
+		sgd.set_learning_rate_change_steps(1);
+		server.Start(sgd);
+	} else {
+		start_network_service_for_worker();
+		worker_load_data(cluster.worker_start());
+		for (int i=0; i<10; i++){
+			worker_update_data();
+			worker_get_data();
+		}
+		worker_send_shutdown(cluster.worker_start());
+		NetworkService::Get()->Shutdown();
+	}
+	gc->Finalize();
+	MPI_Finalize();
+	VLOG(3) << "End, process "<< gc->rank();
+	return 0;
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_blob.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_blob.cc b/src/test/model/test_blob.cc
new file mode 100644
index 0000000..75f1921
--- /dev/null
+++ b/src/test/model/test_blob.cc
@@ -0,0 +1,58 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-18 19:44
+#include <gtest/gtest.h>
+#include "proto/model.pb.h"
+#include "model/lapis.h"
+
+namespace lapis {
+class BlobTest : public ::testing::Test {
+ public:
+  BlobTest() : blob1(new Blob()), blob2(new Blob()) {}
+  ~BlobTest() {
+    delete blob1;
+    delete blob2;
+  }
+ protected:
+  Blob *blob1, *blob2;
+  Blob blob3, blob4;
+};
+
+TEST_F(BlobTest, Constructor) {
+  EXPECT_EQ(blob1->length(), 0);
+  EXPECT_EQ(blob1->width(), 0);
+  EXPECT_EQ(blob1->height(), 0);
+  EXPECT_EQ(blob3.length(), 0);
+  EXPECT_EQ(blob3.width(), 0);
+  EXPECT_EQ(blob3.height(), 0);
+  EXPECT_TRUE(blob2->dptr == nullptr);
+  EXPECT_TRUE(blob4.dptr == nullptr);
+}
+
+TEST_F(BlobTest, TestResize) {
+  blob1->Resize(10,1,1,1);
+  EXPECT_EQ(blob1->length(), 10);
+  EXPECT_EQ(blob1->num(), 10);
+  EXPECT_EQ(blob1->height(), 1);
+  EXPECT_EQ(blob1->width(), 1);
+  EXPECT_TRUE(blob1->dptr != nullptr);
+  blob2->Resize(4,1,1,3);
+  EXPECT_EQ(blob2->length(), 12);
+  EXPECT_EQ(blob2->num(), 4);
+  EXPECT_EQ(blob2->height(), 1);
+  EXPECT_EQ(blob2->width(), 3);
+  EXPECT_TRUE(blob2->dptr != nullptr);
+  blob3.Resize(5,1,4,3);
+  EXPECT_EQ(blob3.length(), 60);
+  EXPECT_EQ(blob3.num(), 5);
+  EXPECT_EQ(blob3.height(), 4);
+  EXPECT_EQ(blob3.width(), 3);
+  EXPECT_TRUE(blob3.dptr != nullptr);
+  blob4.Resize(6,5,4,3);
+  EXPECT_EQ(blob4.length(), 360);
+  EXPECT_EQ(blob4.num(), 6);
+  EXPECT_EQ(blob4.height(), 4);
+  EXPECT_EQ(blob4.width(), 3);
+  EXPECT_TRUE(blob4.dptr != nullptr);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_data_layer.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_data_layer.cc b/src/test/model/test_data_layer.cc
new file mode 100644
index 0000000..49519a5
--- /dev/null
+++ b/src/test/model/test_data_layer.cc
@@ -0,0 +1,178 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-08-01 16:09
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include <map>
+#include <vector>
+
+#include "model/data_layer.h"
+#include "model/trainer.h"
+#include "model/sgd_trainer.h"
+#include "model/conv_edge.h"
+#include "model/relu_layer.h"
+#include "proto/model.pb.h"
+
+#include "utils/proto_helper.h"
+
+namespace lapis {
+class ModelTest : public ::testing::Test {
+ public:
+  ModelTest () {
+    ReadProtoFromTextFile("src/test/data/model.conf", &model_proto);
+  }
+ protected:
+  ModelProto model_proto;
+};
+/**********************************************************************
+ * DataLayer Test
+ **********************************************************************/
+class DataLayerTest : public ModelTest {
+ public:
+   DataLayerTest() {
+     label_layer.Init(model_proto.net().layer(0));
+     img_layer.Init(model_proto.net().layer(1));
+     Trainer::InitDataSource(model_proto.trainer().train_data(), &sources);
+     EXPECT_EQ(2, sources.size());
+     sources[0]->LoadData(nullptr);
+     sources[1]->LoadData(nullptr);
+     DLOG(INFO)<<"after init datasources";
+     label_layer.Setup(2, TrainerProto::kBackPropagation, sources);
+     DLOG(INFO)<<"after setup label layer";
+     img_layer.Setup(2, TrainerProto::kBackPropagation, sources);
+     DLOG(INFO)<<"after setup img layer";
+   }
+   ~DataLayerTest() {
+     for(auto& source: sources)
+       delete source;
+   }
+ protected:
+  DataLayer img_layer, label_layer;
+  std::vector<DataSource*> sources;
+};
+
+TEST_F(DataLayerTest, InitSetupForward) {
+  EXPECT_TRUE(label_layer.HasInput());
+  EXPECT_TRUE(img_layer.HasInput());
+  EXPECT_STREQ("DataLayer", DataLayer::kType.c_str());
+
+  EXPECT_EQ(2, label_layer.feature(nullptr).num());
+  EXPECT_EQ(1, label_layer.feature(nullptr).channels());
+  EXPECT_EQ(1, label_layer.feature(nullptr).height());
+  EXPECT_EQ(1, label_layer.feature(nullptr).width());
+
+  EXPECT_EQ(2, img_layer.feature(nullptr).num());
+  EXPECT_EQ(3, img_layer.feature(nullptr).channels());
+  EXPECT_EQ(227, img_layer.feature(nullptr).height());
+  EXPECT_EQ(227, img_layer.feature(nullptr).width());
+
+  img_layer.Forward();
+}
+// TODO(wangwei) test this after outgoing edges are tested
+
+/**********************************************************************
+ * ConvEdge Test
+ **********************************************************************/
+class ConvEdgeTest : public DataLayerTest {
+ public:
+  ConvEdgeTest() {
+    relu.Init(model_proto.net().layer(2));
+    DLOG(INFO)<<"init both layers";
+    layer_map["input_img"]=&img_layer;
+    layer_map["hidden1_relu"]=&relu;
+
+    edge_proto=model_proto.net().edge(0);
+    convedge.Init(edge_proto, layer_map);
+    convedge.Setup(true);
+  }
+ protected:
+  std::map<std::string, Layer*> layer_map;
+  ConvEdge convedge;
+  EdgeProto edge_proto;
+  ReLULayer relu;
+};
+
+TEST_F(ConvEdgeTest, InitSetupForward) {
+  Layer* dest=layer_map.at("hidden1_relu");
+  Blob &b=dest->feature(&convedge);
+  EXPECT_EQ(0,b.num());
+  convedge.SetupTopBlob(&b);
+  int conv_height = (227 + 2 * edge_proto.pad() - edge_proto.kernel_size())
+    / edge_proto.stride() + 1;
+  int conv_width=conv_height;
+  CHECK_EQ(2, b.num());
+  CHECK_EQ(edge_proto.num_output(), b.channels());
+  CHECK_EQ(conv_height, b.height());
+  CHECK_EQ(conv_width, b.width());
+  DLOG(INFO)<<"after shape check";
+
+  Layer* src=layer_map["input_img"];
+  convedge.Forward(src->feature(&convedge), &b, true);
+}
+
+/**********************************************************************
+ * ReLULayer Test
+ **********************************************************************/
+class ReLULayerTest : public ConvEdgeTest {
+ public:
+  ReLULayerTest() {
+    relu.Setup(2, TrainerProto::kBackPropagation, sources);
+    relu_proto=model_proto.net().layer(3);
+  }
+ protected:
+  LayerProto relu_proto;
+};
+
+TEST_F(ReLULayerTest, ForwardWithoutDropout) {
+  EXPECT_EQ(2, relu.feature(&convedge).num());
+  EXPECT_EQ(2, relu.gradient(&convedge).num());
+
+  relu.Forward();
+}
+/**********************************************************************
+ * PoolingEdge Test
+class PoolingEdgeTest : public ReLULayerTest {
+ public:
+  PoolingEdgeTest() {
+    linearlayer.Init(model.net().layer(3));
+    pooledge.Init(model.net().edge(1));
+  }
+
+ protected:
+  PoolingEdge pooledge;
+  LinearLayer linearlayer;
+}
+ **********************************************************************/
+/**********************************************************************
+ * LinearLayer Test
+ **********************************************************************/
+
+/**********************************************************************
+ * LRNEdge Test
+ **********************************************************************/
+
+/**********************************************************************
+ * InnerProductEdge Test
+ **********************************************************************/
+
+/**********************************************************************
+ * SoftmaxLayerLossEdge Test
+ **********************************************************************/
+
+
+
+
+/**********************************************************************
+ * SGDTrainer Test
+ **********************************************************************/
+class SGDTrainerTest : public ModelTest {
+ protected:
+  SGDTrainer sgd;
+};
+
+TEST_F(SGDTrainerTest, Init) {
+  sgd.Init(model_proto.trainer());
+  EXPECT_TRUE(Trainer::phase==Phase::kInit);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_label_source.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_label_source.cc b/src/test/model/test_label_source.cc
new file mode 100644
index 0000000..9b25c2a
--- /dev/null
+++ b/src/test/model/test_label_source.cc
@@ -0,0 +1,59 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-21 19:40
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+#include "disk/label_source.h"
+
+namespace lapis {
+class LabelSourceTest : public ::testing::Test {
+ public:
+  LabelSourceTest() {
+    DataSourceProto ds;
+    ds.set_path("src/test/data/label_source.dat");
+    ds.set_size(12);
+    ds.set_name("label source");
+    ls.Init(ds);
+  }
+
+ protected:
+  LabelSource ls;
+};
+
+TEST_F(LabelSourceTest, LoadData) {
+  auto ptr2names = ls.LoadData(nullptr);
+  EXPECT_EQ(12, ptr2names->size());
+  EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str());
+  EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str());
+  EXPECT_STREQ("img5.JPEG", ptr2names->at(5).c_str());
+  EXPECT_STREQ("img10.JPEG", ptr2names->at(10).c_str());
+  EXPECT_STREQ("img11.JPEG", ptr2names->at(11).c_str());
+}
+
+TEST_F(LabelSourceTest, GetData) {
+  ls.LoadData(nullptr);
+  Blob b;
+  b.Resize(1, 1, 1, 5);
+  ls.GetData(&b);
+  const float *val = b.dptr;
+  EXPECT_EQ(0.0f, val[0]);
+  EXPECT_EQ(1.0f, val[1]);
+  EXPECT_EQ(4.0f, val[2]);
+  EXPECT_EQ(9.0f, val[3]);
+  EXPECT_EQ(16.0f, val[4]);
+  ls.GetData(&b);
+  EXPECT_EQ(4.0f, val[0]);
+  EXPECT_EQ(5.0f, val[1]);
+  EXPECT_EQ(6.0f, val[2]);
+  EXPECT_EQ(7.0f, val[3]);
+  EXPECT_EQ(8.0f, val[4]);
+  ls.GetData(&b);
+  EXPECT_EQ(1.0f, val[0]);
+  EXPECT_EQ(2.0f, val[1]);
+  EXPECT_EQ(0.0f, val[2]);
+  EXPECT_EQ(1.0f, val[3]);
+  EXPECT_EQ(4.0f, val[4]);
+}
+
+}  // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_param.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_param.cc b/src/test/model/test_param.cc
new file mode 100644
index 0000000..520fbe2
--- /dev/null
+++ b/src/test/model/test_param.cc
@@ -0,0 +1,138 @@
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+
+#include "utils/param.h"
+
+using namespace singa;
+
+class ParamTest : public ::testing::Test {
+ public:
+  ParamTest() {
+    wp.set_name("weight");
+    wp.add_shape(3);
+    wp.add_shape(4);
+    bp.set_name("bias");
+    bp.add_shape(4);
+  }
+ protected:
+  Param w, b;
+  ParamProto wp, bp;
+};
+
+TEST_F(ParamTest, ConstantInit) {
+  bp.set_init_method(ParamProto::kConstant);
+  bp.set_value(0.5);
+  b.Init(bp);
+  const float *val = b.content().dptr;
+  EXPECT_EQ(0.5f, val[0]);
+  EXPECT_EQ(0.5f, val[1]);
+  EXPECT_EQ(0.5f, val[2]);
+  EXPECT_EQ(0.5f, val[3]);
+  wp.set_init_method(ParamProto::kConstant);
+  wp.set_value(1.5);
+  w.Init(wp);
+  val = w.content().dptr;
+  EXPECT_EQ(1.5f, val[0]);
+  EXPECT_EQ(1.5f, val[3]);
+  EXPECT_EQ(1.5f, val[4]);
+  EXPECT_EQ(1.5f, val[11]);
+}
+
+TEST_F(ParamTest, UniformInit) {
+  bp.set_init_method(ParamProto::kUniform);
+  bp.set_value(1.0f);
+  b.Init(bp);
+  const float *val = b.content().dptr;
+  EXPECT_TRUE(val[0] >= -1 && val[0] <= 1);
+  EXPECT_TRUE(val[1] >= -1 && val[2] <= 1);
+  EXPECT_TRUE(val[2] >= -1 && val[2] <= 1);
+  EXPECT_TRUE(val[3] >= -1 && val[3] <= 1);
+  wp.set_init_method(ParamProto::kUniform);
+  wp.set_value(1.0f);
+  w.Init(wp);
+  val = w.content().dptr;
+  EXPECT_TRUE(val[0] >= -1 && val[0] <= 1);
+  EXPECT_TRUE(val[3] >= -1 && val[3] <= 1);
+  EXPECT_TRUE(val[4] >= -1 && val[4] <= 1);
+  EXPECT_TRUE(val[11] >= -1 && val[11] <= 1);
+}
+
+TEST_F(ParamTest, UniformSqrtFanInInit) {
+  wp.set_init_method(ParamProto::kUniformSqrtFanIn);
+  wp.set_value(2.0f);
+  w.Init(wp);
+  const float *val = w.content().dptr;
+  EXPECT_TRUE(val[0] >= -2 && val[0] <= 2);
+  EXPECT_TRUE(val[3] >= -2 && val[3] <= 2);
+  EXPECT_TRUE(val[4] >= -2 && val[4] <= 2);
+  EXPECT_TRUE(val[11] >= -2 && val[11] <= 2);
+}
+
+
+TEST_F(ParamTest, UniformSqrtFanInOutInit) {
+  wp.set_init_method(ParamProto::kUniformSqrtFanInOut);
+  wp.set_value(1.0f);
+  float low=1.0f, high=5.0f;
+  wp.set_low(low);
+  wp.set_high(high);
+  w.Init(wp);
+  const float *val = w.content().dptr;
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+  float factor = wp.value() / sqrt(wp.shape(0) + wp.shape(1));
+  low=low*factor;
+  high=high*factor;
+  LOG(INFO)<<low<<" "<<high;
+  EXPECT_TRUE(val[0] >= low && val[0] <= high);
+  EXPECT_TRUE(val[3] >= low && val[3] <= high);
+  EXPECT_TRUE(val[4] >= low && val[4] <= high);
+  EXPECT_TRUE(val[11] >= low && val[11] <= high);
+}
+
+TEST_F(ParamTest, GaussianInit) {
+  int len=5000, mean=0.0f, std=1.0f;
+  ParamProto p;
+  p.set_name("bias");
+  p.add_shape(1);
+  p.add_shape(len);
+  p.set_init_method(ParamProto::kGaussain);
+  p.set_value(1.0f);
+  p.set_mean(mean);
+  p.set_std(std);
+  w.Init(p);
+
+  const float *val = w.content().dptr;
+  float dmean=0.0f;
+  for(int i=0;i<len;i++)
+    dmean+=val[i];
+  dmean/=len;
+  float dstd=0.0f;
+  for(int i=0;i<len;i++)
+    dstd+=(dmean-val[i])*(dmean-val[i]);
+  dstd/=len;
+  EXPECT_TRUE(std::abs(mean-dmean)<0.1);
+  EXPECT_TRUE(std::abs(std-dstd)<0.1);
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+}
+
+TEST_F(ParamTest, GaussianSqrtFanInInit) {
+  wp.set_init_method(ParamProto::kGaussainSqrtFanIn);
+  wp.set_value(1.0f);
+  wp.set_mean(0);
+  wp.set_std(1.0f);
+  w.Init(wp);
+  //const float *val = w.content().dptr;
+  /*
+  LOG(INFO) << val[0] << " " << val[1] << " " << val[2] << " " << val[3];
+  LOG(INFO) << val[4] << " " << val[5] << " " << val[6] << " " << val[7];
+  LOG(INFO) << val[8] << " " << val[9] << " " << val[10] << " " << val[11];
+  */
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_proto.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_proto.cc b/src/test/model/test_proto.cc
new file mode 100644
index 0000000..f6d81fd
--- /dev/null
+++ b/src/test/model/test_proto.cc
@@ -0,0 +1,67 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-15 21:54
+#include <glog/logging.h>
+#include <gtest/gtest.h>
+#include "proto/model.pb.h"
+#include "utils/proto_helper.h"
+namespace lapis {
+
+// use const Message& m=..., otherwise may lead to segment fault
+TEST(ProtoTest, ReadFromFile) {
+  ModelProto model;
+  LOG(INFO)<<"start....";
+  lapis::ReadProtoFromTextFile("src/test/data/model.conf", &model);
+  LOG(INFO)<<"after reading file...";
+  EXPECT_STREQ("caffe_config", model.name().c_str());
+
+  // layer and edge size
+  const NetProto& net = model.net();
+  EXPECT_EQ(15, net.layer().size());
+  EXPECT_EQ(14, net.edge().size());
+  LOG(INFO)<<"after size check...";
+
+  // layer config
+  LayerProto layer1 = net.layer().Get(1);
+  EXPECT_STREQ("input_img", layer1.name().c_str());
+  EXPECT_STREQ("DataLayer", layer1.type().c_str());
+  LOG(INFO)<<"after datalayer check...";
+  // edge config
+  EdgeProto edge0 = net.edge().Get(0);
+  EXPECT_STREQ("input_img-hidden1_relu", edge0.name().c_str());
+  EXPECT_STREQ("ConvEdge", edge0.type().c_str());
+  EXPECT_EQ(2, edge0.param().size());
+  LOG(INFO)<<"after first edge check...";
+  // param config
+  ParamProto param1 = edge0.param().Get(0);
+  EXPECT_TRUE(ParamProto::kGaussain == param1.init_method());
+  EXPECT_EQ(0.0f, param1.mean());
+  EXPECT_EQ(0.01f, param1.std());
+  EXPECT_EQ(1.0f, param1.learning_rate_multiplier());
+  LOG(INFO)<<"after param of first edge check...";
+
+  ParamProto param2 = edge0.param().Get(1);
+  EXPECT_TRUE(ParamProto::kConstant == param2.init_method());
+  EXPECT_EQ(0.0f, param2.value());
+  EXPECT_EQ(0.0f, param2.weight_decay_multiplier());
+  LOG(INFO)<<"after param of second edge check...";
+
+  // trainer config
+  const TrainerProto& trainer = model.trainer();
+  const SGDProto& sgd=trainer.sgd();
+  EXPECT_EQ(227, sgd.train_batchsize());
+  EXPECT_EQ(0.01f, sgd.base_learning_rate());
+  EXPECT_TRUE(SGDProto::kStep== sgd.learning_rate_change());
+  LOG(INFO)<<"after sgd check...";
+
+  // data source config
+  EXPECT_EQ(2,trainer.train_data().size());
+  LOG(INFO)<<"after size check...";
+  const DataSourceProto& data=trainer.train_data(0);
+  LOG(INFO)<<"after get data...";
+  EXPECT_STREQ("RGBDirSource", data.type().c_str());
+  LOG(INFO)<<"after type check...";
+  EXPECT_EQ(50000, data.size());
+  EXPECT_EQ(3, data.channels());
+  LOG(INFO)<<"after data source check...";
+}
+} // namespace lapis

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/model/test_rgb_dir_source.cc
----------------------------------------------------------------------
diff --git a/src/test/model/test_rgb_dir_source.cc b/src/test/model/test_rgb_dir_source.cc
new file mode 100644
index 0000000..36ac21a
--- /dev/null
+++ b/src/test/model/test_rgb_dir_source.cc
@@ -0,0 +1,63 @@
+// Copyright © 2014 Wei Wang. All Rights Reserved.
+// 2014-07-21 21:52
+
+#include <gtest/gtest.h>
+#include <glog/logging.h>
+#include <algorithm>
+
+#include "proto/model.pb.h"
+#include "disk/rgb_dir_source.h"
+#include "disk/label_source.h"
+
+namespace lapis {
+class RGBDirSourceTest : public ::testing::Test {
+ public:
+  RGBDirSourceTest() {
+    DataSourceProto ds;
+    ds.set_path("src/test/data/rgb_dir");
+    ds.set_mean_file("src/test/data/imagenet_mean.binaryproto");
+    ds.set_size(3);
+    ds.set_height(256);
+    ds.set_width(256);
+    ds.set_offset(2);
+    ds.set_name("rgb dir source");
+    rgbs.Init(ds);
+  }
+
+ protected:
+  RGBDirSource rgbs;
+};
+
+TEST_F(RGBDirSourceTest, LoadDataNoInputKeys) {
+  auto &ptr2names = rgbs.LoadData(nullptr);
+  EXPECT_EQ(3, ptr2names->size());
+  sort(ptr2names->begin(), ptr2names->end());
+  EXPECT_STREQ("img0.JPEG", ptr2names->at(0).c_str());
+  EXPECT_STREQ("img1.JPEG", ptr2names->at(1).c_str());
+  EXPECT_STREQ("img2.JPEG", ptr2names->at(2).c_str());
+}
+
+TEST_F(RGBDirSourceTest, LoadDataWithInputKeys) {
+  LabelSource ls;
+  DataSourceProto ds;
+  ds.set_path("src/test/data/label_source.dat");
+  ds.set_name("label source");
+  ds.set_size(3);
+  ls.Init(ds);
+  auto ptr2names1 = ls.LoadData(nullptr);
+  auto ptr2names2 = rgbs.LoadData(ptr2names1);
+  EXPECT_EQ(3, ptr2names2->size());
+  for (int i = 0; i < 3; i++)
+    EXPECT_STREQ(ptr2names1->at(i).c_str(), ptr2names2->at(i).c_str());
+}
+
+TEST_F(RGBDirSourceTest, GetData) {
+  Blob b;
+  b.Resize(256,256,3,2);
+  rgbs.LoadData(nullptr);
+  rgbs.GetData(&b);
+  rgbs.GetData(&b);
+  rgbs.GetData(&b);
+}
+}  // namespace lapis
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_cluster.cc
----------------------------------------------------------------------
diff --git a/src/test/test_cluster.cc b/src/test/test_cluster.cc
new file mode 100644
index 0000000..d86463a
--- /dev/null
+++ b/src/test/test_cluster.cc
@@ -0,0 +1,95 @@
+#include <fstream>
+#include "gtest/gtest.h"
+#include "proto/cluster.pb.h"
+#include "utils/cluster.h"
+
+using namespace singa;
+
+string folder="src/test/data/";
+/*
+ClusterProto GenClusterProto(){
+  ClusterProto proto;
+  int nworker=6, nserver=4;
+  proto.set_nworkers(nworker);
+  proto.set_nservers(nserver);
+  proto.set_nworkers_per_group(3);
+  proto.set_nservers_per_group(2);
+  proto.set_nthreads_per_worker(1);
+  proto.set_nthreads_per_server(2);
+
+  proto.set_hostfile(folder+"/hostfile");
+
+  std::ofstream fout(folder+"/hostfile", std::ofstream::out);
+  for(int i=0;i<nworker+nserver;i++){
+    char tmp[20];
+    sprintf(tmp, "awan-0-%02d-0", i);
+    fout<<tmp<<std::endl;
+  }
+  fout.flush();
+  fout.close();
+  return proto;
+}
+
+TEST(ClusterTest, NoServer){
+  ClusterProto proto=GenClusterProto();
+  proto.set_nservers(0);
+  auto cluster=Cluster::Get(proto, 0);
+  ASSERT_EQ(proto.nworkers(),cluster->nworkers());
+  ASSERT_EQ(0, cluster->nservers());
+  ASSERT_EQ(proto.nworkers_per_group(),cluster->nworkers_per_group());
+  ASSERT_EQ(proto.nservers_per_group(),cluster->nservers_per_group());
+  ASSERT_FALSE(cluster->AmIServer());
+  ASSERT_TRUE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(0, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-00-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 5);
+  ASSERT_EQ(2,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(0, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-05-0", cluster->host_addr().c_str());
+}
+
+TEST(ClusterTest, SingleServerGroup){
+  ClusterProto proto=GenClusterProto();
+  proto.set_nservers(2);
+  auto cluster=Cluster::Get(proto, 3);
+  ASSERT_FALSE(cluster->AmIServer());
+  ASSERT_TRUE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(1, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-03-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 7);
+  ASSERT_EQ(1,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(1, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str());
+}
+
+TEST(ClusterTest, MultiServerGroups){
+  ClusterProto proto=GenClusterProto();
+  auto cluster=Cluster::Get(proto, 7);
+  ASSERT_EQ(1,cluster->group_procs_id());
+  ASSERT_EQ(0,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(2, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-07-0", cluster->host_addr().c_str());
+
+  cluster=Cluster::Get(proto, 8);
+  ASSERT_TRUE(cluster->AmIServer());
+  ASSERT_FALSE(cluster->AmIWorker());
+  ASSERT_EQ(0,cluster->group_procs_id());
+  ASSERT_EQ(1,cluster->group_id());
+  ASSERT_EQ(2, cluster->nworker_groups());
+  ASSERT_EQ(2, cluster->nserver_groups());
+  ASSERT_STREQ("awan-0-08-0", cluster->host_addr().c_str());
+}
+*/

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_communication.cc
----------------------------------------------------------------------
diff --git a/src/test/test_communication.cc b/src/test/test_communication.cc
new file mode 100644
index 0000000..c9c035f
--- /dev/null
+++ b/src/test/test_communication.cc
@@ -0,0 +1,158 @@
+#include <thread>
+#include <vector>
+#include "gtest/gtest.h"
+#include "communication/msg.h"
+#include "communication/socket.h"
+using std::vector;
+using namespace singa;
+
+const char* ping="PING",*pong="PONG";
+/**
+ * Connect dealer with (gid, id, flag) to stub router
+ */
+void Connect(Dealer* dealer, int gid, int id, int flag){
+  dealer->Connect("inproc://router");
+  Msg msg;
+  msg.set_src(gid, id, flag);
+  msg.set_dst(0,0,2);
+  msg.set_type(0);
+  msg.add_frame(ping, 4);
+  dealer->Send(&msg);
+}
+
+/**
+ * Dealer thread, ping-pong with the stub router
+ */
+void DealerPingPong(int id){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, id, 0);
+  Msg* msg=dealer->Receive();
+  int flag=msg->src_flag();
+  ASSERT_EQ(2, flag);
+  ASSERT_EQ(0, msg->dst_group_id());
+  ASSERT_EQ(id, msg->dst_id());
+  ASSERT_STREQ(pong, (char*)msg->frame_data());
+  delete msg;
+  delete dealer;
+}
+
+/**
+ * Worker thread, connect to router and communicate with server thread
+ */
+void WorkerDealer(int sid, int did){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, sid, 0);
+  for(int i=0;i<2;i++){
+    {
+      Msg msg;
+      msg.set_src(0, sid, 0);
+      msg.set_dst(0, did, 1);
+      msg.set_type(3);
+      msg.set_target(i);
+      dealer->Send(&msg);
+    }
+    {
+      Msg *msg=dealer->Receive();
+      ASSERT_EQ(0, msg->src_group_id());
+      ASSERT_EQ(did, msg->src_id());
+      ASSERT_EQ(1, msg->src_flag());
+      delete msg;
+    }
+  }
+  delete dealer;
+}
+
+/**
+ * Server thread, connect to router and communicate with worker thread
+ */
+void ServerDealer(int id, int n){
+  Dealer* dealer=new Dealer();
+  Connect(dealer, 0, id, 1);
+  for(int i=0;i<n;i++){
+    Msg *msg=dealer->Receive();
+    Msg reply;
+    reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag());
+    reply.set_src(0, id, 1);
+    dealer->Send(&reply);
+    delete msg;
+  }
+  delete dealer;
+}
+
+TEST(CommunicationTest, DealerRouterPingPong){
+  int n=2;
+  vector<std::thread> threads;
+  for(int i=0;i<n;i++)
+    threads.push_back(std::thread(DealerPingPong, i));
+  Router* router=new Router();
+  router->Bind("");
+  for(int k=0;k<n;k++){
+    Msg* msg=router->Receive();
+    ASSERT_EQ(0, msg->src_group_id());
+    ASSERT_EQ(2, msg->dst_flag());
+    ASSERT_STREQ(ping, (char*)msg->frame_data());
+
+    Msg reply;
+    reply.set_src(0,0,2);
+    reply.set_dst(msg->src_group_id(), msg->src_id(), msg->src_flag());
+    reply.add_frame(pong, 4);
+    router->Send(&reply);
+    delete msg;
+  }
+
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}
+
+TEST(CommunicationTest, nWorkers1Server){
+  int nworker=2;
+  vector<std::thread> threads;
+  for(int i=0;i<nworker;i++)
+    threads.push_back(std::thread(WorkerDealer, i, 0));
+  //threads.push_back(std::thread(ServerDealer, 0, 4));
+  Router* router=new Router();
+  router->Bind("");
+  int nmsg=4*nworker;
+  int k=0;
+  while(nmsg>0){
+    Msg* msg=router->Receive();
+    if(2== msg->dst_flag()){
+      ASSERT_STREQ(ping, (char*)msg->frame_data());
+      k++;
+      if(k==nworker)
+        threads.push_back(std::thread(ServerDealer, 0, 2*nworker));
+    }else{
+      nmsg--;
+      router->Send(msg);
+    }
+    delete msg;
+  }
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}
+
+TEST(CommunicationTest, 2Workers2Server){
+  vector<std::thread> threads;
+  threads.push_back(std::thread(WorkerDealer, 0, 0));
+  threads.push_back(std::thread(WorkerDealer, 1, 1));
+  threads.push_back(std::thread(ServerDealer, 0, 2));
+  threads.push_back(std::thread(ServerDealer, 1, 2));
+  Router* router=new Router();
+  router->Bind("");
+  int n=8;
+  while(n>0){
+    Msg* msg=router->Receive();
+    if(2== msg->dst_flag()){
+      ASSERT_STREQ(ping, (char*)msg->frame_data());
+    }else{
+      n--;
+      router->Send(msg);
+    }
+    delete msg;
+  }
+  delete router;
+  for(auto& thread:threads)
+    thread.join();
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/test/test_shard.cc
----------------------------------------------------------------------
diff --git a/src/test/test_shard.cc b/src/test/test_shard.cc
new file mode 100644
index 0000000..c96d876
--- /dev/null
+++ b/src/test/test_shard.cc
@@ -0,0 +1,56 @@
+#include <gtest/gtest.h>
+#include <sys/stat.h>
+
+#include "utils/data_shard.h"
+
+std::string key[]={"firstkey","secondkey","3key", "key4", "key5"};
+std::string tuple[]={"firsttuple","2th-tuple","thridtuple", "tuple4", "tuple5"};
+
+using namespace singa;
+
+TEST(DataShardTest, CreateDataShard){
+  std::string path="src/test/data/shard_test";
+  mkdir(path.c_str(), 0755);
+  DataShard shard(path, DataShard::kCreate, 50);
+  shard.Insert(key[0], tuple[0]);
+  shard.Insert(key[1], tuple[1]);
+  shard.Insert(key[2], tuple[2]);
+  shard.Flush();
+}
+
+TEST(DataShardTest, AppendDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kAppend, 50);
+  shard.Insert(key[3], tuple[3]);
+  shard.Insert(key[4], tuple[4]);
+  shard.Flush();
+}
+TEST(DataShardTest, CountDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kRead, 50);
+  int count=shard.Count();
+  ASSERT_EQ(5, count);
+}
+
+TEST(DataShardTest, ReadDataShard){
+  std::string path="src/test/data/shard_test";
+  DataShard shard(path, DataShard::kRead, 50);
+  std::string k, t;
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[0].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[0].c_str(), t.c_str());
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[1].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[1].c_str(), t.c_str());
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[4].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[4].c_str(), t.c_str());
+
+  ASSERT_FALSE(shard.Next(&k, &t));
+  shard.SeekToFirst();
+  ASSERT_TRUE(shard.Next(&k, &t));
+  ASSERT_STREQ(key[0].c_str(), k.c_str());
+  ASSERT_STREQ(tuple[0].c_str(), t.c_str());
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_server.cc b/src/trainer/pm_server.cc
new file mode 100644
index 0000000..28fa28d
--- /dev/null
+++ b/src/trainer/pm_server.cc
@@ -0,0 +1,99 @@
+#include <gflags/gflags.h>
+#include <glog/logging.h>
+#include "trainer/pm_server.h"
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include <vector>
+
+using std::vector;
+
+namespace singa{
+void PMServer::Setup(int group_id, int server_id, shared_ptr<ParamShard> shard,
+      const UpdaterProto& proto){
+  group_id_=group_id;
+  server_id_=server_id;
+  shard_=shard;
+  updater_=shared_ptr<Updater>(Singleton<Factory<Updater>>::Instance()
+      ->Create("Updater"));
+  updater_->Init(proto);
+}
+
+PMServer::~PMServer(){
+}
+
+bool PMServer::SyncNow(){
+  return false;
+}
+Msg* PMServer::HandlePut(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+    LOG(ERROR)<<"Param ("<<id<<") is put more than once";
+    param=shard_->at(id);
+  }else{
+    param=shared_ptr<Param>(Singleton<Factory<Param>>::Instance()
+        ->Create("Param"));
+    param->set_id(id);
+    (*shard_)[id]=param;
+  }
+  return param->HandlePutMsg(msg);
+}
+
+Msg* PMServer::HandleGet(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+    param=shard_->at(id);
+    return param->HandleGetMsg(msg);
+	} else {
+		//re-construct msg to be re-queued.
+		//the calling function will send this message off
+    return *msg;
+	}
+}
+
+Msg* PMServer::HandleUpdate(Msg **msg) {
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+		//repsonse of the format: <identity><type: kData><paramId><param content>
+    param=shard_->at(id);
+    Msg* tmp=static_cast<Msg*>((*msg)->CopyAddr());
+    param->ParseUpdateMsg(msg);
+    updater_->Update(param->version(), param);
+    param->set_version(param->version()+1);
+    auto response=param->GenUpdateResponseMsg();
+    tmp->SwapAddr();
+    response->SetAddr(tmp);
+    delete tmp;
+    return response;
+	} else {
+    LOG(ERROR)<<"Param ("<<id<<") is not maintained by server ("<<group_id_
+      <<", "<<server_id_<<")";
+		//re-construct msg to be re-queued.
+		return *msg;
+	}
+}
+
+Msg* PMServer::HandleSyncRequest(Msg **msg){
+  int id=(*msg)->target();
+  shared_ptr<Param> param=nullptr;
+  if(shard_->find(id)!=shard_->end()){
+		//repsonse of the format: <identity><type: kData><paramId><param content>
+    param=shard_->at(id);
+    return param->HandleSyncMsg(msg);
+	} else {
+		//re-construct msg to be re-queued.
+    return *msg;
+	}
+}
+
+int PMServer::HandleSyncResponse(Msg **msg){
+  int id=(*msg)->target();
+  CHECK(shard_->find(id)!=shard_->end());
+  return shard_->at(id)->ParseSyncResponseMsg(msg);
+}
+
+} // namespace singa
+
+

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/pm_worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/pm_worker.cc b/src/trainer/pm_worker.cc
new file mode 100644
index 0000000..7269578
--- /dev/null
+++ b/src/trainer/pm_worker.cc
@@ -0,0 +1,344 @@
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include "gflags/gflags.h"
+#include <glog/logging.h>
+#include "proto/model.pb.h"
+#include "trainer/pm_worker.h"
+#include "mshadow/tensor.h"
+#include "utils/cluster.h"
+
+
+namespace singa{
+
+void PMWorker::Setup(int group_id, int worker_id,
+    shared_ptr<ParamShard> shard){
+  group_id_=group_id;
+  worker_id_=worker_id;
+  shard_=shard;
+}
+int PMWorker::Sharding(int param_id){
+  return param_id%Cluster::Get()->nservers_per_group();
+}
+/*
+int PMWorker::Sharding(int param_id){
+  static map<int, int> id2procs;
+  if(id2procs.find(param_id)==id2procs.end()){
+  auto cluster=Cluster::Get();
+  int server_group=group_id_%cluster->nserver_groups();
+  int nprocs_per_server_group=
+    cluster->nservers_per_group()/cluster->nservers_per_procs();
+  int procsid=server_group*nprocs_per_server_group+
+    param_id%nprocs_per_server_group;
+  procsid= cluster->server_worker_separate()?
+    cluster->nworker_procs()+procsid:procsid;
+  id2procs[param_id]=procsid;
+  }
+  return id2procs[param_id];
+}
+*/
+
+Msg* PMWorker::Put(Msg** msg){
+  return *msg;
+}
+
+Msg* PMWorker::Put(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  // only owner can put shared parameter
+  if(param->owner()<0||param->owner()==param->id()){
+    Msg* msg= param->GenPutMsg(&step);
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+        Sharding(param->id()), kServer);
+    msg->set_type(kPut);
+    msg->set_target(param->id());
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Get(Msg** msg){
+  return *msg;
+}
+
+Msg* PMWorker::Get(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  bool send=false;
+  int id=param->id();
+  shared_ptr<ParamCounter> entry=nullptr;
+  if(param->owner()>=0){
+    entry=shard_->at(id);
+    entry->nGet++;
+    send=entry->nGet/entry->nLocal==step;
+  }
+  if(param->owner()<0||send){
+    Msg* msg=nullptr;
+    if(param->owner()<0){
+      msg=param->GenGetMsg(&step);
+      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+          Sharding(id), kServer);
+    } else {
+      msg=entry->param->GenGetMsg(&step);
+      msg->set_dst(entry->owner_procs,kStub);
+    }
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    msg->set_type(kGet);
+    msg->set_target(id);
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Update(Msg** msg){
+  return *msg;
+}
+Msg* PMWorker::Update(shared_ptr<Param> param, int step){
+  param->set_version(step);
+  bool send=false;
+  int id=param->id();
+  shared_ptr<ParamCounter> entry;
+  if(param->owner()>=0){
+    entry=shard_->at(param->id());
+    entry->nGet++;
+    send=entry->nGet/entry->nLocal==step;
+    auto shape=mshadow::Shape1(param->size());
+    mshadow::Tensor<mshadow::cpu,1> grad(param->mutable_cpu_grad(), shape);
+    mshadow::Tensor<mshadow::cpu,1> agg(entry->param->mutable_cpu_grad(), shape);
+    agg+=grad;
+  }
+  if(param->owner()<0||send){
+    Msg* msg=nullptr;
+    if(param->owner()<0){
+      msg=param->GenUpdateMsg(&step);
+      msg->set_dst(group_id_/Cluster::Get()->nworker_groups_per_server_group(),
+          Sharding(id), kServer);
+    } else {
+      entry->param->GenUpdateMsg(&step);
+      msg->set_dst(entry->owner_procs,kStub);
+      memset(param->mutable_cpu_data(), 0, sizeof(float)*param->size());
+    }
+    msg->set_type(kUpdate);
+    msg->set_target(id);
+    msg->set_src(group_id_, worker_id_, kWorkerParam);
+    return msg;
+  }else
+    return nullptr;
+}
+
+Msg* PMWorker::Collect(Msg** msg){
+  int id=(*msg)->target();
+  int type=(*msg)->type();
+  auto pp=shard_->at(id)->param;
+  if(type==kRGet){
+    pp->ParseGetResponseMsg(msg);
+  }else if(type==kRUpdate){
+    pp->ParseUpdateResponseMsg(msg);
+  }
+  if(pp->owner()>=0){
+    // forwarding to workers on other procs
+  }
+  delete (*msg);
+  *msg=nullptr;
+  return nullptr;
+}
+
+/*
+//id is the global worker id
+SingaClient::SingaClient(int global_id, Topology &topology, vector<string> &hosts) {
+	//Read the config files and store endpoints
+	id_ = global_id;
+
+	int n_workers = hosts.size() - topology.nservers();
+	int n_worker_groups = topology.nworker_groups();
+	int group_size = n_workers/n_worker_groups;
+	int server_group_size = topology.nservers()/topology.server_group_size();
+	FLAGS_client_threads = topology.worker_threads();
+
+	local_id_ = (id_-topology.nservers())%group_size;//local worker id.
+	group_id_ = (id_-topology.nservers())/group_size;
+
+	VLOG(3) << "Parsing client config for "<<hosts[id_];
+
+	//connect to all server in the server group group_id_
+	int start_server_idx = group_id_*server_group_size;
+	int end_server_idx = start_server_idx+server_group_size;
+
+	for (int i = start_server_idx; i < end_server_idx; i++) {
+		char *neighbor_endpoint = (char*) malloc(256);
+		sprintf(neighbor_endpoint, "tcp://%s:%d", hosts[i].c_str(), topology.port());
+		neighbors_.push_back(neighbor_endpoint);
+		VLOG(3) << "Worker neighbor (server): "<<neighbor_endpoint;
+	}
+
+	sprintf(backend_endpoint_, "inproc://singanus%d",id_);
+
+	//Create shared paramshard
+	param_shard_ = new ParamShard(id_,0);
+}
+
+void SingaClient::StartClient(){
+	//Create and connect sockets to the server
+	vector<void *> server_sockets;
+	zctx_t *context = zctx_new();
+	int nservers = neighbors_.size();
+	int rc;
+	for (int i=0; i<nservers; i++){
+		void *socket = zsocket_new(context, ZMQ_DEALER);
+		rc = zsocket_connect(socket, neighbors_[i]);
+		VLOG(3) << "Connected to neighbor " <<neighbors_[i];
+		assert(rc==0);
+		server_sockets.push_back(socket);
+	}
+
+	//Create and bind backend socket
+	void *backend = zsocket_new(context, ZMQ_ROUTER);
+	rc = zsocket_bind(backend, backend_endpoint_);
+	assert(rc==0);
+
+	//Start client threads
+	for (int i=0; i<FLAGS_client_threads; i++){
+		void * socket = zthread_fork(context, ClientThread, this);
+		zmsg_t *control_msg = zmsg_new();
+		if (i==0 && local_id_==0)
+			zmsg_pushstr(control_msg,POPULATE);
+		else
+			zmsg_pushstr(control_msg, WAIT);
+		zmsg_send(&control_msg, socket);
+	}
+
+	//Star the message loop
+	bool is_running = true;
+	int nsockets= nservers+1;
+	while (is_running) {
+		zmq_pollitem_t items[nsockets];
+		for (int i = 0; i < nsockets-1; i++)
+			items[i] = {server_sockets[i], 0, ZMQ_POLLIN, 0};
+		items[nsockets-1] = {backend, 0, ZMQ_POLLIN, 0};
+
+		int rc = zmq_poll(items,nsockets,-1);
+		if (rc<0) break;
+
+		for (int i=0; i<nsockets-1; i++){
+			if (items[i].revents & ZMQ_POLLIN){
+				zmsg_t *msg = zmsg_recv(server_sockets[i]);
+				if (!msg){
+					is_running = false;
+					break;
+				}
+				//forward to backend
+				zmsg_send(&msg, backend);
+			}
+		}
+		if (items[nsockets-1].revents & ZMQ_POLLIN){
+			//compute serverId from paramId and forward to the socket
+			zmsg_t *msg = zmsg_recv(backend);
+			if (!msg) is_running=false;
+			zframe_t *identity = zmsg_pop(msg);
+			zframe_t *type = zmsg_pop(msg);
+			int paramId;
+			sscanf(zmsg_popstr(msg), "%d", &paramId);
+			zmsg_pushstrf(msg,"%d",paramId);
+			zmsg_prepend(msg,&type);
+			zmsg_prepend(msg,&identity);
+			zmsg_send(&msg, server_sockets[param_to_server_id(paramId)]);
+		}
+	}
+
+	zsocket_destroy(context, backend);
+	for (int i=0; i<nsockets-1; i++)
+		zsocket_destroy(context, server_sockets[i]);
+	zctx_destroy(&context);
+}
+
+vector<Param*> gen_random_params() {
+	int size[] = { 1960000, 2500, 5000000, 2000, 3000000, 1500, 1500000, 1000, 500000, 500, 5000, 10 };
+	vector<Param*> params;
+	for (int i = 0; i < 12; i++) {
+		ParamProto proto;
+		proto.set_id(i);
+		proto.set_init_method(ParamProto::kGaussain);
+		Param* p = new Param();
+		p->Setup(proto, vector<int> { size[i] }, 0);
+		p->Init();
+		params.push_back(p);
+	}
+	return params;
+}
+
+//simple mapping
+int SingaClient::param_to_server_id(int paramId){
+	return paramId % neighbors_.size();
+}
+
+void ClientThread(void *args, zctx_t *ctx, void *pipe){
+	SingaClient *client = static_cast<SingaClient*>(args);
+
+	//Create back-end socket and connect to the main thread
+	void *backend = zsocket_new(ctx, ZMQ_DEALER);
+	int rc = zsocket_connect(backend, client->backend_endpoint());
+	assert(rc==0);
+	//Create PMClient object
+	PMClient *pmclient = new PMClient(client->id(), client->param_shard(), backend);
+
+	//FOR TESTING ONLY. REMOVE THIS!
+	//wait for control from main thread
+	vector<Param*> params = gen_random_params();
+	zmsg_t *control_msg = zmsg_recv(pipe);
+	zframe_t *msg = zmsg_pop(control_msg);
+	if (zframe_streq(msg,WAIT))
+		zclock_sleep(2000); //2s
+	else{
+		for (int i=0; i<params.size(); i++){
+			pmclient->Put(i, params[i]);
+		}
+		VLOG(3)<<"Done PUT requests for populating servers.";
+		zclock_sleep(2000);
+	}
+	zframe_destroy(&msg);
+	//END TESTING
+	LOG(ERROR) << "Done putting";
+
+	//first, get the params
+
+	test_get(pmclient);
+	test_collect(pmclient);
+
+
+	int iterations = 1;
+	while (iterations<=200){
+		VLOG(3) << "Iteration "<<iterations;
+		test_update(pmclient, params);
+		test_collect(pmclient);
+		iterations++;
+	}
+
+	zsocket_destroy(ctx, backend);
+}
+
+void test_get(PMClient *client){
+	for (int i=0; i<12; i++){
+		Param pm;
+		int status = client->Get(i, &pm);
+		assert(status==NON_LOCAL);
+	}
+}
+
+void test_collect(PMClient *client){
+	for (int i=0; i<12; i++){
+		Param pm;
+		int64_t start_time = zclock_time();
+		while (!client->Collect(&pm))
+			zclock_sleep(1);
+		int64_t end_time = zclock_time();
+		VLOG(3) << "Collected: " <<(end_time-start_time);
+	}
+}
+
+void test_update(PMClient *client, vector<Param*> params){
+	for (int i=0; i<params.size(); i++)
+		client->Update(i, params[i]);
+}
+*/
+
+
+} //namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/server.cc
----------------------------------------------------------------------
diff --git a/src/trainer/server.cc b/src/trainer/server.cc
new file mode 100644
index 0000000..bf0ad03
--- /dev/null
+++ b/src/trainer/server.cc
@@ -0,0 +1,68 @@
+#include <list>
+#include <tuple>
+#include <queue>
+#include "trainer/server.h"
+#include "utils/param.h"
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include "utils/cluster.h"
+
+
+namespace singa {
+Server::Server(int group_id, int server_id):
+  group_id_(group_id), server_id_(server_id){}
+
+void Server::Setup(const UpdaterProto& proto,
+    shared_ptr<PMServer::ParamShard> shard,
+    shared_ptr<Dealer> dealer){
+	//VLOG(3) << "Parsing config file for host "<<hosts[id_] << " server id = " <<id_;
+  pmserver_=shared_ptr<PMServer>(Singleton<Factory<PMServer>>::Instance()
+      ->Create("PMServer"));
+  pmserver_->Setup(group_id_, server_id_, shard, proto);
+  dealer_=dealer;
+}
+
+void Server::Run(){
+  Msg* ping=new Msg();
+  ping->set_src(group_id_, server_id_, kServer);
+  ping->set_dst(0,0,kStub);
+  ping->set_type(kConnect);
+  dealer_->Send(ping);
+  int timeout=Cluster::Get()->server_timeout();
+  Poller poller;
+  poller.Add(dealer_.get());
+	//start recv loop and process requests
+  while (true){
+    Msg* msg=dealer_->Receive();
+    if (msg==nullptr)
+      break;
+    Msg* response=nullptr;
+    int type=msg->type();
+    switch (type){
+      case kPut:
+        response = pmserver_->HandlePut(&msg);
+        break;
+      case kGet:
+        response = pmserver_->HandleGet(&msg);
+        break;
+      case kUpdate:
+        response = pmserver_->HandleUpdate(&msg);
+        break;
+      case kSyncRequest:
+        VLOG(3)<<"Handle SYNC-REQUEST";
+        response = pmserver_->HandleSyncRequest(&msg);
+        break;
+      case kSyncResponse:
+        VLOG(3) << "Handle SYNC response";
+        pmserver_->HandleSyncResponse(&msg);
+        break;
+    }
+
+    if (response!=nullptr)
+      dealer_->Send(response);
+  }
+}
+
+
+
+} /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/trainer.cc
----------------------------------------------------------------------
diff --git a/src/trainer/trainer.cc b/src/trainer/trainer.cc
new file mode 100644
index 0000000..3621b7e
--- /dev/null
+++ b/src/trainer/trainer.cc
@@ -0,0 +1,206 @@
+#include <thread>
+#include <vector>
+#include <map>
+#include <glog/logging.h>
+#include "trainer/trainer.h"
+using std::vector;
+using std::map;
+
+namespace singa {
+int ProcsIDOf(int group_id, int id, int flag){
+  int procsid;
+  auto cluster=Cluster::Get();
+  if(flag==kServer){
+    procsid=group_id*cluster->nservers_per_group()/
+      cluster->nservers_per_procs()+id/cluster->nservers_per_procs();
+    if(cluster->server_worker_separate())
+      procsid+=cluster->nworker_procs();
+  }else if(flag==kWorkerLayer || flag==kWorkerParam){
+    procsid=group_id*cluster->nworkers_per_group()
+      /cluster->nworkers_per_procs();
+    if(cluster->nworkers_per_group()>cluster->nworkers_per_procs())
+      procsid+=id/cluster->nworkers_per_procs();
+  }else{
+    LOG(ERROR)<<"Unkown flag ("<<flag<<")";
+  }
+  return procsid;
+}
+
+void Trainer::RegisterDefaultClasses(const singa::ModelProto& proto){
+  // register all layers appearing in the neural net
+  singa::NeuralNet::RegisterLayers();
+  Singleton<Factory<singa::Param>>::Instance()->Register(
+      "Param", CreateInstance(singa::Param, singa::Param));
+  Singleton<Factory<singa::Updater>>::Instance() ->Register(
+      "Updater", CreateInstance(singa::SGDUpdater, singa::Updater));
+  Singleton<Factory<singa::PMWorker>>::Instance() ->Register(
+      "PMWorker", CreateInstance(singa::PMWorker, singa::PMWorker));
+  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
+      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
+  Singleton<Factory<singa::PMServer>>::Instance() ->Register(
+      "PMServer", CreateInstance(singa::PMServer, singa::PMServer));
+}
+
+void Trainer::Start(const ModelProto& mproto, const ClusterProto& cproto,
+    int procs_id){
+  RegisterDefaultClasses(mproto);
+
+  auto cluster=Cluster::Get(cproto, procs_id);
+  // create servers
+  vector<shared_ptr<Server>> servers;
+  int nSocket=1; // the first socket is the router
+  if(cluster->has_server()){
+    int pid=cluster->procs_id();
+    if(cluster->server_worker_separate())
+      pid-=cluster->nworker_procs();
+    int gid=pid*cluster->nservers_per_procs()/cluster->nservers_per_group();
+    int start=pid*cluster->nservers_per_procs()%cluster->nservers_per_group();
+    int end=start+cluster->nservers_per_group();
+    // the ParamShard for servers consists of a dictionary of Param objects
+    auto shard=make_shared<PMServer::ParamShard>();
+    for(int sid=start;sid<end;sid++){
+      auto server=make_shared<Server>(gid, sid);
+      auto dealer=make_shared<Dealer>(nSocket++);
+      dealer->Connect(kInprocRouterEndpoint);
+      server->Setup(mproto.updater(), shard, dealer);
+      servers.push_back(server);
+    }
+  }
+
+  // create workers
+  vector<shared_ptr<Worker>> workers;
+  if(cluster->has_worker()){
+    auto net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+    int pid=cluster->procs_id();
+    int gstart, gend, wstart, wend;
+    if(cluster->nworkers_per_group()>=cluster->nworkers_per_procs()){
+      // all workers in this procs are from the same group
+      gstart=pid*cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+      gend=gstart+1;
+      wstart=pid*cluster->nworkers_per_procs()%cluster->nworkers_per_group();
+      wend=wstart+cluster->nworkers_per_group();
+    }else{
+      // there are multiple groups in this procs
+      CHECK_EQ(cluster->nworkers_per_procs()%cluster->nworkers_per_group(),0);
+      int groups_per_procs=
+        cluster->nworkers_per_procs()/cluster->nworkers_per_group();
+      gstart=pid*groups_per_procs;
+      gend=(pid+1)*groups_per_procs;
+      wstart=0;
+      wend=cluster->nworkers_per_group();
+    }
+    for(int gid=gstart;gid<gend;gid++){
+      shared_ptr<NeuralNet> train_net, test_net, validation_net;
+      if(gid==gstart)
+        train_net=net;
+      else{
+        train_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTrain);
+        // the train net for other groups may share parameter values from the
+        // first group
+        if(mproto.hogwild())
+          train_net->ShareParams(net, kValueOnly);
+      }
+      if(gid==0){
+        // validation and test are performed only by the first group
+        if(mproto.test_steps()){
+          test_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kTest);
+          if(test_net!=nullptr)
+            test_net->ShareParams(train_net, kValueOnly);
+        }
+        if(mproto.validation_steps()){
+          validation_net=NeuralNet::SetupNeuralNet(mproto.neuralnet(), kValidation);
+          if(validation_net!=nullptr)
+            validation_net->ShareParams(train_net, kValueOnly);
+        }
+      }
+      // create ParamShard for the workers
+      auto shard=make_shared<PMWorker::ParamShard>();
+      for(auto layer: train_net->layers()){
+        int procsid=ProcsIDOf(gid, layer->locationid(),kWorkerParam);
+        int local=procsid==cluster->procs_id();
+        for(auto param: layer->GetParams()){
+          int owner=param->owner()<0||param->owner()==param->id()?procsid:-1;
+          if(shard->find(param->id())==shard->end())
+            (*shard)[param->id()]=make_shared<ParamCounter>(param, local, owner);
+          else
+            shard->at(param->id())->AddParam(param, local, owner);
+        }
+      }
+      for(int wid=wstart;wid<wend;wid++){
+        shared_ptr<Worker> worker=nullptr;
+        if(mproto.alg()==ModelProto_GradCalcAlg_kBackPropagation)
+          worker=make_shared<BPWorker>(gid, wid);
+        else{
+        // TODO add CDWorker
+        }
+        auto layer_dealer=make_shared<Dealer>(nSocket++);
+        auto param_dealer=make_shared<Dealer>(nSocket++);
+        layer_dealer->Connect(kInprocRouterEndpoint);
+        param_dealer->Connect(kInprocRouterEndpoint);
+        worker->Setup(mproto, train_net, shard, layer_dealer, param_dealer);
+        worker->set_test_net(test_net);
+        worker->set_validation_net(validation_net);
+        workers.push_back(worker);
+      }
+    }
+  }
+
+#ifdef USE_MPI
+  for(int i=0;i<nSocket;i++){
+    MPIQueues.push_back(make_shared<SafeQueue>());
+  }
+#endif
+  vector<std::thread> threads;
+  for(auto server: servers)
+    threads.push_back(std::thread(&Server::Run,server));
+  for(auto worker: workers)
+    threads.push_back(std::thread(&Worker::Run,worker));
+  Run();
+  for(auto& thread: threads)
+    thread.join();
+}
+
+void Trainer::Run(){
+  auto cluster=Cluster::Get();
+  auto router=make_shared<Router>();
+  router->Bind(kInprocRouterEndpoint);
+  if(cluster->nprocs()>1)
+    router->Bind(cluster->endpoint());
+
+  map<int, shared_ptr<Dealer>> interprocs_dealers;
+  Poller poller;
+  poller.Add(router.get());
+  int timeout=cluster->stub_timeout();
+  while(true){
+    Msg* msg=router->Receive();
+    if(msg==nullptr){
+      LOG(ERROR)<<"Connection broken!";
+      exit(0);
+    }
+    int dst_flag=msg->dst_flag();
+    int type=msg->type();
+    int group_id, id, procs_id;
+    switch (dst_flag){ // TODO process other requests, e.g. RESTful
+      case kStub:
+        if(type==kConnect){
+          delete msg;
+        }else{
+          // TODO processing requests for worker group spanning multiple procs.
+          LOG(ERROR)<<"Unkown message type ("<<type<<") to stub";
+        }
+        break;
+      default:
+        group_id=msg->dst_group_id();
+        id=msg->dst_id();
+        procs_id=ProcsIDOf(group_id, id, dst_flag);
+        if(procs_id!=cluster->procs_id()){
+          if (interprocs_dealers.find(procs_id)==interprocs_dealers.end())
+            interprocs_dealers[procs_id]=make_shared<Dealer>(procs_id);
+          interprocs_dealers[procs_id]->Send(msg);
+        } else
+          router->Send(msg);
+        break;
+    }
+  }
+}
+} /* singa */

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/trainer/worker.cc
----------------------------------------------------------------------
diff --git a/src/trainer/worker.cc b/src/trainer/worker.cc
new file mode 100644
index 0000000..047ec2d
--- /dev/null
+++ b/src/trainer/worker.cc
@@ -0,0 +1,299 @@
+#include <glog/logging.h>
+#include <thread>
+#include <memory>
+#include <iostream>
+#include "utils/singleton.h"
+#include "utils/factory.h"
+#include "trainer/worker.h"
+#include "proto/model.pb.h"
+using std::thread;
+namespace singa {
+Worker::Worker( int group_id, int worker_id):
+   group_id_(group_id), worker_id_(worker_id){
+}
+
+void Worker::Setup(const ModelProto& model,
+    shared_ptr<NeuralNet> train_net,
+    shared_ptr<PMWorker::ParamShard> shard,
+    shared_ptr<Dealer> layer_dealer,
+    shared_ptr<Dealer> param_dealer){
+  train_net_=train_net;
+  modelproto_=model;
+  layer_dealer_=layer_dealer;
+  param_dealer_=param_dealer;
+  if(layer_dealer_!=nullptr)
+    layer_poller_.Add(layer_dealer_.get());
+  if(param_dealer_!=nullptr)
+    param_poller_.Add(param_dealer_.get());
+  pmworker_=shared_ptr<PMWorker>(Singleton<Factory<PMWorker>>::Instance()
+      ->Create("PMWorker"));
+  pmworker_->Setup(group_id_, worker_id_, shard);
+  step_=modelproto_.step();
+  // init params
+  for(auto layer: train_net->layers())
+    if(group_id_==0&&layer->locationid()==worker_id_)
+      for(auto param: layer->GetParams()){
+        if(param->owner()<0||param->owner()==param->id()){
+          param->Init();
+          Put(param, step_);
+        }
+        Get(param, step_);
+      }
+}
+
+void Worker::Run(){
+  step_=modelproto_.step();
+  Performance perf(train_net_);
+  try{
+    while(!StopNow(step_)){
+      RunOneBatch(step_, &perf);
+      step_++;
+    }
+  }catch(WorkerException& e){
+    LOG(ERROR)<<e.what();
+  }
+}
+int Worker::Put(shared_ptr<Param> param, int step){
+  auto msg=pmworker_->Put(param, step);
+  if(msg!=nullptr)
+    param_dealer_->Send(msg);
+  return 1;
+}
+int Worker::Get(shared_ptr<Param> param, int step){
+  if(param->version()<step){
+    auto msg=pmworker_->Get(param, step);
+    if(msg!=nullptr)
+      param_dealer_->Send(msg);
+  }
+  return 1;
+}
+int Worker::Update(shared_ptr<Param> param, int step){
+  auto msg=pmworker_->Update(param, step);
+  if(msg!=nullptr)
+    param_dealer_->Send(msg);
+  return 1;
+}
+int Worker::Collect(shared_ptr<Param> param, int step){
+  while(param->version()<step){
+    Msg* msg=param_dealer_->Receive();
+    if(msg==nullptr)
+      return 0;
+    pmworker_->Collect(&msg);
+  }
+  return 1;
+}
+
+void Worker::RunOneBatch(int step, Performance* perf){
+  //DLOG(ERROR)<<"Step "<<step;
+  // Test will call Pull which updates the sync time
+  // Hence we store the sync time, and restore it later
+  //float tSyncData=tSyncData_, tSyncParam=tSyncParam_;
+  if(ValidateNow(step)){
+    LOG(ERROR)<<"Validation at step "<<step;
+    Test(validation_net_, modelproto_.validation_steps(), perf!=nullptr);
+  }
+  if(TestNow(step)){
+    LOG(ERROR)<<"Test at step "<<step;
+    Test(test_net_, modelproto_.test_steps(), perf!=nullptr);
+  }
+  //tSyncData_=tSyncData; tSyncParam_=tSyncParam;
+
+  TrainOneBatch(step);
+  if(perf!=nullptr){
+    perf->Update();
+    if(DisplayNow(step)){
+      LOG(ERROR)<<"Training at step "<<step;
+      LOG(ERROR)<<"\t"<<perf->ToString();
+      perf->Reset();
+      //LOG(ERROR)<<"\t"<<TimerInfo();
+    }
+  }
+
+  /*
+  if(CheckpointNow(step)){
+    pm_->Checkpoint(cluster_->workspace()+"/snapshot-"+std::to_string(step));
+  }
+  */
+}
+
+void Worker::ReceiveBlobs(shared_ptr<NeuralNet> net){
+  /*
+  int type;
+  char *name;
+  int64_t tick=zclock_mono();
+  zframe_t* frame=zframe_new_empty();
+
+  zsock_recv(pull_, "isf", &type, &name, &frame);
+  if(type==kDataFrame){
+    auto* dst=static_cast<BridgeDstLayer*>(
+        net->name2layer(string(name)).get());
+    memcpy(dst->mutable_data()->mutable_cpu_data(), zframe_data(frame),
+        zframe_size(frame));
+    dst->set_ready(true);
+  }else if(type==kGradFrame){
+    auto* src=static_cast<BridgeSrcLayer*>(net->name2layer(string(name)).get());
+    memcpy(src->mutable_grad()->mutable_cpu_data(), zframe_data(frame),
+        zframe_size(frame));
+    src->set_ready(true);
+  }
+  zframe_destroy(&frame);
+  delete name;
+  tSyncData_+=zclock_mono()-tick;
+  */
+}
+
+void Worker::SendBlob(){
+
+}
+
+void Worker::Test(shared_ptr<NeuralNet> net, int nsteps, bool disperf){
+  Performance perf(net);
+  for(int step=0;step<nsteps;step++){
+    TestOneBatch(net, step, kTest);
+    if(disperf)
+      perf.Update();
+  }
+  if(disperf)
+    LOG(ERROR)<<"\t"<<perf.ToString();
+}
+
+/****************************BPWorker**********************************/
+
+void BPWorker::Forward(shared_ptr<NeuralNet> net, int step,  bool training){
+  auto& layers=net->layers();
+  for(auto& layer: layers){
+    if(layer->locationid()==worker_id_){
+      if(layer->is_bridgedstlayer()){
+        //auto* dst=static_cast<BridgeDstLayer*>(layer.get());
+        // receive fea blobs
+      }
+      if(training){
+        for(shared_ptr<Param> p: layer->GetParams()){
+          if(Collect(p, step)==0){
+            throw WorkerException();
+          }
+        }
+      }
+      layer->ComputeFeature(training);
+      if(layer->is_bridgesrclayer()){
+        // send fea blobs
+      }
+      if(training&&DisplayDebugInfo(step)&&layer->mutable_data()!=nullptr){
+        LOG(INFO)<<StringPrintf("Forward layer  %10s data norm1 %13.9f",
+            layer->name().c_str(), layer->data().asum_data());
+      }
+    }
+  }
+}
+
+void BPWorker::Backward(shared_ptr<NeuralNet> net, int step){
+  auto& layers=net->layers();
+  for (auto it = layers.rbegin(); it != layers.rend(); it++){
+    shared_ptr<Layer> layer=*it;
+    if(layer->locationid()==worker_id_){
+      if(layer->is_bridgesrclayer()){
+        //auto* src=static_cast<BridgeSrcLayer*>(layer.get());
+        // receive grad blobs
+      }
+      layer->ComputeGradient();
+      if(DisplayDebugInfo(step)&&layer->mutable_grad()!=nullptr){
+        LOG(INFO)<<StringPrintf("Backward layer %10s grad norm1 %13.9f\t",
+            layer->name().c_str(), layer->grad().asum_data());
+        for(shared_ptr<Param> p: layer->GetParams())
+          LOG(INFO)<<StringPrintf("param id %2d, name %10s,\
+              value norm1 %13.9f, grad norm1 %13.9f",
+              p->id(), p->name().c_str(),
+              p->data().asum_data(), p->grad().asum_data());
+      }
+      for(shared_ptr<Param> p: layer->GetParams()){
+        Update(p, step);
+      }
+      if(layer->is_bridgedstlayer()){
+        // send grad blobs
+      }
+    }
+  }
+}
+
+void BPWorker::TrainOneBatch(int step){
+  Forward(train_net_, step, true);
+  Backward(train_net_, step);
+}
+
+void BPWorker::TestOneBatch(shared_ptr<NeuralNet> net,int step, Phase phase){
+  Forward(net, step, false);
+}
+
+/*********************Implementation for Performance class*******************/
+Performance::Performance(shared_ptr<NeuralNet> net):net_(net), counter_(0){
+  for(auto& layer: net->losslayers()){
+    name_.push_back(layer->name());
+    metric_.push_back(vector<float>{});
+    metric_.back().resize(layer->metric().count(),0.f);
+  }
+}
+
+void Performance::Update(){
+  const auto& losslayers=net_->losslayers();
+  for(size_t i=0;i<losslayers.size();i++){
+    const float * ptr=losslayers[i]->metric().cpu_data();
+    vector<float>& m=metric_.at(i);
+    for(int j=0;j<losslayers[i]->metric().count();j++)
+      m[j]+=ptr[j];
+  }
+  counter_++;
+}
+
+void Performance::Reset(){
+  for(auto& m: metric_)
+    for(auto& x: m)
+      x=0.f;
+  counter_=0;
+}
+
+string Performance::ToString(){
+  string disp="";
+  for(size_t i=0;i<metric_.size();i++){
+    disp+="Output from "+name_[i]+" layer ";
+    vector<float> m=metric_.at(i);
+    for(size_t j=0;j<m.size();j++)
+        disp+=std::to_string(j)+" : "+std::to_string(m[j]/counter_)+"\t";
+    disp+="\n";
+  }
+  return disp;
+}
+/*
+void Executor::Setup(int local_threadid, const ModelProto& model){
+  tForward_=tBackward_=tSyncData_=tSyncParam_=0;
+  modelproto_=model;
+  local_threadid_=local_threadid;
+  if(model.prefetch()){
+    for(auto& layer: train_net_->datalayers()){
+      if(cluster_->group_threadid(local_threadid_)==layer->locationid())
+        localDataLayers_.push_back(layer);
+    }
+    if(localDataLayers_.size())
+      prefetch_thread_=std::thread(Executor::PrefetchData,
+          std::ref(localDataLayers_), true,1);
+  }
+  int gthreadid=cluster_->group_threadid(local_threadid);
+}
+
+void Executor::PrefetchData(const vector<DataLayer*>& datalayers, bool training,
+    int steps){
+  if(datalayers.size()==0)
+    return;
+  for(int i=0;i<steps;i++){
+    for(auto& layer: datalayers){
+      layer->Prefetching(training);
+      for(auto& dstlayer: layer->dstlayers()){
+        CHECK(dstlayer->is_parserlayer());
+        auto parserlayer=static_cast<ParserLayer*>(dstlayer.get());
+        parserlayer->Prefetching(training);
+      }
+    }
+  }
+}
+*/
+
+}  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b2dc51d2/src/utils/blob.cc
----------------------------------------------------------------------
diff --git a/src/utils/blob.cc b/src/utils/blob.cc
new file mode 100644
index 0000000..92fc989
--- /dev/null
+++ b/src/utils/blob.cc
@@ -0,0 +1,330 @@
+/**
+ * The code is adapted from that of Caffe whose license is attached.
+ *
+ * COPYRIGHT
+ * All contributions by the University of California:
+ * Copyright (c) 2014, The Regents of the University of California (Regents)
+ * All rights reserved.
+ * All other contributions:
+ * Copyright (c) 2014, the respective contributors
+ * All rights reserved.
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ * LICENSE
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ * CONTRIBUTION AGREEMENT
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ */
+#include <utility>
+#include <math.h>
+#include <cblas.h>
+#include "utils/blob.h"
+/*********************SyncedMemory implementation************************/
+
+#define NO_GPU LOG(FATAL) << "CPU-only Mode: cannot make GPU call."
+// Instantiate a class with float and double specifications.
+#define INSTANTIATE_CLASS(classname) \
+  template class classname<float>; \
+  template class classname<double>
+// Disable the copy and assignment operator for a class.
+#define DISABLE_COPY_AND_ASSIGN(classname) \
+private:\
+  classname(const classname&);\
+  classname& operator=(const classname&)
+
+#ifndef CPU_ONLY
+// CUDA: various checks for different function calls.
+#define CUDA_CHECK(condition) \
+  /* Code block avoids redefinition of cudaError_t error */ \
+  do { \
+    cudaError_t error = condition; \
+    CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
+  } while (0)
+
+#define CUBLAS_CHECK(condition) \
+  do { \
+    cublasStatus_t status = condition; \
+    CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << " " \
+      << caffe::cublasGetErrorString(status); \
+  } while (0)
+
+#define CURAND_CHECK(condition) \
+  do { \
+    curandStatus_t status = condition; \
+    CHECK_EQ(status, CURAND_STATUS_SUCCESS) << " " \
+      << caffe::curandGetErrorString(status); \
+  } while (0)
+
+#endif // CPU_ONLY
+
+
+SyncedMemory::~SyncedMemory() {
+  if (cpu_ptr_ && own_cpu_data_) {
+    FreeHost(cpu_ptr_);
+  }
+
+#ifndef CPU_ONLY
+  if (gpu_ptr_) {
+    CUDA_CHECK(cudaFree(gpu_ptr_));
+  }
+#endif  // CPU_ONLY
+}
+
+inline void SyncedMemory::to_cpu() {
+  switch (head_) {
+  case UNINITIALIZED:
+    MallocHost(&cpu_ptr_, size_);
+    memset(cpu_ptr_,0, size_);
+    head_ = HEAD_AT_CPU;
+    own_cpu_data_ = true;
+    break;
+  case HEAD_AT_GPU:
+#ifndef CPU_ONLY
+    if (cpu_ptr_ == NULL) {
+      MallocHost(&cpu_ptr_, size_);
+      own_cpu_data_ = true;
+    }
+    CUDA_CHECK(cudaMemcpy(cpu_ptr_, gpu_ptr_, size_, cudaMemcpyDefault));
+    head_ = SYNCED;
+#else
+    NO_GPU;
+#endif
+    break;
+  case HEAD_AT_CPU:
+  case SYNCED:
+    break;
+  }
+}
+
+inline void SyncedMemory::to_gpu() {
+#ifndef CPU_ONLY
+  switch (head_) {
+  case UNINITIALIZED:
+    CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+    CUDA_CHECK(cudaMemset(gpu_ptr_, 0, N));  // NOLINT(caffe/alt_fn)
+    head_ = HEAD_AT_GPU;
+    break;
+  case HEAD_AT_CPU:
+    if (gpu_ptr_ == NULL) {
+      CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_));
+    }
+    CUDA_CHECK(cudaMemcpy( gpu_ptr_,cpu_ptr_, size_, cudaMemcpyDefault));
+    head_ = SYNCED;
+    break;
+  case HEAD_AT_GPU:
+  case SYNCED:
+    break;
+  }
+#else
+  NO_GPU;
+#endif
+}
+
+const void* SyncedMemory::cpu_data() {
+  to_cpu();
+  return (const void*)cpu_ptr_;
+}
+
+void SyncedMemory::set_cpu_data(void* data) {
+  CHECK(data);
+  if (own_cpu_data_) {
+    FreeHost(cpu_ptr_);
+  }
+  cpu_ptr_ = data;
+  head_ = HEAD_AT_CPU;
+  own_cpu_data_ = false;
+}
+
+const void* SyncedMemory::gpu_data() {
+#ifndef CPU_ONLY
+  to_gpu();
+  return (const void*)gpu_ptr_;
+#else
+  NO_GPU;
+#endif
+  return nullptr;
+}
+
+void* SyncedMemory::mutable_cpu_data() {
+  to_cpu();
+  head_ = HEAD_AT_CPU;
+  return cpu_ptr_;
+}
+
+void* SyncedMemory::mutable_gpu_data() {
+#ifndef CPU_ONLY
+  to_gpu();
+  head_ = HEAD_AT_GPU;
+  return gpu_ptr_;
+#else
+  NO_GPU;
+#endif
+  return nullptr;
+}
+
+/*********************Blob implementation************************/
+
+template <typename Dtype>
+Blob<Dtype>::Blob(const vector<int>& shape)
+  // capacity_ must be initialized before calling Reshape
+  : capacity_(0) {
+  Reshape(shape);
+}
+
+template <typename Dtype>
+void Blob<Dtype>::Reshape(const vector<int>& shape) {
+  count_=1;
+  shape_=shape;
+  for(size_t i=0;i<shape.size();i++){
+    CHECK(shape[i]);
+    count_*=shape[i];
+  }
+  if (count_ > capacity_) {
+    capacity_ = count_;
+    data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
+  }
+}
+
+template <typename Dtype>
+void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {
+  Reshape(other.shape());
+}
+
+template <typename Dtype>
+const Dtype* Blob<Dtype>::cpu_data() const {
+  CHECK(data_);
+  return (const Dtype*)data_->cpu_data();
+}
+
+template <typename Dtype>
+void Blob<Dtype>::set_cpu_data(Dtype* data) {
+  CHECK(data);
+  data_->set_cpu_data(data);
+}
+
+template <typename Dtype>
+const Dtype* Blob<Dtype>::gpu_data() const {
+  CHECK(data_);
+  return (const Dtype*)data_->gpu_data();
+}
+
+template <typename Dtype>
+Dtype* Blob<Dtype>::mutable_cpu_data() {
+  CHECK(data_);
+  return static_cast<Dtype*>(data_->mutable_cpu_data());
+}
+
+template <typename Dtype>
+Dtype* Blob<Dtype>::mutable_gpu_data() {
+  CHECK(data_);
+  return static_cast<Dtype*>(data_->mutable_gpu_data());
+}
+
+template <typename Dtype>
+void Blob<Dtype>::ShareData(const Blob& other) {
+  CHECK_EQ(count_, other.count());
+  data_ = other.data();
+}
+
+template <> float Blob<float>::asum_data() const {
+  if(count()==0)
+    return 0.f;
+  return cblas_sasum(count(), cpu_data(), 1)/count();
+}
+template <> float Blob<float>::sum_data() const {
+  if(count()==0)
+    return 0.f;
+  float sum=0.f;
+  const float *dptr=cpu_data();
+  for(int i=0;i<count();i++)
+    sum+=dptr[i];
+  return sum/count();
+}
+template <> unsigned int Blob<unsigned int>::asum_data() const {
+  NOT_IMPLEMENTED;
+  return 0;
+}
+
+template <> int Blob<int>::asum_data() const {
+  NOT_IMPLEMENTED;
+  return 0;
+}
+
+template <typename Dtype>
+void Blob<Dtype>::Swap(Blob& other){
+  CHECK_EQ(other.count(), count());
+  CHECK(std::equal(shape_.begin(), shape_.end(), other.shape_.begin()));
+  std::swap(data_, other.data_);
+  std::swap(capacity_, other.capacity_);
+}
+
+template <typename Dtype>
+void Blob<Dtype>::CopyFrom(const Blob& source, bool reshape) {
+  if (!std::equal(shape_.begin(),shape_.end(),source.shape_.begin())) {
+    if (reshape) {
+      Reshape(source.shape_);
+    } else {
+      LOG(FATAL) << "Trying to copy blobs of different sizes.";
+    }
+  }
+#ifndef CPU_ONLY
+  CUDA_CHECK(cudaMemcpy(static_cast<Dtype*>(data_->mutable_gpu_data()),
+            source.gpu_data(), sizeof(Dtype) * count_, cudaMemcpyDefault));
+#endif
+  memcpy(static_cast<Dtype*>(data_->mutable_cpu_data()),source.cpu_data(),
+        sizeof(Dtype)*count_);
+}
+
+/*
+template <typename Dtype>
+void Blob<Dtype>::FromProto(const BlobProto& proto) {
+  Reshape();
+  // copy data
+  Dtype* data_vec = mutable_cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    data_vec[i] = proto.data(i);
+  }
+}
+*/
+
+template <typename Dtype>
+void Blob<Dtype>::ToProto(singa::BlobProto* proto) const {
+  proto->set_num(shape_[0]);
+  if(shape_.size()>1)
+    proto->set_channels(shape_[1]);
+  if(shape_.size()>2)
+    proto->set_height(shape_[2]);
+  if(shape_.size()>3)
+    proto->set_width(shape_[3]);
+  proto->clear_data();
+  const Dtype* data_vec = cpu_data();
+  for (int i = 0; i < count_; ++i) {
+    proto->add_data(data_vec[i]);
+  }
+}
+
+INSTANTIATE_CLASS(Blob);
+template class Blob<int>;
+template class Blob<unsigned int>;


Mime
View raw message