singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [1/2] incubator-singa git commit: SINGA-344 Add a GAN example
Date Mon, 27 Aug 2018 14:31:16 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 8aac80e42 -> f8cd7e384


SINGA-344 Add a GAN example


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/b1610d75
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/b1610d75
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/b1610d75

Branch: refs/heads/master
Commit: b1610d7576cd58cbc0c989af540c6c64c501585c
Parents: 2224d5f
Author: huangwentao <wentaohuang@ruc.edu.cn>
Authored: Fri Aug 24 10:16:37 2018 +0800
Committer: huangwentao <wentaohuang@ruc.edu.cn>
Committed: Fri Aug 24 10:16:37 2018 +0800

----------------------------------------------------------------------
 examples/gan/download_mnist.py |  28 +++++
 examples/gan/lsgan.py          | 213 ++++++++++++++++++++++++++++++++++++
 examples/gan/utils.py          |  67 ++++++++++++
 examples/gan/vanilla.py        | 207 +++++++++++++++++++++++++++++++++++
 4 files changed, 515 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/download_mnist.py
----------------------------------------------------------------------
diff --git a/examples/gan/download_mnist.py b/examples/gan/download_mnist.py
new file mode 100644
index 0000000..b042a7c
--- /dev/null
+++ b/examples/gan/download_mnist.py
@@ -0,0 +1,28 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 
+
+import argparse
+from utils import download_data
+
+if __name__ == '__main__':
+	parser = argparse.ArgumentParser(description='download the pre-processed MNIST dataset')
+	parser.add_argument('gzfile', type=str, help='the dataset path')
+	parser.add_argument('url', type=str, help='dataset url')	
+	args = parser.parse_args()
+	download_data(args.gzfile, args.url)
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/lsgan.py
----------------------------------------------------------------------
diff --git a/examples/gan/lsgan.py b/examples/gan/lsgan.py
new file mode 100644
index 0000000..dc6582c
--- /dev/null
+++ b/examples/gan/lsgan.py
@@ -0,0 +1,213 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class LSGAN():
+	def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128,
batch=128, 
+		interval=1000, learning_rate=0.001, epochs=1000000, d_steps=3, g_steps=1, 
+		dataset_filepath='mnist.pkl.gz', file_dir='lsgan_images/'):
+		self.dev = dev
+		self.rows = rows
+		self.cols = cols
+		self.channels = channels
+		self.feature_size = self.rows * self.cols * self.channels
+		self.noise_size = noise_size
+		self.hidden_size = hidden_size
+		self.batch = batch
+		self.batch_size = self.batch//2
+		self.interval = interval
+		self.learning_rate = learning_rate
+		self.epochs = epochs
+		self.d_steps = d_steps
+		self.g_steps = g_steps
+		self.dataset_filepath = dataset_filepath
+		self.file_dir = file_dir
+
+		self.g_w0_specs = {'init': 'xavier',}
+		self.g_b0_specs = {'init': 'constant', 'value': 0,}
+		self.g_w1_specs = {'init': 'xavier',}
+		self.g_b1_specs = {'init': 'constant', 'value': 0,}
+		self.gen_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+		self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True,

+			W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
+		self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+		self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True,

+			W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
+		self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
+		self.gen_net.add(self.gen_net_fc_0)
+		self.gen_net.add(self.gen_net_relu_0)
+		self.gen_net.add(self.gen_net_fc_1)
+		self.gen_net.add(self.gen_net_sigmoid_1)
+		for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
+			filler = specs.filler
+			if filler.type == 'gaussian':
+				p.gaussian(filler.mean, filler.std)
+			elif filler.type == 'xavier':
+				initializer.xavier(p)
+			else: 
+				p.set_value(0)
+			print(specs.name, filler.type, p.l1())	
+		self.gen_net.to_device(self.dev)		
+
+		self.d_w0_specs = {'init': 'xavier',}
+		self.d_b0_specs = {'init': 'constant', 'value': 0,}
+		self.d_w1_specs = {'init': 'xavier',}
+		self.d_b1_specs = {'init': 'constant', 'value': 0,}			
+		self.dis_net = ffnet.FeedForwardNet(loss.SquaredError(),)
+		self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True,

