Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 93F8C160BED for ; Tue, 21 Nov 2017 15:08:32 +0100 (CET) Received: (qmail 1164 invoked by uid 500); 21 Nov 2017 14:08:31 -0000 Mailing-List: contact commits-help@mxnet.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@mxnet.incubator.apache.org Delivered-To: mailing list commits@mxnet.incubator.apache.org Received: (qmail 1155 invoked by uid 99); 21 Nov 2017 14:08:31 -0000 Received: from ec2-52-202-80-70.compute-1.amazonaws.com (HELO gitbox.apache.org) (52.202.80.70) by apache.org (qpsmtpd/0.29) with ESMTP; Tue, 21 Nov 2017 14:08:31 +0000 From: GitBox To: commits@mxnet.apache.org Subject: [GitHub] SumNeuron opened a new issue #8751: Distributed Training has inverse results when imported (8 GPUS is slower than 1!) Message-ID: <151127331120.28289.8892015724850284195.gitbox@gitbox.apache.org> archived-at: Tue, 21 Nov 2017 14:08:33 -0000 SumNeuron opened a new issue #8751: Distributed Training has inverse results when imported (8 GPUS is slower than 1!) URL: https://github.com/apache/incubator-mxnet/issues/8751 This issues will be demonstrated with 3 files - *file 1*: distributed training in a single file - *file 2*: a file which loads a net from another file and runs the same distributed learning - *file 3*: the net from file 1 There will also be 3 outputs - *out 1*: file 1 run on 1, 2, 4, and 8 gpus - *out 2*: file 2 running net from file 3 with 1 gpu - *out 3*: file 2 running net from file 1 with 8 gpus The important things to note here is that 1.) I modified the `get_mnist()` function to load from a local directory 2.) I am not running any evaluation metric in file 2 (but this is not factored into the run time) 3.) using 8 gpus takes 18 seconds for file 2, but 2.7 for file 1, whereas it takes 2.3-3 seconds for 1 gpu in file 1 and 4 seconds for 1 gpu in file 2 (this makes no sense). 4.) the distributed training is an adaption of the chapter from the straight dope. Lastly, one can call file 2 by something like: `python3 distrubted.py -v -i test_net -g 1 -e 10` where distributed.py is file2 and test_net.py is file 3 So what is going on? # File 1 ``` from __future__ import absolute_import, print_function, division import mxnet as mx from mxnet import nd, gluon, autograd from mxnet.context import Context from mxnet.ndarray import array from mxnet.symbol import Symbol from mxnet.io import NDArrayIter from time import time from mxnet.test_utils import get_mnist import json import traceback import numbers import subprocess import os import errno import logging import numpy as np import numpy.testing as npt import gzip import struct # import mxnet as mx # from .context import Context # from .ndarray import array # from .symbol import Symbol try: import requests except ImportError: # in rare cases requests may be not installed pass # def get_mnist(): """Download and load the MNIST dataset Returns ------- dict A dict containing the data """ def read_data(label_url, image_url): with gzip.open(label_url) as flbl: struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(image_url, 'rb') as fimg: _, _, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) image = image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255 return (label, image) # changed to mxnet.io for more stable hosting # path = 'http://yann.lecun.com/exdb/mnist/' path = 'http://data.mxnet.io/data/mnist/' path = '/workspace/mxnet/MNIST_data/' (train_lbl, train_img) = read_data( path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz') (test_lbl, test_img) = read_data( path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz') return {'train_data':train_img, 'train_label':train_lbl, 'test_data':test_img, 'test_label':test_lbl} print('Defining network') net = gluon.nn.Sequential(prefix="cnn_") with net.name_scope(): net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu')) net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(128, activation='relu')) net.add(gluon.nn.Dense(10)) print(net) loss = gluon.loss.SoftmaxCrossEntropyLoss() def forward_backward(net, data, label): with autograd.record(): losses = [loss(net(X), Y) for X, Y in zip(data, label)] for l in losses: l.backward() def train_batch(batch, ctx, net, trainer): data = gluon.utils.split_and_load(batch.data[0], ctx) label = gluon.utils.split_and_load(batch.label[0], ctx) # compute gradient forward_backward(net, data, label) # update parameters trainer.step(batch.data[0].shape[0]) def valid_batch(batch, ctx, net): data = batch.data[0].as_in_context(ctx[0]) pred = nd.argmax(net(data), axis=1) return nd.sum(pred == batch.label[0].as_in_context(ctx[0])).asscalar() def run(num_gpus, batch_size, lr, epochs=5): # the list of GPUs will be used durations = [] accuracies = [] ctx = [mx.gpu(i) for i in range(num_gpus)] print('Running on {}'.format(ctx)) # data iterator print('loading mnist') mnist = get_mnist() print('mnist loaded') train_data = NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size) valid_data = NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size) print('Batch size is {}'.format(batch_size)) print('initalizing parameters') net.collect_params().initialize(force_reinit=True, ctx=ctx) print('initalizing trainer') trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr}) for epoch in range(epochs): # train start = time() train_data.reset() for batch in train_data: train_batch(batch, ctx, net, trainer) nd.waitall() # wait until all computations are finished to benchmark the time epoch_train_time = time()-start durations.append(epoch_train_time) print('Epoch %d, training time = %.1f sec'%(epoch, epoch_train_time)) # validating valid_data.reset() correct, num = 0.0, 0.0 for batch in valid_data: correct += valid_batch(batch, ctx, net) num += batch.data[0].shape[0] print('\t\tValidation accuracy = %.4f'%(correct/num)) accuracies.append(correct / num) return {'duration': durations, 'accuracy': accuracies} data = {} for gpus in [1, 2, 4, 8]: print('running with {} GPUs'.format(gpus)) da = run(gpus, 64*gpus, .3, 10) data[str(gpus)] = da print(json.dumps(data)) ``` # File 2 ``` import argparse, sys parser = argparse.ArgumentParser() # parser.add_argument('in', help='the directory where the crs files are stored. e.g.\n\tdata.txt (values of data)\n\tindices.txt (index of those values)\n\tindptr.txt (start of rows)\n\tshape.txt (matrix shape).') # parser.add_argument('out', help='the full path to where the dense matrix is stored.') parser.add_argument('-v', '--verbosity', help='increase output statements', action='count', default=0) parser.add_argument('-g', '--gpus', help='set the number of GPUs to run on', default=1) parser.add_argument('-d', '--delim', help='the delim for the out file', default=',') parser.add_argument('-i', '--input', help='the file defining the network') parser.add_argument('-o', '--output', help='the directory for the output files', default='./') parser.add_argument('-s', '--seed', help='set the seed', default=3) parser.add_argument('-bs', '--batch-size', help='set the batch size', default=64) parser.add_argument('-lr', '--learning-rate', help='set the learning rate', default=0.001) parser.add_argument('-e', '--epochs', help='set the epochs', default=100) parser.add_argument('-tm', '--training-method', help='set the training method', default='sgd') args = parser.parse_args() v = args.verbosity d = args.delim # g = args.gpus # s = args.seed # bs = args.batch_size # lr = args.learning_rate # e = args.epochs # tm = args.training_method inp = args.input out = args.output _global_functions = { 'loss':None, 'net':None } _global_data = { 'training_data': None, 'test_data':None } _global_args = { "batch_size": None, "learning_rate": None, "epochs": None, 'seed': None, 'gpus': None, 'training_method': None } def extract_mod_attr(var_dict, mod, attr, v=v, try_defaults=False): if v: print('[EXTRACT]:\t{} from {}'.format(attr,mod)) try: var_dict[attr] = getattr(__import__(mod, attr), attr) except AttributeError as err: print('[ERROR]:\t{} could not be imported from {}'.format(attr,mod)) if var_dict[attr] is None and try_defaults: if v: print('[TRY]:\tlooking for default value for {}'.format(attr)) try: var_dict[attr] = eval("args."+attr) except Exception as excep: print('[FATAL]:\t{} is not defined in {} or defaults'.format(attr, mod)) sys.exit() if var_dict[attr] is None: print('[FATAL]:\t{} could not be imported from {}'.format(attr,mod)) sys.exit() if v: print('[EXTRACTED]:\t{}:\t{}'.format(attr, var_dict[attr])) # Extract Global Functions (e.g. net, loss, etc) for k in _global_functions: extract_mod_attr(_global_functions, inp, k, v, try_defaults=False) # Extract Global Arguments (e.g. learning rate, batch size, etc) for k in _global_args: extract_mod_attr(_global_args, inp, k, v, try_defaults=True) # Extract Global Data (e.g. training, validation and test data sets) for k in _global_data: extract_mod_attr(_global_data, inp, k, v, try_defaults=False) if v: print('[IMPORT]:\timporting dependencies') import mxnet as mx from mxnet import nd, autograd from mxnet import gluon import numpy as np from time import time import json mx.random.seed(_global_args['seed']) def parallel_forward_backward_pass(net, loss, data, label): with autograd.record(): losses = [loss(net(X), Y) for X, Y in zip(data, label)] for loss in losses: loss.backward() def parallel_batch(batch, ctx, net, trainer, loss): data = gluon.utils.split_and_load(batch.data[0], ctx) label = gluon.utils.split_and_load(batch.label[0], ctx) # compute gradient parallel_forward_backward_pass(net, loss, data, label) # update parameters trainer.step(batch.data[0].shape[0]) def parallel_run(number_of_gpus, net, training_method, loss, batch_size, learning_rate, epochs, training_data, test_data, validation_data=None): durations = [] accuracies = [] ctx = [mx.gpu(i) for i in range(number_of_gpus)] if v: print('Running on {} gpus'.format(number_of_gpus)) print(ctx) if v: print('[INIT]:\tnet parameters') net.collect_params().initialize(force_reinit=True, ctx=ctx) if v: print('[INIT]:\ttrainer') trainer = gluon.Trainer(net.collect_params(), training_method, {'learning_rate': learning_rate}) for epoch in range(epochs): start = time() training_data.reset() for batch in training_data: parallel_batch(batch, ctx, net, trainer, loss) nd.waitall() epoch_training_time = time()-start durations.append(epoch_training_time) if v: print('Epoch %d, training time = %.1f sec'%(epoch, epoch_training_time)) test_data.reset() acc = 0 # validation function print('\t\tValidation Accuracy = %.4f'%(acc)) accuracies.append(acc) return {'duration':durations, 'accuracy': accuracies} # data = {} g = int(_global_args['gpus']) net = _global_functions['net'] loss = _global_functions['loss'] tm = _global_args["training_method"] bs = int(_global_args['batch_size']) lr = float(_global_args['learning_rate']) e = int(_global_args['epochs']) train = _global_data['training_data'] test = _global_data['test_data'] print('______________________________') print('[GO]:\tParallel Run') parallel_run( number_of_gpus=g, net=net, training_method=tm, loss=loss, batch_size=bs, learning_rate=lr, epochs=e, training_data=train, test_data=test, validation_data=None ) print('[END]:\tParallel Run') print('______________________________') ``` # File 3 ``` import mxnet as mx from mxnet import nd, gluon, autograd from mxnet.io import NDArrayIter from mxnet.test_utils import get_mnist import traceback import numbers import subprocess import os import errno import logging import numpy as np import numpy.testing as npt import gzip import struct net = gluon.nn.Sequential(prefix="cnn_") with net.name_scope(): net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu')) net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2))) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(128, activation='relu')) net.add(gluon.nn.Dense(10)) loss = gluon.loss.SoftmaxCrossEntropyLoss() batch_size = 64 # learning_rate = 0.001 # epochs = 100 training_method = 'sgd' def get_mnist(): """Download and load the MNIST dataset Returns ------- dict A dict containing the data """ def read_data(label_url, image_url): with gzip.open(label_url) as flbl: struct.unpack(">II", flbl.read(8)) label = np.fromstring(flbl.read(), dtype=np.int8) with gzip.open(image_url, 'rb') as fimg: _, _, rows, cols = struct.unpack(">IIII", fimg.read(16)) image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols) image = image.reshape(image.shape[0], 1, 28, 28).astype(np.float32)/255 return (label, image) # changed to mxnet.io for more stable hosting # path = 'http://yann.lecun.com/exdb/mnist/' path = 'http://data.mxnet.io/data/mnist/' path = '/workspace/mxnet/MNIST_data/' (train_lbl, train_img) = read_data( path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz') (test_lbl, test_img) = read_data( path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz') return {'train_data':train_img, 'train_label':train_lbl, 'test_data':test_img, 'test_label':test_lbl} # trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learning_rate}) mnist = get_mnist() training_data = NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size) test_data = NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size) ``` # Out 1 ``` Defining network Sequential( (0): Conv2D(20, kernel_size=(3, 3), stride=(1, 1)) (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (2): Conv2D(50, kernel_size=(5, 5), stride=(1, 1)) (3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (4): Flatten (5): Dense(128, Activation(relu)) (6): Dense(10, linear) ) running with 1 GPUs Running on [gpu(0)] loading mnist mnist loaded Batch size is 64 initalizing parameters initalizing trainer [13:19:10] src/operator/././cudnn_algoreg-inl.h:112: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) Epoch 0, training time = 3.5 sec Validation accuracy = 0.9723 Epoch 1, training time = 2.3 sec Validation accuracy = 0.9824 Epoch 2, training time = 2.3 sec Validation accuracy = 0.9859 Epoch 3, training time = 2.2 sec Validation accuracy = 0.9868 Epoch 4, training time = 2.3 sec Validation accuracy = 0.9870 Epoch 5, training time = 2.3 sec Validation accuracy = 0.9874 Epoch 6, training time = 2.2 sec Validation accuracy = 0.9866 Epoch 7, training time = 2.3 sec Validation accuracy = 0.9866 Epoch 8, training time = 2.3 sec Validation accuracy = 0.9865 Epoch 9, training time = 2.3 sec Validation accuracy = 0.9861 running with 2 GPUs Running on [gpu(0), gpu(1)] loading mnist mnist loaded Batch size is 128 initalizing parameters initalizing trainer Epoch 0, training time = 3.5 sec Validation accuracy = 0.9497 Epoch 1, training time = 2.4 sec Validation accuracy = 0.9712 Epoch 2, training time = 2.4 sec Validation accuracy = 0.9818 Epoch 3, training time = 2.4 sec Validation accuracy = 0.9850 Epoch 4, training time = 2.4 sec Validation accuracy = 0.9866 Epoch 5, training time = 2.4 sec Validation accuracy = 0.9870 Epoch 6, training time = 2.4 sec Validation accuracy = 0.9861 Epoch 7, training time = 2.4 sec Validation accuracy = 0.9852 Epoch 8, training time = 2.4 sec Validation accuracy = 0.9887 Epoch 9, training time = 2.4 sec Validation accuracy = 0.9876 running with 4 GPUs Running on [gpu(0), gpu(1), gpu(2), gpu(3)] loading mnist mnist loaded Batch size is 256 initalizing parameters initalizing trainer Epoch 0, training time = 2.3 sec Validation accuracy = 0.9521 Epoch 1, training time = 2.2 sec Validation accuracy = 0.9672 Epoch 2, training time = 2.3 sec Validation accuracy = 0.9822 Epoch 3, training time = 2.3 sec Validation accuracy = 0.9843 Epoch 4, training time = 2.3 sec Validation accuracy = 0.9856 Epoch 5, training time = 2.3 sec Validation accuracy = 0.9873 Epoch 6, training time = 2.3 sec Validation accuracy = 0.9869 Epoch 7, training time = 2.3 sec Validation accuracy = 0.9876 Epoch 8, training time = 2.3 sec Validation accuracy = 0.9877 Epoch 9, training time = 2.3 sec Validation accuracy = 0.9865 running with 8 GPUs Running on [gpu(0), gpu(1), gpu(2), gpu(3), gpu(4), gpu(5), gpu(6), gpu(7)] loading mnist mnist loaded Batch size is 512 initalizing parameters initalizing trainer [13:20:40] src/kvstore/././comm.h:327: only 32 out of 56 GPU pairs are enabled direct access. It may affect the performance. You can set MXNET_ENABLE_GPU_P2P=0 to turn it off [13:20:40] src/kvstore/././comm.h:336: .vvvv... [13:20:40] src/kvstore/././comm.h:336: v.vv.v.. [13:20:40] src/kvstore/././comm.h:336: vv.v..v. [13:20:40] src/kvstore/././comm.h:336: vvv....v [13:20:40] src/kvstore/././comm.h:336: v....vvv [13:20:40] src/kvstore/././comm.h:336: .v..v.vv [13:20:40] src/kvstore/././comm.h:336: ..v.vv.v [13:20:40] src/kvstore/././comm.h:336: ...vvvv. Epoch 0, training time = 2.5 sec Validation accuracy = 0.9287 Epoch 1, training time = 2.3 sec Validation accuracy = 0.9634 Epoch 2, training time = 2.3 sec Validation accuracy = 0.9754 Epoch 3, training time = 2.3 sec Validation accuracy = 0.9794 Epoch 4, training time = 2.3 sec Validation accuracy = 0.9805 Epoch 5, training time = 2.3 sec Validation accuracy = 0.9824 Epoch 6, training time = 2.3 sec Validation accuracy = 0.9834 Epoch 7, training time = 2.3 sec Validation accuracy = 0.9835 Epoch 8, training time = 2.3 sec Validation accuracy = 0.9847 Epoch 9, training time = 2.3 sec Validation accuracy = 0.9846 {"4": {"accuracy": [0.9521484375, 0.9671875, 0.9822265625, 0.98427734375, 0.98564453125, 0.9873046875, 0.9869140625, 0.98759765625, 0.9876953125, 0.9865234375], "duration": [2.305682897567749, 2.2148516178131104, 2.2893030643463135, 2.3267581462860107, 2.328000545501709, 2.2684123516082764, 2.274890899658203, 2.296605348587036, 2.258661985397339, 2.288602352142334]}, "8": {"accuracy": [0.9287109375, 0.96337890625, 0.975390625, 0.97939453125, 0.98046875, 0.982421875, 0.9833984375, 0.98349609375, 0.98466796875, 0.9845703125], "duration": [2.457437753677368, 2.306643009185791, 2.316239833831787, 2.2657582759857178, 2.2696518898010254, 2.2852253913879395, 2.303246021270752, 2.298490285873413, 2.2846038341522217, 2.290074110031128]}, "2": {"accuracy": [0.9496637658227848, 0.9712223101265823, 0.9818037974683544, 0.9849683544303798, 0.9865506329113924, 0.9870450949367089, 0.986056170886076, 0.9851661392405063, 0.9887262658227848, 0.9876384493670886], "duration": [3.522899866104126, 2.3761959075927734, 2.4048850536346436, 2.403083562850952, 2.4359405040740967, 2.3811116218566895, 2.4005138874053955, 2.3995509147644043, 2.4216606616973877, 2.4278368949890137]}, "1": {"accuracy": [0.9723328025477707, 0.9823845541401274, 0.9858678343949044, 0.9867635350318471, 0.9869625796178344, 0.9873606687898089, 0.9865644904458599, 0.9865644904458599, 0.9864649681528662, 0.9860668789808917], "duration": [3.4593358039855957, 2.284571647644043, 2.304379940032959, 2.244168281555176, 2.3261559009552, 2.3148741722106934, 2.232365846633911, 2.2623894214630127, 2.2599427700042725, 2.273547410964966]}} ``` # Out 2 ``` [EXTRACT]: loss from test_net [EXTRACTED]: loss: SoftmaxCrossEntropyLoss(batch_axis=0, w=None) [EXTRACT]: net from test_net [EXTRACTED]: net: Sequential( (0): Conv2D(20, kernel_size=(3, 3), stride=(1, 1)) (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (2): Conv2D(50, kernel_size=(5, 5), stride=(1, 1)) (3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (4): Flatten (5): Dense(128, Activation(relu)) (6): Dense(10, linear) ) [EXTRACT]: batch_size from test_net [EXTRACTED]: batch_size: 64 [EXTRACT]: learning_rate from test_net [ERROR]: learning_rate could not be imported from test_net [TRY]: looking for default value for learning_rate [EXTRACTED]: learning_rate: 0.001 [EXTRACT]: training_method from test_net [EXTRACTED]: training_method: sgd [EXTRACT]: gpus from test_net [ERROR]: gpus could not be imported from test_net [TRY]: looking for default value for gpus [EXTRACTED]: gpus: 8 [EXTRACT]: epochs from test_net [ERROR]: epochs could not be imported from test_net [TRY]: looking for default value for epochs [EXTRACTED]: epochs: 10 [EXTRACT]: seed from test_net [ERROR]: seed could not be imported from test_net [TRY]: looking for default value for seed [EXTRACTED]: seed: 3 [EXTRACT]: training_data from test_net [EXTRACTED]: training_data: [EXTRACT]: test_data from test_net [EXTRACTED]: test_data: [IMPORT]: importing dependencies ______________________________ [GO]: Parallel Run Running on 8 gpus [INIT]: net parameters [INIT]: trainer [10:08:09] src/operator/././cudnn_algoreg-inl.h:112: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) [10:08:09] src/kvstore/././comm.h:327: only 32 out of 56 GPU pairs are enabled direct access. It may affect the performance. You can set MXNET_ENABLE_GPU_P2P=0 to turn it off [10:08:09] src/kvstore/././comm.h:336: .vvvv... [10:08:09] src/kvstore/././comm.h:336: v.vv.v.. [10:08:09] src/kvstore/././comm.h:336: vv.v..v. [10:08:09] src/kvstore/././comm.h:336: vvv....v [10:08:09] src/kvstore/././comm.h:336: v....vvv [10:08:09] src/kvstore/././comm.h:336: .v..v.vv [10:08:09] src/kvstore/././comm.h:336: ..v.vv.v [10:08:09] src/kvstore/././comm.h:336: ...vvvv. Epoch 0, training time = 18.8 sec Validation Accuracy = 0.0000 Epoch 1, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 2, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 3, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 4, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 5, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 6, training time = 18.0 sec Validation Accuracy = 0.0000 Epoch 7, training time = 18.2 sec Validation Accuracy = 0.0000 Epoch 8, training time = 18.1 sec Validation Accuracy = 0.0000 Epoch 9, training time = 17.9 sec Validation Accuracy = 0.0000 [END]: Parallel Run ``` # Out 3 ``` EXTRACT]: loss from test_net [EXTRACTED]: loss: SoftmaxCrossEntropyLoss(batch_axis=0, w=None) [EXTRACT]: net from test_net [EXTRACTED]: net: Sequential( (0): Conv2D(20, kernel_size=(3, 3), stride=(1, 1)) (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (2): Conv2D(50, kernel_size=(5, 5), stride=(1, 1)) (3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False) (4): Flatten (5): Dense(128, Activation(relu)) (6): Dense(10, linear) ) [EXTRACT]: learning_rate from test_net [ERROR]: learning_rate could not be imported from test_net [TRY]: looking for default value for learning_rate [EXTRACTED]: learning_rate: 0.001 [EXTRACT]: gpus from test_net [ERROR]: gpus could not be imported from test_net [TRY]: looking for default value for gpus [EXTRACTED]: gpus: 1 [EXTRACT]: epochs from test_net [ERROR]: epochs could not be imported from test_net [TRY]: looking for default value for epochs [EXTRACTED]: epochs: 10 [EXTRACT]: batch_size from test_net [EXTRACTED]: batch_size: 64 [EXTRACT]: seed from test_net [ERROR]: seed could not be imported from test_net [TRY]: looking for default value for seed [EXTRACTED]: seed: 3 [EXTRACT]: training_method from test_net [EXTRACTED]: training_method: sgd [EXTRACT]: test_data from test_net [EXTRACTED]: test_data: [EXTRACT]: training_data from test_net [EXTRACTED]: training_data: [IMPORT]: importing dependencies ______________________________ [GO]: Parallel Run Running on 1 gpus [gpu(0)] [INIT]: net parameters [INIT]: trainer [13:11:13] src/operator/././cudnn_algoreg-inl.h:112: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable) Epoch 0, training time = 3.5 sec Validation Accuracy = 0.0000 Epoch 1, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 2, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 3, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 4, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 5, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 6, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 7, training time = 2.3 sec Validation Accuracy = 0.0000 Epoch 8, training time = 2.4 sec Validation Accuracy = 0.0000 Epoch 9, training time = 2.3 sec Validation Accuracy = 0.0000 [END]: Parallel Run ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org With regards, Apache Git Services