mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] stonedl3 commented on issue #9766: DeepLearning on Imagenet with mxnet issues translating .lst to .rec files
Date Thu, 01 Jan 1970 00:00:00 GMT
stonedl3 commented on issue #9766: DeepLearning on Imagenet with mxnet issues translating .lst
to .rec files
URL: https://github.com/apache/incubator-mxnet/issues/9766#issuecomment-366445508
 
 
   Here is the script used to train alexnet
   
   # USAGE
   # python train_alexnet.py --checkpoints checkpoints --prefix alexnet
   # python train_alexnet.py --checkpoints checkpoints --prefix alexnet --start-epoch 25
   
   # import the necessary packages
   from config import imagenet_alexnet_config as config
   from pyimagesearch.nn.mxconv import MxAlexNet
   import mxnet as mx
   import argparse
   import logging
   import json
   import os
   
   # construct the argument parse and parse the arguments
   ap = argparse.ArgumentParser()
   ap.add_argument("-c", "--checkpoints", required=True,
   	help="path to output checkpoint directory")
   ap.add_argument("-p", "--prefix", required=True,
   	help="name of model prefix")
   ap.add_argument("-s", "--start-epoch", type=int, default=0,
   	help="epoch to restart training at")
   args = vars(ap.parse_args())
   
   # set the logging level and output file
   logging.basicConfig(level=logging.DEBUG,
   	filename="training_{}.log".format(args["start_epoch"]),
   	filemode="w")
   
   # load the RGB means for the training set, then determine the batch
   # size
   means = json.loads(open(config.DATASET_MEAN).read())
   batchSize = config.BATCH_SIZE * config.NUM_DEVICES
   
   # construct the training image iterator
   trainIter = mx.io.ImageRecordIter(
   	path_imgrec=config.TRAIN_MX_REC,
   	data_shape=(3, 224, 224),
   	batch_size=batchSize,
   	rand_crop=True,
   	rand_mirror=True,
   	rotate=15,
   	max_shear_ratio=0.1,
   	mean_r=means["R"],
   	mean_g=means["G"],
   	mean_b=means["B"],
   	preprocess_threads=config.NUM_DEVICES * 2)
   
   # construct the validation image iterator
   valIter = mx.io.ImageRecordIter(
   	path_imgrec=config.VAL_MX_REC,
   	data_shape=(3, 224, 224),
   	batch_size=batchSize,
   	mean_r=means["R"],
   	mean_g=means["G"],
   	mean_b=means["B"])
   
   # initialize the optimizer
   opt = mx.optimizer.SGD(learning_rate=1e-2, momentum=0.9, wd=0.0005,
   	rescale_grad=1.0 / batchSize)
   
   # construct the checkpoints path, initialize the model argument and
   # auxiliary parameters
   checkpointsPath = os.path.sep.join([args["checkpoints"],
   	args["prefix"]])
   argParams = None
   auxParams = None
   
   # if there is no specific model starting epoch supplied, then
   # initialize the network
   if args["start_epoch"] <= 0:
   	# build the LeNet architecture
   	print("[INFO] building network...")
   	model = MxAlexNet.build(config.NUM_CLASSES)
   
   # otherwise, a specific checkpoint was supplied
   else:
   	# load the checkpoint from disk
   	print("[INFO] loading epoch {}...".format(args["start_epoch"]))
   	model = mx.model.FeedForward.load(checkpointsPath,
   		args["start_epoch"])
   
   	# update the model and parameters
   	argParams = model.arg_params
   	auxParams = model.aux_params
   	model = model.symbol
   
   # compile the model
   model = mx.mod.Module(
   	context=[mx.gpu(0), mx.gpu(1)],
   	symbol=model)
   
   # initialize the callbacks and evaluation metrics
   batchEndCBs = [mx.callback.Speedometer(batchSize, 500)]
   epochEndCBs = [mx.callback.do_checkpoint(checkpointsPath)]
   metrics = [mx.metric.Accuracy(), mx.metric.TopKAccuracy(top_k=5),
   	mx.metric.CrossEntropy()]
   
   # train the network
   print("[INFO] training network...")
   model.fit(
   	train_data=trainIter,
   	eval_data=valIter,
   	eval_metric=metrics,
   	batch_end_callback=batchEndCBs,
   	epoch_end_callback=epochEndCBs,
   	initializer=mx.initializer.Xavier(),
   	arg_params=argParams,
   	aux_params=auxParams,
   	optimizer=opt,
   	num_epoch=65,
   	begin_epoch=args["start_epoch"])

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message