+			W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
+		self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+		self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  use_bias=True, 
+			W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
+		self.dis_net.add(self.dis_net_fc_0)
+		self.dis_net.add(self.dis_net_relu_0)
+		self.dis_net.add(self.dis_net_fc_1)			
+		for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
+			filler = specs.filler
+			if filler.type == 'gaussian':
+				p.gaussian(filler.mean, filler.std)
+			elif filler.type == 'xavier':
+				initializer.xavier(p)
+			else: 
+				p.set_value(0)
+			print(specs.name, filler.type, p.l1())
+		self.dis_net.to_device(self.dev)
+
+		self.combined_net = ffnet.FeedForwardNet(loss.SquaredError(), )
+		for l in self.gen_net.layers:
+			self.combined_net.add(l)
+		for l in self.dis_net.layers:
+			self.combined_net.add(l)
+		self.combined_net.to_device(self.dev)
+
+	def train(self):
+		train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+		opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator 
+		opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined
model
+		for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
+			opt_0.register(p, specs)
+		for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
+			opt_1.register(p, specs)
+
+		for epoch in range(self.epochs):
+			for d_step in range(self.d_steps):
+				idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+				real_imgs = train_data[idx]
+				real_imgs = tensor.from_numpy(real_imgs)
+				real_imgs.to_device(self.dev)
+				noise = tensor.Tensor((self.batch_size, self.noise_size))
+				noise.uniform(-1, 1)
+				noise.to_device(self.dev)
+				fake_imgs = self.gen_net.forward(flag=False, x=noise)
+				substrahend = tensor.Tensor((real_imgs.shape[0], 1))
+				substrahend.set_value(1.0)
+				substrahend.to_device(self.dev)
+				grads, (d_loss_real, _) = self.dis_net.train(real_imgs, substrahend)
+				for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+					opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+				substrahend.set_value(-1.0)
+				grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, substrahend)
+				for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+					opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+				d_loss = d_loss_real + d_loss_fake
+			
+			for g_step in range(self.g_steps): 
+				noise = tensor.Tensor((self.batch_size, self.noise_size))
+				noise.uniform(-1, 1)
+				noise.to_device(self.dev)
+				substrahend = tensor.Tensor((real_imgs.shape[0], 1))
+				substrahend.set_value(0.0)
+				substrahend.to_device(self.dev)
+				grads, (g_loss, _) = self.combined_net.train(noise, substrahend)
+				for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
+					opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+			
+			if epoch % self.interval == 0:
+				self.save_image(epoch)
+				print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+
+	def save_image(self, epoch):
+		rows = 5
+		cols = 5
+		channels = self.channels
+		noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+		noise.uniform(-1,1)
+		noise.to_device(self.dev)
+		gen_imgs = self.gen_net.forward(flag=False, x=noise)
+		gen_imgs = tensor.to_numpy(gen_imgs)
+		show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+		fig, axs = plt.subplots(rows, cols)
+		cnt = 0
+		for r in range(rows):
+			for c in range(cols):
+				axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+				axs[r,c].axis('off')
+				cnt += 1
+		fig.savefig("{}{}.png".format(self.file_dir, epoch))
+		plt.close()
+
+if __name__ == '__main__':
+	parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+	parser.add_argument('filepath',  type=str, help='the dataset path')
+	parser.add_argument('--use_gpu', action='store_true')
+	args = parser.parse_args()
+	
+	if args.use_gpu:
+		print('Using GPU')
+		dev = device.create_cuda_gpu()
+		layer.engine = 'cudnn'
+	else:
+		print('Using CPU')
+		dev = device.get_default_device()
+		layer.engine = 'singacpp'
+
+	if not os.path.exists('lsgan_images/'):
+		os.makedirs('lsgan_images/')
+
+	rows = 28
+	cols = 28
+	channels = 1
+	noise_size = 100
+	hidden_size = 128
+	batch = 128
+	interval = 1000
+	learning_rate = 0.001
+	epochs = 1000000
+	d_steps = 3
+	g_steps = 1
+	dataset_filepath = 'mnist.pkl.gz'
+	file_dir = 'lsgan_images/'
+	lsgan = LSGAN(dev, rows, cols, channels, noise_size, hidden_size, batch, interval, 
+		learning_rate, epochs, d_steps, g_steps, dataset_filepath, file_dir)
+	lsgan.train()
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/utils.py
----------------------------------------------------------------------
diff --git a/examples/gan/utils.py b/examples/gan/utils.py
new file mode 100644
index 0000000..050d184
--- /dev/null
+++ b/examples/gan/utils.py
@@ -0,0 +1,67 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#     http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# 
+
+import gzip
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import pickle
+import sys
+import time
+
+try:
+	import urllib.request as ul_request
+except ImportError:
+	import urllib as ul_request
+
+def print_log(s):
+    t = time.ctime()
+    print('[{}]{}'.format(t, s))
+
+def load_data(filepath):
+	with gzip.open(filepath, 'rb') as f:
+		train_set, valid_set, test_set = pickle.load(f, encoding='bytes')
+		traindata = train_set[0].astype(np.float32)
+		validdata = valid_set[0].astype(np.float32)
+		testdata = test_set[0].astype(np.float32)
+		trainlabel = train_set[1].astype(np.float32)
+		validlabel = valid_set[1].astype(np.float32)
+		testlabel = test_set[1].astype(np.float32)
+		return traindata, trainlabel, validdata, validlabel, testdata, testlabel
+
+def download_data(gzfile, url):
+	if os.path.exists(gzfile):
+		print('Downloaded already!')
+		sys.exit(0)
+	print('Downloading data %s' % (url))
+	ul_request.urlretrieve(url, gzfile)
+	print('Finished!')
+
+def show_images(filepath):
+	with open(filepath, 'rb') as f:
+		imgs = pickle.load(f)
+		r, c = 5, 5
+		fig, axs = plt.subplots(5, 5)
+		cnt = 0
+		for i in range(r):
+			for j in range(c):
+				axs[i, j].imshow(imgs[cnt, :, :, 0], cmap='gray')
+				axs[i, j].axis('off')
+				cnt += 1
+		plt.show()
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/b1610d75/examples/gan/vanilla.py
----------------------------------------------------------------------
diff --git a/examples/gan/vanilla.py b/examples/gan/vanilla.py
new file mode 100644
index 0000000..ce5e048
--- /dev/null
+++ b/examples/gan/vanilla.py
@@ -0,0 +1,207 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from singa import device
+from singa import initializer
+from singa import layer
+from singa import loss
+from singa import net as ffnet
+from singa import optimizer
+from singa import tensor
+
+import argparse
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+
+from utils import load_data
+from utils import print_log
+
+class VANILLA():
+	def  __init__(self, dev, rows=28, cols=28, channels=1, noise_size=100, hidden_size=128,
batch=128, 
+		interval=1000, learning_rate=0.001, epochs=1000000, dataset_filepath='mnist.pkl.gz', file_dir='vanilla_images/'):
+		self.dev = dev
+		self.rows = rows
+		self.cols = cols
+		self.channels = channels
+		self.feature_size = self.rows * self.cols * self.channels
+		self.noise_size = noise_size
+		self.hidden_size = hidden_size
+		self.batch = batch
+		self.batch_size = self.batch//2
+		self.interval = interval
+		self.learning_rate = learning_rate
+		self.epochs = epochs
+		self.dataset_filepath = dataset_filepath
+		self.file_dir = file_dir
+
+		self.g_w0_specs = {'init': 'xavier',}
+		self.g_b0_specs = {'init': 'constant', 'value': 0,}
+		self.g_w1_specs = {'init': 'xavier',}
+		self.g_b1_specs = {'init': 'constant', 'value': 0,}
+		self.gen_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+		self.gen_net_fc_0 = layer.Dense(name='g_fc_0', num_output=self.hidden_size, use_bias=True,

+			W_specs=self.g_w0_specs, b_specs=self.g_b0_specs, input_sample_shape=(self.noise_size,))
+		self.gen_net_relu_0 = layer.Activation(name='g_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+		self.gen_net_fc_1 = layer.Dense(name='g_fc_1', num_output=self.feature_size, use_bias=True,

+			W_specs=self.g_w1_specs, b_specs=self.g_b1_specs, input_sample_shape=(self.hidden_size,))
+		self.gen_net_sigmoid_1 = layer.Activation(name='g_relu_1', mode='sigmoid', input_sample_shape=(self.feature_size,))
+		self.gen_net.add(self.gen_net_fc_0)
+		self.gen_net.add(self.gen_net_relu_0)
+		self.gen_net.add(self.gen_net_fc_1)
+		self.gen_net.add(self.gen_net_sigmoid_1)
+		for (p, specs) in zip(self.gen_net.param_values(), self.gen_net.param_specs()):
+			filler = specs.filler
+			if filler.type == 'gaussian':
+				p.gaussian(filler.mean, filler.std)
+			elif filler.type == 'xavier':
+				initializer.xavier(p)
+			else: 
+				p.set_value(0)
+			print(specs.name, filler.type, p.l1())	
+		self.gen_net.to_device(self.dev)		
+
+		self.d_w0_specs = {'init': 'xavier',}
+		self.d_b0_specs = {'init': 'constant', 'value': 0,}
+		self.d_w1_specs = {'init': 'xavier',}
+		self.d_b1_specs = {'init': 'constant', 'value': 0,}			
+		self.dis_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(),)
+		self.dis_net_fc_0 = layer.Dense(name='d_fc_0', num_output=self.hidden_size, use_bias=True,

+			W_specs=self.d_w0_specs, b_specs=self.d_b0_specs, input_sample_shape=(self.feature_size,))
+		self.dis_net_relu_0 = layer.Activation(name='d_relu_0', mode='relu',input_sample_shape=(self.hidden_size,))
+		self.dis_net_fc_1 = layer.Dense(name='d_fc_1', num_output=1,  use_bias=True, 
+			W_specs=self.d_w1_specs, b_specs=self.d_b1_specs, input_sample_shape=(self.hidden_size,))
+		self.dis_net.add(self.dis_net_fc_0)
+		self.dis_net.add(self.dis_net_relu_0)
+		self.dis_net.add(self.dis_net_fc_1)			
+		for (p, specs) in zip(self.dis_net.param_values(), self.dis_net.param_specs()):
+			filler = specs.filler
+			if filler.type == 'gaussian':
+				p.gaussian(filler.mean, filler.std)
+			elif filler.type == 'xavier':
+				initializer.xavier(p)
+			else: 
+				p.set_value(0)
+			print(specs.name, filler.type, p.l1())
+		self.dis_net.to_device(self.dev)
+
+		self.combined_net = ffnet.FeedForwardNet(loss.SigmoidCrossEntropy(), )
+		for l in self.gen_net.layers:
+			self.combined_net.add(l)
+		for l in self.dis_net.layers:
+			self.combined_net.add(l)
+		self.combined_net.to_device(self.dev)
+
+	def train(self):
+		train_data, _, _, _, _, _ = load_data(self.dataset_filepath)
+		opt_0 = optimizer.Adam(lr=self.learning_rate) # optimizer for discriminator 
+		opt_1 = optimizer.Adam(lr=self.learning_rate) # optimizer for generator, aka the combined
model
+		for (p, specs) in zip(self.dis_net.param_names(), self.dis_net.param_specs()):
+			opt_0.register(p, specs)
+		for (p, specs) in zip(self.gen_net.param_names(), self.gen_net.param_specs()):
+			opt_1.register(p, specs)
+
+		for epoch in range(self.epochs):
+			idx = np.random.randint(0, train_data.shape[0], self.batch_size)
+			real_imgs = train_data[idx]
+			real_imgs = tensor.from_numpy(real_imgs)
+			real_imgs.to_device(self.dev)
+			noise = tensor.Tensor((self.batch_size, self.noise_size))
+			noise.uniform(-1, 1)
+			noise.to_device(self.dev)
+			fake_imgs = self.gen_net.forward(flag=False, x=noise)
+			real_labels = tensor.Tensor((self.batch_size, 1))
+			fake_labels = tensor.Tensor((self.batch_size, 1))
+			real_labels.set_value(1.0)
+			fake_labels.set_value(0.0)
+			real_labels.to_device(self.dev)
+			fake_labels.to_device(self.dev)
+			grads, (d_loss_real, _) = self.dis_net.train(real_imgs, real_labels)
+			for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+				opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+			grads, (d_loss_fake, _) = self.dis_net.train(fake_imgs, fake_labels)
+			for (s, p ,g) in zip(self.dis_net.param_names(), self.dis_net.param_values(), grads):
+				opt_0.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+			d_loss = d_loss_real + d_loss_fake
+			noise = tensor.Tensor((self.batch_size, self.noise_size))
+			noise.uniform(-1,1)
+			noise.to_device(self.dev)
+			real_labels = tensor.Tensor((self.batch_size, 1))
+			real_labels.set_value(1.0)
+			real_labels.to_device(self.dev)
+			grads, (g_loss, _) = self.combined_net.train(noise, real_labels)
+			for (s, p ,g) in zip(self.gen_net.param_names(), self.gen_net.param_values(), grads):
+				opt_1.apply_with_lr(epoch, self.learning_rate, g, p, str(s), epoch)
+			
+			if epoch % self.interval == 0:
+				self.save_image(epoch)
+				print_log('The {} epoch, G_LOSS: {}, D_LOSS: {}'.format(epoch, g_loss, d_loss))
+
+	def save_image(self, epoch):
+		rows = 5
+		cols = 5
+		channels = self.channels
+		noise = tensor.Tensor((rows*cols*channels, self.noise_size))
+		noise.uniform(-1, 1)
+		noise.to_device(self.dev)
+		gen_imgs = self.gen_net.forward(flag=False, x=noise)
+		gen_imgs = tensor.to_numpy(gen_imgs)
+		show_imgs = np.reshape(gen_imgs, (gen_imgs.shape[0], self.rows, self.cols, self.channels))
+		fig, axs = plt.subplots(rows, cols)
+		cnt = 0
+		for r in range(rows):
+			for c in range(cols):
+				axs[r,c].imshow(show_imgs[cnt, :, :, 0], cmap='gray')
+				axs[r,c].axis('off')
+				cnt += 1
+		fig.savefig("{}{}.png".format(self.file_dir, epoch))
+		plt.close()
+
+if __name__ == '__main__':
+	parser = argparse.ArgumentParser(description='Train GAN over MNIST')
+	parser.add_argument('filepath',  type=str, help='the dataset path')
+	parser.add_argument('--use_gpu', action='store_true')
+	args = parser.parse_args()
+	
+	if args.use_gpu:
+		print('Using GPU')
+		dev = device.create_cuda_gpu()
+		layer.engine = 'cudnn'
+	else:
+		print('Using CPU')
+		dev = device.get_default_device()
+		layer.engine = 'singacpp'
+
+	if not os.path.exists('vanilla_images/'):
+		os.makedirs('vanilla_images/')
+
+	rows = 28
+	cols = 28
+	channels = 1
+	noise_size = 100
+	hidden_size = 128
+	batch = 128
+	interval = 1000
+	learning_rate = 0.001
+	epochs = 1000000
+	dataset_filepath = 'mnist.pkl.gz'
+	file_dir = 'vanilla_images/'
+	vanilla = VANILLA(dev, rows, cols, channels, noise_size, hidden_size, batch, 
+		interval, learning_rate, epochs, dataset_filepath, file_dir)
+	vanilla.train()
\ No newline at end of file


Mime
View raw message