# mxnet-commits mailing list archives

##### Site index · List index
Message view
Top
From GitBox <...@apache.org>
Subject [GitHub] szha commented on a change in pull request #12750: [MXNET -1030] Cosine Embedding Loss
Date Thu, 25 Oct 2018 23:03:05 GMT
szha commented on a change in pull request #12750: [MXNET -1030] Cosine Embedding Loss
URL: https://github.com/apache/incubator-mxnet/pull/12750#discussion_r228362925

##########
File path: python/mxnet/gluon/loss.py
##########
@@ -767,3 +767,69 @@ def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
loss += stirling_factor
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss)
+
+
+class CosineEmbeddingLoss(Loss):
+    r"""For a target label 1 or -1, vectors target and pred, the function computes the cosine
distance
+    between the vectors. This can be interpretted as how similar/dissimilar two input vectors
are.
+
+    .. math::
+
+        L = \sum_i \begin{cases} 1 - {cos\_sim({input1}_i, {input2}_i)} & \text{ if }
{label}_i = 1\\
+                         {cos\_sim({input1}_i, {input2}_i)} & \text{ if } {label}_i =
-1 \end{cases}\\
+        cos\_sim(input1, input2) = \frac{{input1}_i.{input2}_i}{||{input1}_i||.||{input2}_i||}
+
+    input1, input2 can have arbitrary shape as long as they have the same number of elements.
+
+    Parameters
+    ----------
+    weight : float or None
+        Global scalar weight for loss.
+    batch_axis : int, default 0
+        The axis that represents mini-batch.
+    margin : float
+        Margin of separation between correct and incorrect pair.
+
+
+    Inputs:
+        - **input1**: a tensor with arbitrary shape
+        - **input2**: another tensor with same shape as pred to which input1 is
+          compared for similarity and loss calculation
+        - **sample_weight**: element-wise weighting tensor. Must be broadcastable

Review comment:
This needs to be added to hybrid_forward. Could you also put this after label?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.