add assert for labels for spatial case
Differential Revision: D4570726
fbshipit-source-id: fe73c7f0dfa3b5d5ad50b2a1ed651f520e609985
diff --git a/caffe2/operators/softmax_with_loss_op.cc b/caffe2/operators/softmax_with_loss_op.cc
index 5b880b9..b7e3309 100644
--- a/caffe2/operators/softmax_with_loss_op.cc
+++ b/caffe2/operators/softmax_with_loss_op.cc
@@ -121,7 +121,7 @@
for (int i = 0; i < N; ++i) {
CAFFE_ENFORCE(
- label_data[i] < D,
+ label_data[i] < D && label_data[i] >= 0,
"Label seems incorrect: label value larger than number of classes: ",
label_data[i],
" vs ",
@@ -214,6 +214,12 @@
int label_idx = i * H * W + y * W + x;
int label = label_data[label_idx];
if (label != DONT_CARE) {
+ CAFFE_ENFORCE(
+ label < D && label >= 0,
+ "Label seems incorrect: label value larger than number of classes: ",
+ label_data[i],
+ " vs ",
+ D);
int idx = i * (H * W * D) + label * (H * W) + y * W + x;
float w = weights ? weights[label_idx] : 1.0;
total_weight += w;
diff --git a/caffe2/operators/softmax_with_loss_op.cu b/caffe2/operators/softmax_with_loss_op.cu
index 2d67097..16118e4 100644
--- a/caffe2/operators/softmax_with_loss_op.cu
+++ b/caffe2/operators/softmax_with_loss_op.cu
@@ -140,6 +140,7 @@
const int label = static_cast<int>(label_data[index]);
if (label != DONTCARE) {
+ CUDA_KERNEL_ASSERT(label >= 0 && label < D);
float weight = (weights == NULL ? 1.0 : weights[index]);
loss_data[index] = -log(max(
Pdata[i * W * H * D + label * W * H + y * W + x], 1e-20f)) * weight;