mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-mxnet] kedarbellare commented on a change in pull request #15023: Extend Clojure BERT example
Date Sun, 02 Jun 2019 16:38:23 GMT
kedarbellare commented on a change in pull request #15023: Extend Clojure BERT example
URL: https://github.com/apache/incubator-mxnet/pull/15023#discussion_r289649517
 
 

 ##########
 File path: contrib/clojure-package/examples/bert/src/bert/bert_sentence_classification.clj
 ##########
 @@ -157,4 +183,42 @@
 (comment
 
   (train (context/cpu 0) 3)
-  (m/save-checkpoint model {:prefix "fine-tune-sentence-bert" :epoch 3}))
+
+  (m/save-checkpoint model {:prefix fine-tuned-prefix :epoch 3})
+
+  
+  ;;;; Explore results from the fine-tuned model
+
+  ;; We need a predictor with a batch size of 1, so we can feed the
+  ;; model a single sentence pair.
+  (def fine-tuned-predictor
+    (infer/create-predictor (infer/model-factory fine-tuned-prefix
+                                                 [{:name "data0" :shape [1 seq-length] :dtype
dtype/FLOAT32 :layout layout/NT}
+                                                  {:name "data1" :shape [1 seq-length] :dtype
dtype/FLOAT32 :layout layout/NT}
+                                                  {:name "data2" :shape [1]            :dtype
dtype/FLOAT32 :layout layout/N}])
+                            {:epoch 3}))
+  
+  ;; Get the fine-tuned model's opinion on whether two sentences are equivalent:
+  (defn predict-equivalence
+    [predictor sentence1 sentence2]
+    (let [vocab (bert.util/get-vocab)
+          processed-test-data (mapv #(pre-processing (:idx->token vocab)
+                                                     (:token->idx vocab) %)
+                                    [[sentence1 sentence2]])
+          prediction (infer/predict-with-ndarray predictor
+                                                 [(ndarray/array (slice-inputs-data processed-test-data
0) [1 seq-length])
+                                                  (ndarray/array (slice-inputs-data processed-test-data
1) [1 seq-length])
+                                                  (ndarray/array (slice-inputs-data processed-test-data
2) [1])])]
+      (ndarray/->vec (first prediction))))
+
+  ;; Modify an existing sentence pair to test:
+  ;; ["1"
+  ;;  "69773"
+  ;;  "69792"
+  ;;  "Cisco pared spending to compensate for sluggish sales ."
+  ;;  "In response to sluggish sales , Cisco pared spending ."]
+  (predict-equivalence fine-tuned-predictor
 
 Review comment:
   did you want to add a test for this?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message