Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 55A5E200B6F for ; Wed, 10 Aug 2016 08:02:08 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 5430D160AAD; Wed, 10 Aug 2016 06:02:08 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 9D4D1160AA5 for ; Wed, 10 Aug 2016 08:02:06 +0200 (CEST) Received: (qmail 89928 invoked by uid 500); 10 Aug 2016 06:02:05 -0000 Mailing-List: contact commits-help@singa.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@singa.incubator.apache.org Delivered-To: mailing list commits@singa.incubator.apache.org Received: (qmail 89919 invoked by uid 99); 10 Aug 2016 06:02:05 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd2-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 10 Aug 2016 06:02:05 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd2-us-west.apache.org (ASF Mail Server at spamd2-us-west.apache.org) with ESMTP id 48CA91A137A for ; Wed, 10 Aug 2016 06:02:05 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd2-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -4.645 X-Spam-Level: X-Spam-Status: No, score=-4.645 tagged_above=-999 required=6.31 tests=[KAM_ASCII_DIVIDERS=0.8, KAM_LAZY_DOMAIN_SECURITY=1, RCVD_IN_DNSWL_HI=-5, RCVD_IN_MSPIKE_H3=-0.01, RCVD_IN_MSPIKE_WL=-0.01, RP_MATCHES_RCVD=-1.426, URIBL_BLOCKED=0.001] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd2-us-west.apache.org [10.40.0.9]) (amavisd-new, port 10024) with ESMTP id e8LiIXxrDS11 for ; Wed, 10 Aug 2016 06:01:59 +0000 (UTC) Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with SMTP id C87C95FBB8 for ; Wed, 10 Aug 2016 06:01:57 +0000 (UTC) Received: (qmail 88648 invoked by uid 99); 10 Aug 2016 06:01:57 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 10 Aug 2016 06:01:57 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id DB3F0E0B16; Wed, 10 Aug 2016 06:01:56 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: wangwei@apache.org To: commits@singa.incubator.apache.org Date: Wed, 10 Aug 2016 06:01:56 -0000 Message-Id: <574cfcb341d64c03bc7c8b448184508c@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [1/3] incubator-singa git commit: SINGA-232 Alexnet on Imagenet archived-at: Wed, 10 Aug 2016 06:02:08 -0000 Repository: incubator-singa Updated Branches: refs/heads/dev 17bfb1967 -> 53639b7ce SINGA-232 Alexnet on Imagenet Implement Alexnet on Imagenet. 1. The model is following the imagenet paper. 2. The data is created offline, including multiple training files, one test file and one mean file. All of them are in binary format. This part is implemented via writer, encoder in SINGA. 3. Loading data in multiple threads. This part needs reader, decoder and transformer in SINGA. 4. This example need OpenCV support. 5. snapshot, jpgencoder, timer, binfile_reader are slightly modified. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/34d3ae68 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/34d3ae68 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/34d3ae68 Branch: refs/heads/dev Commit: 34d3ae680e4b39b952524420e6fe6a29f9094801 Parents: db5478e Author: xiangrui Authored: Tue Aug 2 16:09:29 2016 +0800 Committer: Xiangrui Committed: Mon Aug 8 16:02:13 2016 +0800 ---------------------------------------------------------------------- CMakeLists.txt | 2 +- examples/CMakeLists.txt | 1 + examples/imagenet/CMakeLists.txt | 16 ++ examples/imagenet/README.md | 43 ++++ examples/imagenet/alexnet.cc | 410 ++++++++++++++++++++++++++++++++++ examples/imagenet/create_data.sh | 3 + examples/imagenet/ilsvrc12.cc | 69 ++++++ examples/imagenet/ilsvrc12.h | 380 +++++++++++++++++++++++++++++++ examples/imagenet/run.sh | 3 + include/singa/io/snapshot.h | 8 +- include/singa/utils/timer.h | 6 +- src/core/tensor/tensor.cc | 8 +- src/io/binfile_reader.cc | 6 +- src/io/jpg_encoder.cc | 2 +- src/io/snapshot.cc | 8 +- 15 files changed, 950 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/CMakeLists.txt b/CMakeLists.txt index 23f8ef6..8c6afad 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ CMAKE_MINIMUM_REQUIRED(VERSION 2.6) PROJECT(singa) -SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2 ") LIST(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Thirdparty) #message(STATUS "module path: ${CMAKE_MODULE_PATH}") http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3490c38..6014f27 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1 +1,2 @@ ADD_SUBDIRECTORY(cifar10) +ADD_SUBDIRECTORY(imagenet) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/examples/imagenet/CMakeLists.txt b/examples/imagenet/CMakeLists.txt new file mode 100644 index 0000000..71fbbb1 --- /dev/null +++ b/examples/imagenet/CMakeLists.txt @@ -0,0 +1,16 @@ +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) +INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/include) + +IF(USE_CUDNN) + IF(USE_OPENCV) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp ") + ADD_EXECUTABLE(imagenet alexnet.cc) + ADD_DEPENDENCIES(imagenet singa_core singa_model singa_utils singa_io) + TARGET_LINK_LIBRARIES(imagenet singa_core singa_utils singa_model singa_io protobuf ${SINGA_LIBKER_LIBS}) + + ADD_EXECUTABLE(createdata ilsvrc12.cc) + ADD_DEPENDENCIES(createdata singa_core singa_io singa_model singa_utils) + TARGET_LINK_LIBRARIES(createdata singa_core singa_utils singa_io singa_model protobuf ${SINGA_LIBKER_LIBS}) + #SET_TARGET_PROPERTIES(createdata PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + ENDIF(USE_OPENCV) +ENDIF(USE_CUDNN) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/README.md ---------------------------------------------------------------------- diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md new file mode 100644 index 0000000..2e0389a --- /dev/null +++ b/examples/imagenet/README.md @@ -0,0 +1,43 @@ +# Example of alexnet + +### Data download +* Please refer to step1-3 on [Instructions to create ImageNet 2012 data](https://github.com/amd/OpenCL-caffe/wiki/Instructions-to-create-ImageNet-2012-data) + to download and decompress the data. +* You can download the training and validation list by + [get_ilsvrc_aux.sh](https://github.com/BVLC/caffe/blob/master/data/ilsvrc12/get_ilsvrc_aux.sh) + or from [Imagenet](http://www.image-net.org/download-images). + +### Data preprocessing +* Assuming you have downloaded the data and the list. + Now we should transform the data into binary files. You can run: + + sh create_data.sh + + The script will generate a test file(`test.bin`), a mean file(`mean.bin`) and + several training files(`trainX.bin`) in the specified output folder. +* You can also change the parameters in `create_data.sh`. + + `-trainlist `: the file of training list; + + `-trainfolder `: the folder of training images; + + `-testlist `: the file of test list; + + `-testfolder `: the folder of test images; + + `-outdata `: the folder to save output files, including mean, training and test files. + The script will generate these files in the specified folder; + + `-filesize `: number of training images that stores in each binary file. + +### Training +* After preparing data, you can run the following command to train the Alexnet model. + + sh run.sh +* You may change the parameters in `run.sh`. + + `-epoch `: number of epoch to be trained, default is 90; + + `-lr `: base learning rate, the learning rate will decrease each 20 epochs, + more specifically, `lr = lr * exp(0.1 * (epoch / 20))`; + + `-batchsize `: batchsize, it should be changed regarding to your memory; + + `-filesize `: number of training images that stores in each binary file, it is the + same as the `filesize` in data preprocessing; + + `-ntrain `: number of training images; + + `-ntest `: number of test images; + + `-data `: the folder which stores the binary files, it is exactly the output + folder in data preprocessing step; + + `-pfreq `: the frequency(in batch) of printing current model status(loss and accuracy); + + `-nthreads `: the number of threads to load data which feed to the model. \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/alexnet.cc ---------------------------------------------------------------------- diff --git a/examples/imagenet/alexnet.cc b/examples/imagenet/alexnet.cc new file mode 100644 index 0000000..22cc88f --- /dev/null +++ b/examples/imagenet/alexnet.cc @@ -0,0 +1,410 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "singa/singa_config.h" +#ifdef USE_OPENCV +#include +#include "../../src/model/layer/cudnn_activation.h" +#include "../../src/model/layer/cudnn_convolution.h" +#include "../../src/model/layer/cudnn_dropout.h" +#include "../../src/model/layer/cudnn_lrn.h" +#include "../../src/model/layer/cudnn_pooling.h" +#include "../../src/model/layer/dense.h" +#include "../../src/model/layer/flatten.h" +#include "./ilsvrc12.h" +#include "singa/io/snapshot.h" +#include "singa/model/feed_forward_net.h" +#include "singa/model/initializer.h" +#include "singa/model/metric.h" +#include "singa/model/optimizer.h" +#include "singa/utils/channel.h" +#include "singa/utils/string.h" +#include "singa/utils/timer.h" +namespace singa { + +LayerConf GenConvConf(string name, int nb_filter, int kernel, int stride, + int pad, float std, float bias = .0f) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnConvolution"); + ConvolutionConf *conv = conf.mutable_convolution_conf(); + conv->set_num_output(nb_filter); + conv->add_kernel_size(kernel); + conv->add_stride(stride); + conv->add_pad(pad); + conv->set_bias_term(true); + + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(std); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + bspec->set_lr_mult(2); + bspec->set_decay_mult(0); + auto bfill = bspec->mutable_filler(); + bfill->set_value(bias); + return conf; +} + +LayerConf GenPoolingConf(string name, bool max_pool, int kernel, int stride, + int pad) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnPooling"); + PoolingConf *pool = conf.mutable_pooling_conf(); + pool->set_kernel_size(kernel); + pool->set_stride(stride); + pool->set_pad(pad); + if (!max_pool) pool->set_pool(PoolingConf_PoolMethod_AVE); + return conf; +} + +LayerConf GenReLUConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("RELU"); + return conf; +} + +LayerConf GenDenseConf(string name, int num_output, float std, float wd, + float bias = .0f) { + LayerConf conf; + conf.set_name(name); + conf.set_type("Dense"); + DenseConf *dense = conf.mutable_dense_conf(); + dense->set_num_output(num_output); + + ParamSpec *wspec = conf.add_param(); + wspec->set_name(name + "_weight"); + wspec->set_decay_mult(wd); + auto wfill = wspec->mutable_filler(); + wfill->set_type("Gaussian"); + wfill->set_std(std); + + ParamSpec *bspec = conf.add_param(); + bspec->set_name(name + "_bias"); + bspec->set_lr_mult(2); + bspec->set_decay_mult(0); + auto bfill = bspec->mutable_filler(); + bfill->set_value(bias); + + return conf; +} + +LayerConf GenLRNConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnLRN"); + LRNConf *lrn = conf.mutable_lrn_conf(); + lrn->set_local_size(5); + lrn->set_alpha(1e-04); + lrn->set_beta(0.75); + return conf; +} + +LayerConf GenFlattenConf(string name) { + LayerConf conf; + conf.set_name(name); + conf.set_type("Flatten"); + return conf; +} + +LayerConf GenDropoutConf(string name, float dropout_ratio) { + LayerConf conf; + conf.set_name(name); + conf.set_type("CudnnDropout"); + DropoutConf *dropout = conf.mutable_dropout_conf(); + dropout->set_dropout_ratio(dropout_ratio); + return conf; +} + +FeedForwardNet CreateNet() { + FeedForwardNet net; + Shape s{3, 227, 227}; + + net.Add(new CudnnConvolution(), GenConvConf("conv1", 96, 11, 4, 0, 0.01), &s); + net.Add(new CudnnActivation(), GenReLUConf("relu1")); + net.Add(new CudnnPooling(), GenPoolingConf("pool1", true, 3, 2, 0)); + net.Add(new CudnnLRN(), GenLRNConf("lrn1")); + net.Add(new CudnnConvolution(), + GenConvConf("conv2", 256, 5, 1, 2, 0.01, 1.0)); + net.Add(new CudnnActivation(), GenReLUConf("relu2")); + net.Add(new CudnnPooling(), GenPoolingConf("pool2", true, 3, 2, 0)); + net.Add(new CudnnLRN(), GenLRNConf("lrn2")); + net.Add(new CudnnConvolution(), GenConvConf("conv3", 384, 3, 1, 1, 0.01)); + net.Add(new CudnnActivation(), GenReLUConf("relu3")); + net.Add(new CudnnConvolution(), + GenConvConf("conv4", 384, 3, 1, 1, 0.01, 1.0)); + net.Add(new CudnnActivation(), GenReLUConf("relu4")); + net.Add(new CudnnConvolution(), + GenConvConf("conv5", 256, 3, 1, 1, 0.01, 1.0)); + net.Add(new CudnnActivation(), GenReLUConf("relu5")); + net.Add(new CudnnPooling(), GenPoolingConf("pool5", true, 3, 2, 0)); + net.Add(new Flatten(), GenFlattenConf("flat")); + net.Add(new Dense(), GenDenseConf("ip6", 4096, 0.005, 1, 1.0)); + net.Add(new CudnnActivation(), GenReLUConf("relu6")); + net.Add(new CudnnDropout(), GenDropoutConf("drop6", 0.5)); + net.Add(new Dense(), GenDenseConf("ip7", 4096, 0.005, 1, 1.0)); + net.Add(new CudnnActivation(), GenReLUConf("relu7")); + net.Add(new CudnnDropout(), GenDropoutConf("drop7", 0.5)); + net.Add(new Dense(), GenDenseConf("ip8", 1000, 0.01, 1)); + + return net; +} + +void TrainOneEpoch(FeedForwardNet &net, ILSVRC &data, + std::shared_ptr device, int epoch, string bin_folder, + size_t num_train_files, size_t batchsize, float lr, + Channel *train_ch, size_t pfreq, int nthreads) { + float loss = 0.0f, metric = 0.0f; + float load_time = 0.0f, train_time = 0.0f; + size_t b = 0; + size_t n_read; + Timer timer, ttr; + Tensor prefetch_x, prefetch_y; + string binfile = bin_folder + "/train1.bin"; + timer.Tick(); + data.LoadData(kTrain, binfile, batchsize, &prefetch_x, &prefetch_y, &n_read, + nthreads); + load_time += timer.Elapsed(); + CHECK_EQ(n_read, batchsize); + Tensor train_x(prefetch_x.shape(), device); + Tensor train_y(prefetch_y.shape(), device, kInt); + std::thread th; + for (size_t fno = 1; fno <= num_train_files; fno++) { + binfile = bin_folder + "/train" + std::to_string(fno) + ".bin"; + while (true) { + if (th.joinable()) { + th.join(); + load_time += timer.Elapsed(); + // LOG(INFO) << "num of samples: " << n_read; + if (n_read < batchsize) { + if (n_read > 0) { + LOG(WARNING) << "Pls set batchsize to make num_total_samples " + << "% batchsize == 0. Otherwise, the last " << n_read + << " samples would not be used"; + } + break; + } + } + if (n_read == batchsize) { + train_x.CopyData(prefetch_x); + train_y.CopyData(prefetch_y); + } + timer.Tick(); + th = data.AsyncLoadData(kTrain, binfile, batchsize, &prefetch_x, + &prefetch_y, &n_read, nthreads); + if (n_read < batchsize) continue; + CHECK_EQ(train_x.shape(0), train_y.shape(0)); + ttr.Tick(); + auto ret = net.TrainOnBatch(epoch, train_x, train_y); + train_time += ttr.Elapsed(); + loss += ret.first; + metric += ret.second; + b++; + } + if (b % pfreq == 0) { + train_ch->Send( + "Epoch " + std::to_string(epoch) + ", training loss = " + + std::to_string(loss / b) + ", accuracy = " + + std::to_string(metric / b) + ", lr = " + std::to_string(lr) + + ", time of loading " + std::to_string(batchsize) + " images = " + + std::to_string(load_time / b) + + " ms, time of training (batchsize = " + std::to_string(batchsize) + + ") = " + std::to_string(train_time / b) + " ms."); + loss = 0.0f; + metric = 0.0f; + load_time = 0.0f; + train_time = 0.0f; + b = 0; + } + } +} + +void TestOneEpoch(FeedForwardNet &net, ILSVRC &data, + std::shared_ptr device, int epoch, string bin_folder, + size_t num_test_images, size_t batchsize, Channel *val_ch, + int nthreads) { + float loss = 0.0f, metric = 0.0f; + float load_time = 0.0f, eval_time = 0.0f; + size_t n_read; + string binfile = bin_folder + "/test.bin"; + Timer timer, tte; + Tensor prefetch_x, prefetch_y; + timer.Tick(); + data.LoadData(kEval, binfile, batchsize, &prefetch_x, &prefetch_y, &n_read, + nthreads); + load_time += timer.Elapsed(); + Tensor test_x(prefetch_x.shape(), device); + Tensor test_y(prefetch_y.shape(), device, kInt); + int remain = (int)num_test_images - n_read; + CHECK_EQ(n_read, batchsize); + std::thread th; + while (true) { + if (th.joinable()) { + th.join(); + load_time += timer.Elapsed(); + remain -= n_read; + if (remain < 0) break; + if (n_read < batchsize) break; + } + test_x.CopyData(prefetch_x); + test_y.CopyData(prefetch_y); + timer.Tick(); + th = data.AsyncLoadData(kEval, binfile, batchsize, &prefetch_x, &prefetch_y, + &n_read, nthreads); + + CHECK_EQ(test_x.shape(0), test_y.shape(0)); + tte.Tick(); + auto ret = net.EvaluateOnBatch(test_x, test_y); + eval_time += tte.Elapsed(); + ret.first.ToHost(); + ret.second.ToHost(); + loss += Sum(ret.first); + metric += Sum(ret.second); + } + loss /= num_test_images; + metric /= num_test_images; + val_ch->Send("Epoch " + std::to_string(epoch) + ", val loss = " + + std::to_string(loss) + ", accuracy = " + std::to_string(metric) + + ", time of loading " + std::to_string(num_test_images) + + " images = " + std::to_string(load_time) + + " ms, time of evaluating " + std::to_string(num_test_images) + + " images = " + std::to_string(eval_time) + " ms."); +} + +void Checkpoint(FeedForwardNet &net, string prefix) { + Snapshot snapshot(prefix, Snapshot::kWrite, 200); + auto names = net.GetParamNames(); + auto values = net.GetParamValues(); + for (size_t k = 0; k < names.size(); k++) { + values.at(k).ToHost(); + snapshot.Write(names.at(k), values.at(k)); + } + LOG(INFO) << "Write snapshot into " << prefix; +} + +void Train(int num_epoch, float lr, size_t batchsize, size_t train_file_size, + string bin_folder, size_t num_train_images, size_t num_test_images, + size_t pfreq, int nthreads) { + ILSVRC data; + data.ReadMean(bin_folder + "/mean.bin"); + auto net = CreateNet(); + auto cuda = std::make_shared(0); + net.ToDevice(cuda); + SGD sgd; + OptimizerConf opt_conf; + opt_conf.set_momentum(0.9); + auto reg = opt_conf.mutable_regularizer(); + reg->set_coefficient(0.0005); + sgd.Setup(opt_conf); + sgd.SetLearningRateGenerator( + [lr](int epoch) { return lr * std::pow(0.1, epoch / 20); }); + + SoftmaxCrossEntropy loss; + Accuracy acc; + net.Compile(true, &sgd, &loss, &acc); + + Channel *train_ch = GetChannel("train_perf"); + train_ch->EnableDestStderr(true); + Channel *val_ch = GetChannel("val_perf"); + val_ch->EnableDestStderr(true); + size_t num_train_files = num_train_images / train_file_size + + (num_train_images % train_file_size ? 1 : 0); + for (int epoch = 0; epoch < num_epoch; epoch++) { + float epoch_lr = sgd.GetLearningRate(epoch); + TrainOneEpoch(net, data, cuda, epoch, bin_folder, num_train_files, + batchsize, epoch_lr, train_ch, pfreq, nthreads); + if (epoch % 10 == 0 && epoch > 0) { + string prefix = "snapshot_epoch" + std::to_string(epoch); + Checkpoint(net, prefix); + } + TestOneEpoch(net, data, cuda, epoch, bin_folder, num_test_images, batchsize, + val_ch, nthreads); + } +} +} + +int main(int argc, char **argv) { + singa::InitChannel(nullptr); + + if (argc == 1) { + std::cout << "Usage:\n" + << "\t-epoch : number of epoch to be trained, default is 90;\n" + << "\t-lr : base learning rate;\n" + << "\t-batchsize : batchsize, it should be changed regarding " + "to your memory;\n" + << "\t-filesize : number of training images that stores in " + "each binary file;\n" + << "\t-ntrain : number of training images;\n" + << "\t-ntest : number of test images;\n" + << "\t-data : the folder which stores the binary files;\n" + << "\t-pfreq : the frequency(in batch) of printing current " + "model status(loss and accuracy);\n" + << "\t-nthreads `: the number of threads to load data which " + "feed to the model.\n"; + return 0; + } + int pos = singa::ArgPos(argc, argv, "-epoch"); + int nEpoch = 90; + if (pos != -1) nEpoch = atoi(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-lr"); + float lr = 0.01; + if (pos != -1) lr = atof(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-batchsize"); + int batchsize = 256; + if (pos != -1) batchsize = atof(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-filesize"); + size_t train_file_size = 1280; + if (pos != -1) train_file_size = atoi(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-ntrain"); + size_t num_train_images = 1281167; + if (pos != -1) num_train_images = atoi(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-ntest"); + size_t num_test_images = 50000; + if (pos != -1) num_test_images = atoi(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-data"); + string bin_folder = "imagenet_data"; + if (pos != -1) bin_folder = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-pfreq"); + size_t pfreq = 100; + if (pos != -1) pfreq = atoi(argv[pos + 1]); + + pos = singa::ArgPos(argc, argv, "-nthreads"); + int nthreads = 12; + if (pos != -1) nthreads = atoi(argv[pos + 1]); + + LOG(INFO) << "Start training"; + singa::Train(nEpoch, lr, batchsize, train_file_size, bin_folder, + num_train_images, num_test_images, pfreq, nthreads); + LOG(INFO) << "End training"; +} +#endif http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/create_data.sh ---------------------------------------------------------------------- diff --git a/examples/imagenet/create_data.sh b/examples/imagenet/create_data.sh new file mode 100755 index 0000000..dd3d9b8 --- /dev/null +++ b/examples/imagenet/create_data.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh +../../build/bin/createdata -trainlist "imagenet/label/train.txt" -trainfolder "imagenet/ILSVRC2012_img_train" \ + -testlist "imagenet/label/val.txt" -testfolder "imagenet/ILSVRC2012_img_val" -outdata "imagenet_data" -filesize 1280 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/ilsvrc12.cc ---------------------------------------------------------------------- diff --git a/examples/imagenet/ilsvrc12.cc b/examples/imagenet/ilsvrc12.cc new file mode 100644 index 0000000..2bc07f2 --- /dev/null +++ b/examples/imagenet/ilsvrc12.cc @@ -0,0 +1,69 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include "singa/singa_config.h" +#ifdef USE_OPENCV +#include "ilsvrc12.h" +#include "singa/utils/channel.h" +#include "singa/utils/string.h" +int main(int argc, char **argv) { + if (argc == 1) { + std::cout << "Usage:\n" + << "\t-trainlist : the file of training list;\n" + << "\t-trainfolder : the folder of training images;\n" + << "\t-testlist : the file of test list;\n" + << "\t-testfolder : the folder of test images;\n" + << "\t-outdata : the folder to save output files;\n" + << "\t-filesize : number of training images that stores in " + "each binary file.\n"; + return 0; + } + int pos = singa::ArgPos(argc, argv, "-trainlist"); + string train_image_list = "imagenet/label/train.txt"; + if (pos != -1) train_image_list = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-trainfolder"); + string train_image_folder = "imagenet/ILSVRC2012_img_train"; + if (pos != -1) train_image_folder = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-testlist"); + string test_image_list = "imagenet/label/val.txt"; + if (pos != -1) test_image_list = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-testfolder"); + string test_image_folder = "imagenet/ILSVRC2012_img_val"; + if (pos != -1) test_image_folder = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-outdata"); + string bin_folder = "imagenet_data"; + if (pos != -1) bin_folder = argv[pos + 1]; + + pos = singa::ArgPos(argc, argv, "-filesize"); + size_t train_file_size = 1280; + if (pos != -1) train_file_size = atoi(argv[pos + 1]); + singa::ILSVRC data; + LOG(INFO) << "Creating training and test data..."; + data.CreateTrainData(train_image_list, train_image_folder, bin_folder, + train_file_size); + data.CreateTestData(test_image_list, test_image_folder, bin_folder); + LOG(INFO) << "Data created!"; + return 0; +} +#endif // USE_OPENCV http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/ilsvrc12.h ---------------------------------------------------------------------- diff --git a/examples/imagenet/ilsvrc12.h b/examples/imagenet/ilsvrc12.h new file mode 100644 index 0000000..a6d4238 --- /dev/null +++ b/examples/imagenet/ilsvrc12.h @@ -0,0 +1,380 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ +#include "singa/singa_config.h" +#ifdef USE_OPENCV +#ifndef SINGA_EXAMPLES_IMAGENET_ILSVRC12_H_ +#define SINGA_EXAMPLES_IMAGENET_ILSVRC12_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "singa/core/tensor.h" +#include "singa/io/decoder.h" +#include "singa/io/encoder.h" +#include "singa/io/reader.h" +#include "singa/io/transformer.h" +#include "singa/io/writer.h" +#include "singa/proto/io.pb.h" +#include "singa/utils/timer.h" + +using std::string; +using namespace singa::io; +namespace singa { +/// For reading ILSVRC2012 image data as tensors. +class ILSVRC { + public: + /// Setup encoder, decoder + ILSVRC(); + ~ILSVRC() { + if (encoder != nullptr) delete encoder; + if (decoder != nullptr) delete decoder; + if (transformer != nullptr) delete transformer; + if (reader != nullptr) { + reader->Close(); + delete reader; + } + if (writer != nullptr) { + writer->Close(); + delete writer; + } + } + /// Create binary files for training data + /// train_image_list: list file of training images + /// train_image_folder: folder where stores original training images + /// train_bin_folder: folder to store binary files + /// train_file_size: number of images that are contain in one binary file + void CreateTrainData(string train_image_list, string train_image_folder, + string train_bin_folder, size_t train_file_size); + /// Create binary files for test data + /// train_image_list: list file of test images + /// train_image_folder: folder where saves original test images + /// train_bin_folder: folder to save binary files + void CreateTestData(string test_image_list, string test_image_folder, + string test_bin_folder); + /// Load data from a binary file, return pair + /// suppose the data will be loaded file by file. + /// flag: kTrain or kTest + /// file: binary file which stores the images + /// read_size: number of images to be loaded + /// offset: offset in the file + /// n_read: number of images which are read + size_t LoadData(int flag, string file, size_t read_size, Tensor *x, Tensor *y, + size_t *n_read, int nthreads); + + std::thread AsyncLoadData(int flag, string file, size_t read_size, Tensor *x, + Tensor *y, size_t *n_read, int nthreads); + + void DecodeTransform(int flag, int thid, int nthreads, + vector images, Tensor *x, Tensor *y); + std::thread AsyncDecodeTransform(int flag, int thid, int nthreads, + vector images, Tensor *x, + Tensor *y); + + /// Read mean from path + void ReadMean(string path); + + protected: + /// Read one image at path, resize the image + Tensor ReadImage(string path); + /// Write buff to the file in kCreate/kAppend mode + void Write(string outfile, singa::io::Mode mode); + void WriteMean(Tensor &mean, string path); + + private: + /// size for resizing + const size_t kImageSize = 256; + const size_t kImageNBytes = 3 * kImageSize * kImageSize; + /// size for cropping + const size_t kCropSize = 227; + Tensor mean; + string last_read_file = ""; + + JPGEncoder *encoder = nullptr; + JPGDecoder *decoder = nullptr; + ImageTransformer *transformer = nullptr; + BinFileReader *reader = nullptr; + BinFileWriter *writer = nullptr; +}; + +ILSVRC::ILSVRC() { + EncoderConf en_conf; + en_conf.set_image_dim_order("CHW"); + encoder = new JPGEncoder(); + encoder->Setup(en_conf); + + DecoderConf de_conf; + de_conf.set_image_dim_order("CHW"); + decoder = new JPGDecoder(); + decoder->Setup(de_conf); + + TransformerConf trans_conf; + trans_conf.add_crop_shape(kCropSize); + trans_conf.add_crop_shape(kCropSize); + trans_conf.set_image_dim_order("CHW"); + trans_conf.set_horizontal_mirror(true); + transformer = new ImageTransformer(); + transformer->Setup(trans_conf); +} + +Tensor ILSVRC::ReadImage(string path) { + cv::Mat mat = cv::imread(path, CV_LOAD_IMAGE_COLOR); + CHECK(mat.data != NULL) << "OpenCV load image fail: " << path; + cv::Size size(kImageSize, kImageSize); + cv::Mat resized; + cv::resize(mat, resized, size); + CHECK_EQ((size_t)resized.size().height, kImageSize); + CHECK_EQ((size_t)resized.size().width, kImageSize); + // dimension_order: CHW + Shape shape{(size_t)resized.channels(), (size_t)resized.rows, + (size_t)resized.cols}; + Tensor image(shape, singa::kUChar); + unsigned char *data = new unsigned char[kImageNBytes]; + for (int i = 0; i < resized.rows; i++) + for (int j = 0; j < resized.cols; j++) + for (int k = 0; k < resized.channels(); k++) + data[k * kImageSize * kImageSize + i * kImageSize + j] = + resized.at(i, j)[k]; + image.CopyDataFromHostPtr(data, kImageNBytes); + delete[] data; + + return image; +} + +void ILSVRC::WriteMean(Tensor &mean, string path) { + Tensor mean_lb(Shape{1}, kInt); + std::vector input; + input.push_back(mean); + input.push_back(mean_lb); + BinFileWriter bfwriter; + bfwriter.Open(path, kCreate); + bfwriter.Write(path, encoder->Encode(input)); + bfwriter.Flush(); + bfwriter.Close(); +} + +void ILSVRC::CreateTrainData(string image_list, string input_folder, + string output_folder, size_t file_size = 12800) { + std::vector> file_list; + size_t *sum = new size_t[kImageNBytes]; + for (size_t i = 0; i < kImageNBytes; i++) sum[i] = 0u; + string image_file_name; + int label; + string outfile; + std::ifstream image_list_file(image_list.c_str(), std::ios::in); + while (image_list_file >> image_file_name >> label) + file_list.push_back(std::make_pair(image_file_name, label)); + LOG(INFO) << "Data Shuffling"; + std::shuffle(file_list.begin(), file_list.end(), + std::default_random_engine()); + LOG(INFO) << "Total number of training images is " << file_list.size(); + size_t num_train_images = file_list.size(); + num_train_images = 12900; + if (file_size == 0) file_size = num_train_images; + // todo: accelerate with omp + for (size_t imageid = 0; imageid < num_train_images; imageid++) { + string path = input_folder + "/" + file_list[imageid].first; + Tensor image = ReadImage(path); + auto image_data = image.data(); + for (size_t i = 0; i < kImageNBytes; i++) + sum[i] += static_cast(image_data[i]); + label = file_list[imageid].second; + Tensor lb(Shape{1}, kInt); + lb.CopyDataFromHostPtr(&label, 1); + std::vector input; + input.push_back(image); + input.push_back(lb); + // LOG(INFO) << path << "\t" << label; + string encoded_str = encoder->Encode(input); + if (writer == nullptr) { + writer = new BinFileWriter(); + outfile = output_folder + "/train" + + std::to_string(imageid / file_size + 1) + ".bin"; + writer->Open(outfile, kCreate); + } + writer->Write(path, encoded_str); + if ((imageid + 1) % file_size == 0) { + writer->Flush(); + writer->Close(); + LOG(INFO) << "Write " << file_size << " images into " << outfile; + delete writer; + writer = nullptr; + } + } + if (writer != nullptr) { + writer->Flush(); + writer->Close(); + LOG(INFO) << "Write " << num_train_images % file_size << " images into " + << outfile; + delete writer; + writer = nullptr; + } + size_t num_file = + num_train_images / file_size + ((num_train_images % file_size) ? 1 : 0); + LOG(INFO) << "Write " << num_train_images << " images into " << num_file + << " binary files"; + Tensor mean = Tensor(Shape{3, kImageSize, kImageSize}, kUChar); + unsigned char *mean_data = new unsigned char[kImageNBytes]; + for (size_t i = 0; i < kImageNBytes; i++) + mean_data[i] = static_cast(sum[i] / num_train_images); + mean.CopyDataFromHostPtr(mean_data, kImageNBytes); + string mean_path = output_folder + "/mean.bin"; + WriteMean(mean, mean_path); + delete[] mean_data; + delete[] sum; +} + +void ILSVRC::CreateTestData(string image_list, string input_folder, + string output_folder) { + std::vector> file_list; + string image_file_name; + string outfile = output_folder + "/test.bin"; + int label; + std::ifstream image_list_file(image_list.c_str(), std::ios::in); + while (image_list_file >> image_file_name >> label) + file_list.push_back(std::make_pair(image_file_name, label)); + LOG(INFO) << "Total number of test images is " << file_list.size(); + size_t num_test_images = file_list.size(); + num_test_images = 500; + for (size_t imageid = 0; imageid < num_test_images; imageid++) { + string path = input_folder + "/" + file_list[imageid].first; + Tensor image = ReadImage(path); + label = file_list[imageid].second; + Tensor lb(Shape{1}, singa::kInt); + lb.CopyDataFromHostPtr(&label, 1); + std::vector input; + input.push_back(image); + input.push_back(lb); + string encoded_str = encoder->Encode(input); + if (writer == nullptr) { + writer = new BinFileWriter(); + writer->Open(outfile, kCreate); + } + writer->Write(path, encoded_str); + } + if (writer != nullptr) { + writer->Flush(); + writer->Close(); + delete writer; + writer = nullptr; + } + LOG(INFO) << "Write " << num_test_images << " images into " << outfile; +} + +void ILSVRC::ReadMean(string path) { + BinFileReader bfreader; + string key, value; + bfreader.Open(path); + bfreader.Read(&key, &value); + auto ret = decoder->Decode(value); + bfreader.Close(); + mean = ret[0]; +} +/// A wrapper method to spawn a thread to execute LoadData() method. +std::thread ILSVRC::AsyncLoadData(int flag, string file, size_t read_size, + Tensor *x, Tensor *y, size_t *n_read, + int nthreads) { + return std::thread( + [=]() { LoadData(flag, file, read_size, x, y, n_read, nthreads); }); +} + +size_t ILSVRC::LoadData(int flag, string file, size_t read_size, Tensor *x, + Tensor *y, size_t *n_read, int nthreads) { + x->Reshape(Shape{read_size, 3, kCropSize, kCropSize}); + y->AsType(kInt); + y->Reshape(Shape{read_size}); + if (file != last_read_file) { + if (reader != nullptr) { + reader->Close(); + delete reader; + reader = nullptr; + } + reader = new BinFileReader(); + reader->Open(file, 100 << 20); + last_read_file = file; + } else if (reader == nullptr) { + reader = new BinFileReader(); + reader->Open(file, 100 << 20); + } + vector images; + for (size_t i = 0; i < read_size; i++) { + string image_path; + string *image = new string(); + bool ret = reader->Read(&image_path, image); + if (ret == false) { + reader->Close(); + delete reader; + reader = nullptr; + break; + } + images.push_back(image); + } + int nimg = images.size(); + *n_read = nimg; + + vector threads; + for (int i = 1; i < nthreads; i++) { + threads.push_back(AsyncDecodeTransform(flag, i, nthreads, images, x, y)); + } + DecodeTransform(flag, 0, nthreads, images, x, y); + for (size_t i = 0; i < threads.size(); i++) threads[i].join(); + for (int k = 0; k < nimg; k++) delete images.at(k); + return nimg; +} + +/// A wrapper method to spawn a thread to execute Decodetransform() method. +std::thread ILSVRC::AsyncDecodeTransform(int flag, int thid, int nthreads, + vector images, Tensor *x, + Tensor *y) { + return std::thread( + [=]() { DecodeTransform(flag, thid, nthreads, images, x, y); }); +} + +void ILSVRC::DecodeTransform(int flag, int thid, int nthreads, + vector images, Tensor *x, Tensor *y) { + int nimg = images.size(); + int start = nimg / nthreads * thid; + int end = start + nimg / nthreads; + for (int k = start; k < end; k++) { + std::vector pair = decoder->Decode(*images.at(k)); + auto tmp_image = pair[0] - mean; + Tensor aug_image = transformer->Apply(flag, tmp_image); + CopyDataToFrom(x, aug_image, aug_image.Size(), k * aug_image.Size()); + CopyDataToFrom(y, pair[1], 1, k); + } + if (thid == 0) { + for (int k = nimg / nthreads * nthreads; k < nimg; k++) { + std::vector pair = decoder->Decode(*images.at(k)); + auto tmp_image = pair[0] - mean; + Tensor aug_image = transformer->Apply(flag, tmp_image); + CopyDataToFrom(x, aug_image, aug_image.Size(), k * aug_image.Size()); + CopyDataToFrom(y, pair[1], 1, k); + } + } +} +} // namespace singa + +#endif // SINGA_EXAMPLES_IMAGENET_ILSVRC12_H_ +#endif // USE_OPENCV http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/examples/imagenet/run.sh ---------------------------------------------------------------------- diff --git a/examples/imagenet/run.sh b/examples/imagenet/run.sh new file mode 100755 index 0000000..5c27b5c --- /dev/null +++ b/examples/imagenet/run.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env sh +../../build/bin/imagenet -epoch 90 -lr 0.01 -batchsize 256 -filesize 1280 -ntrain 1281167 -ntest 50000 \ + -data "imagenet_data" -pfreq 100 -nthreads 12 http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/include/singa/io/snapshot.h ---------------------------------------------------------------------- diff --git a/include/singa/io/snapshot.h b/include/singa/io/snapshot.h index 7545572..0d5aa66 100644 --- a/include/singa/io/snapshot.h +++ b/include/singa/io/snapshot.h @@ -49,7 +49,8 @@ class Snapshot { /// i.e. /// name and shape, one line per parameter. /// kRead for reading snapshot, whereas kWrite for dumping out snapshot. - Snapshot(const std::string& prefix, Mode mode); + /// max_param_size: in MB + Snapshot(const std::string& prefix, Mode mode, int max_param_size = 10); ~Snapshot() {} /// Read parameters saved as tensors from checkpoint file. std::vector> Read(); @@ -67,8 +68,9 @@ class Snapshot { private: std::string prefix_; Mode mode_; - std::unique_ptr bin_writer_ptr_, text_writer_ptr_; - std::unique_ptr bin_reader_ptr_; + std::unique_ptr bin_writer_ptr_; + std::unique_ptr text_writer_ptr_; + std::unique_ptr bin_reader_ptr_; /// Check whether parameter name is unique. std::unordered_set param_names_; /// Preload key-parameter tensor pairs for seeking a specified key. http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/include/singa/utils/timer.h ---------------------------------------------------------------------- diff --git a/include/singa/utils/timer.h b/include/singa/utils/timer.h index bdd6c5c..1372d3c 100644 --- a/include/singa/utils/timer.h +++ b/include/singa/utils/timer.h @@ -11,6 +11,7 @@ class Timer { typedef std::chrono::duration Seconds; typedef std::chrono::duration Milliseconds; typedef std::chrono::duration> Hours; + typedef std::chrono::duration Microseconds; /// Init the internal time point to the current time Timer() { Tick(); } @@ -23,8 +24,9 @@ class Timer { int Elapsed() const { static_assert(std::is_same::value || std::is_same::value || - std::is_same::value, - "Template arg must be Seconds | Milliseconds | Hours"); + std::is_same::value || + std::is_same::value, + "Template arg must be Seconds | Milliseconds | Hours | Microseconds"); auto now = std::chrono::high_resolution_clock::now(); return std::chrono::duration_cast(now - last_).count(); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/src/core/tensor/tensor.cc ---------------------------------------------------------------------- diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 4972a86..2951aa9 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -34,12 +34,12 @@ Tensor::Tensor() { device_ = defaultDevice; } Tensor::Tensor(const Shape &shape, DataType dtype) : data_type_(dtype), device_(defaultDevice), shape_(shape) { - device_ = defaultDevice; + //device_ = defaultDevice; block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(Shape &&shape, DataType dtype) : data_type_(dtype), device_(defaultDevice), shape_(shape) { - device_ = defaultDevice; + //device_ = defaultDevice; block_ = device_->NewBlock(Product(shape_) * SizeOf(data_type_)); } Tensor::Tensor(const Shape &shape, std::shared_ptr device, @@ -127,7 +127,9 @@ void Tensor::ToDevice(std::shared_ptr dst) { } } -void Tensor::ToHost() { ToDevice(device_->host()); } +void Tensor::ToHost() { + if (device_ != defaultDevice) ToDevice(device_->host()); +} template void Tensor::CopyDataFromHostPtr(const DType *src, const size_t num, http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/src/io/binfile_reader.cc ---------------------------------------------------------------------- diff --git a/src/io/binfile_reader.cc b/src/io/binfile_reader.cc index 77e34d8..9b52a5d 100644 --- a/src/io/binfile_reader.cc +++ b/src/io/binfile_reader.cc @@ -98,7 +98,7 @@ bool BinFileReader::OpenFile() { buf_ = new char[capacity_]; fdat_.open(path_, std::ios::in | std::ios::binary); CHECK(fdat_.is_open()) << "Cannot open file " << path_; - return fdat_.is_open(); + return fdat_.is_open(); } bool BinFileReader::ReadField(std::string* content) { @@ -108,7 +108,9 @@ bool BinFileReader::ReadField(std::string* content) { int len = *reinterpret_cast(buf_ + offset_); offset_ += ssize; if (!PrepareNextField(len)) return false; - for (int i = 0; i < len; ++i) content->push_back(buf_[offset_ + i]); + content->reserve(len); + content->insert(0, buf_ + offset_, len); + //for (int i = 0; i < len; ++i) content->push_back(buf_[offset_ + i]); offset_ += len; return true; } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/src/io/jpg_encoder.cc ---------------------------------------------------------------------- diff --git a/src/io/jpg_encoder.cc b/src/io/jpg_encoder.cc index 9db799d..8335a91 100644 --- a/src/io/jpg_encoder.cc +++ b/src/io/jpg_encoder.cc @@ -72,7 +72,7 @@ std::string JPGEncoder::Encode(vector& data) { // suppose each image is attached with at most one label if (data.size() == 2) { const int* label = data[1].data(); - CHECK_EQ(label[0], 2); + //CHECK_EQ(label[0], 2); record.add_label(label[0]); } http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/34d3ae68/src/io/snapshot.cc ---------------------------------------------------------------------- diff --git a/src/io/snapshot.cc b/src/io/snapshot.cc index 3b9b8ce..58c7044 100644 --- a/src/io/snapshot.cc +++ b/src/io/snapshot.cc @@ -29,17 +29,17 @@ #include namespace singa { -Snapshot::Snapshot(const std::string& prefix, Mode mode) +Snapshot::Snapshot(const std::string& prefix, Mode mode, int max_param_size /*in MB*/) : prefix_(prefix), mode_(mode), bin_writer_ptr_(mode_ == kWrite ? (new io::BinFileWriter) : nullptr), text_writer_ptr_(mode_ == kWrite ? (new io::TextFileWriter) : nullptr), bin_reader_ptr_(mode_ == kRead ? (new io::BinFileReader) : nullptr) { if (mode_ == kWrite) { - bin_writer_ptr_->Open(prefix + ".model", io::kCreate); + bin_writer_ptr_->Open(prefix + ".model", io::kCreate, max_param_size << 20); text_writer_ptr_->Open(prefix + ".desc", io::kCreate); } else if (mode == kRead) { - bin_reader_ptr_->Open(prefix + ".model"); + bin_reader_ptr_->Open(prefix + ".model", max_param_size << 20); std::string key, serialized_str; singa::TensorProto tp; while (bin_reader_ptr_->Read(&key, &serialized_str)) { @@ -63,6 +63,7 @@ void Snapshot::Write(const std::string& key, const Tensor& param) { std::string serialized_str; CHECK(tp.SerializeToString(&serialized_str)); bin_writer_ptr_->Write(key, serialized_str); +// bin_writer_ptr_->Flush(); std::string desc_str = "parameter name: " + key; Shape shape = param.shape(); @@ -71,6 +72,7 @@ void Snapshot::Write(const std::string& key, const Tensor& param) { desc_str += "\tshape:"; for (size_t s : shape) desc_str += " " + std::to_string(s); text_writer_ptr_->Write(key, desc_str); + // text_writer_ptr_->Flush(); } std::vector> Snapshot::Read() {