ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dlig...@apache.org
Subject svn commit: r1762139 - in /ctakes/trunk/ctakes-temporal/scripts/nn: classify.sh lstm_classify.py lstm_train.py train.sh
Date Sat, 24 Sep 2016 16:09:18 GMT
Author: dligach
Date: Sat Sep 24 16:09:18 2016
New Revision: 1762139

URL: http://svn.apache.org/viewvc?rev=1762139&view=rev
Log:
now trying lstm on pos tags

Added:
    ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py
    ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py
Modified:
    ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh
    ctakes/trunk/ctakes-temporal/scripts/nn/train.sh

Modified: ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh?rev=1762139&r1=1762138&r2=1762139&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/classify.sh Sat Sep 24 16:09:18 2016
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 source $(dirname $0)/env/bin/activate
-python $(dirname $0)/cnn_classify.py $*
+python $(dirname $0)/lstm_classify.py $*
 ret=$?
 deactivate
 exit $ret

Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py?rev=1762139&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_classify.py Sat Sep 24 16:09:18 2016
@@ -0,0 +1,63 @@
+#!python
+
+from keras.models import Sequential, model_from_json
+import numpy as np
+import et_cleartk_io as ctk_io
+import sys
+import os.path
+import pickle
+from keras.preprocessing.sequence import pad_sequences
+
+def main(args):
+  if len(args) < 1:
+      sys.stderr.write("Error - one required argument: <model directory>\n")
+      sys.exit(-1)
+  working_dir = args[0]
+
+  target_dir = 'ctakes-temporal/target/eval/thyme/train_and_test/event-time/'
+  model_dir = os.path.join(os.environ['CTAKES_ROOT'], target_dir)
+  maxlen   = pickle.load(open(os.path.join(model_dir, "maxlen.p"), "rb"))
+  word2int = pickle.load(open(os.path.join(model_dir, "word2int.p"), "rb"))
+  label2int = pickle.load(open(os.path.join(model_dir, "label2int.p"), "rb"))
+  model = model_from_json(open(os.path.join(model_dir, "model_0.json")).read())
+  model.load_weights(os.path.join(model_dir, "model_0.h5"))
+
+  int2label = {}
+  for label, integer in label2int.items():
+    int2label[integer] = label
+
+  while True:
+      try:
+          line = sys.stdin.readline().rstrip()
+          if not line:
+              break
+
+          feats = []
+          for unigram in line.rstrip().split():
+              if unigram in word2int:
+                  feats.append(word2int[unigram])
+              else:
+                  # TODO: 'none' is not in vocabulary!
+                  feats.append(word2int['none'])
+                    
+          if len(feats) > maxlen:
+              feats=feats[0:maxlen]
+          test_x = pad_sequences([feats], maxlen=maxlen)
+          out = model.predict(test_x, batch_size=50)[0]
+
+      except KeyboardInterrupt:
+          sys.stderr.write("Caught keyboard interrupt\n")
+          break
+
+      if line == '':
+          sys.stderr.write("Encountered empty string so exiting\n")
+          break
+
+      out_str = int2label[out.argmax()]
+      print out_str
+      sys.stdout.flush()
+
+  sys.exit(0)
+
+if __name__ == "__main__":
+  main(sys.argv[1:])

Added: ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py?rev=1762139&view=auto
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py (added)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/lstm_train.py Sat Sep 24 16:09:18 2016
@@ -0,0 +1,75 @@
+#!/usr/bin/env python
+
+import sklearn as sk
+import numpy as np
+np.random.seed(1337)
+import et_cleartk_io as ctk_io
+import nn_models
+import sys
+import os.path
+import dataset
+import keras as k
+from keras.utils.np_utils import to_categorical
+from keras.optimizers import RMSprop
+from keras.preprocessing.sequence import pad_sequences
+from keras.models import Sequential
+from keras.layers.core import Dense, Dropout, Activation
+from keras.layers.embeddings import Embedding
+from keras.layers import LSTM
+import pickle
+
+def main(args):
+  if len(args) < 1:
+      sys.stderr.write("Error - one required argument: <data directory>\n")
+      sys.exit(-1)
+  working_dir = args[0]
+  data_file = os.path.join(working_dir, 'training-data.liblinear')
+
+  # learn alphabet from training data
+  provider = dataset.DatasetProvider(data_file)
+  # now load training examples and labels
+  train_x, train_y = provider.load(data_file)
+  # turn x and y into numpy array among other things
+  maxlen = max([len(seq) for seq in train_x])
+  classes = len(set(train_y))
+
+  train_x = pad_sequences(train_x, maxlen=maxlen)
+  train_y = to_categorical(np.array(train_y), classes)
+
+  pickle.dump(maxlen, open(os.path.join(working_dir, 'maxlen.p'),"wb"))
+  pickle.dump(provider.word2int, open(os.path.join(working_dir, 'word2int.p'),"wb"))
+  pickle.dump(provider.label2int, open(os.path.join(working_dir, 'label2int.p'),"wb"))
+
+  print 'train_x shape:', train_x.shape
+  print 'train_y shape:', train_y.shape
+
+  model = Sequential()
+    
+  model.add(Embedding(len(provider.word2int),
+                      300,
+                      input_length=maxlen,
+                      dropout=0.25))
+  model.add(LSTM(128,
+                 dropout_W = 0.20,
+                 dropout_U = 0.20))
+  model.add(Dense(classes))
+  model.add(Activation('softmax'))
+
+  optimizer = RMSprop(lr=0.001, rho=0.9, epsilon=1e-08)
+  model.compile(loss='categorical_crossentropy',
+                optimizer=optimizer,
+                metrics=['accuracy'])
+  model.fit(train_x,
+            train_y,
+            nb_epoch=1,
+            batch_size=50,
+            verbose=1,
+            validation_split=0.1)
+
+  json_string = model.to_json()
+  open(os.path.join(working_dir, 'model_0.json'), 'w').write(json_string)
+  model.save_weights(os.path.join(working_dir, 'model_0.h5'), overwrite=True)
+  sys.exit(0)
+
+if __name__ == "__main__":
+  main(sys.argv[1:])

Modified: ctakes/trunk/ctakes-temporal/scripts/nn/train.sh
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/scripts/nn/train.sh?rev=1762139&r1=1762138&r2=1762139&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/scripts/nn/train.sh (original)
+++ ctakes/trunk/ctakes-temporal/scripts/nn/train.sh Sat Sep 24 16:09:18 2016
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 source $(dirname $0)/env/bin/activate
-python $(dirname $0)/cnn_train.py $*
+python $(dirname $0)/lstm_train.py $*
 ret=$?
 deactivate
 exit $ret



Mime
View raw message