Revert D25719980: [pytorch][PR] Accept input tensor with 0-dim batch size for MultiLabelMarginLoss

Test Plan: revert-hammer

Differential Revision:
D25719980 (https://github.com/pytorch/pytorch/commit/6b56b71e61e14bf4de5b371f0d8f2f2029065b31)

Original commit changeset: 83414bad37c0

fbshipit-source-id: 27eddd711a2b9e0adbc08bfab12100562e63ac21
diff --git a/aten/src/ATen/native/LossMulti.h b/aten/src/ATen/native/LossMulti.h
deleted file mode 100644
index 4282c34..0000000
--- a/aten/src/ATen/native/LossMulti.h
+++ /dev/null
@@ -1,72 +0,0 @@
-#include <ATen/ATen.h>
-#include <ATen/Dispatch.h>
-#include <ATen/AccumulateType.h>
-
-#pragma once
-
-namespace at { namespace native {
-namespace {
-  static void multilabel_margin_loss_shape_check(
-    int64_t& nframe,
-    int64_t& dim,
-    const int64_t& ndims,
-    TensorArg& target_arg,
-    const Tensor& input,
-    const Tensor& target) {
-    bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
-    TORCH_CHECK(
-                valid_inputs,
-                "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
-                input.sizes());
-
-    if (ndims <= 1) {
-      nframe = 1;
-      dim = ndims == 0 ? 1 : input.size(0);
-      TORCH_CHECK(
-                  valid_inputs && target.dim() <= 1 && target.numel() == dim,
-                  "inconsistent size ",
-                  target.sizes(),
-                  " for ",
-                  target_arg);
-    } else {
-      nframe = input.size(0);
-      dim = input.size(1);
-      TORCH_CHECK(
-                  valid_inputs && target.dim() == 2 && target.size(0) == nframe &&
-                  target.size(1) == dim,
-                  "inconsistent size ",
-                  target.sizes(),
-                  " for ",
-                  target_arg);
-    }
-  }
-
-  static void multi_margin_loss_shape_check(
-    int64_t& nframe,
-    int64_t& dim,
-    const int64_t& ndims,
-    TensorArg& target_arg,
-    const Tensor& input,
-    const Tensor& target) {
-    bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
-    if (ndims <= 1) {
-      nframe = 1;
-      dim = ndims == 0 ? 1 : input.size(0);
-    } else {
-      nframe = input.size(0);
-      dim = input.size(1);
-    }
-    
-    TORCH_CHECK(
-                valid_inputs,
-                "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
-                input.sizes());
-    TORCH_CHECK(
-                valid_inputs && target.dim() <= 1 && target.numel() == nframe,
-                "inconsistent target size, got: ",
-                target.sizes());
-  }
-
-
-}  // anonymous namespace
-}} // namespace at::native
diff --git a/aten/src/ATen/native/LossMultiLabelMargin.cpp b/aten/src/ATen/native/LossMultiLabelMargin.cpp
index 3cd0f46..9582bf6 100644
--- a/aten/src/ATen/native/LossMultiLabelMargin.cpp
+++ b/aten/src/ATen/native/LossMultiLabelMargin.cpp
@@ -2,7 +2,6 @@
 #include <ATen/AccumulateType.h>
 #include <ATen/Dispatch.h>
 #include <ATen/TensorUtils.h>
-#include <ATen/native/LossMulti.h>
 
 namespace at {
 namespace native {
@@ -40,7 +39,6 @@
       }
     }
   }
-
   return sum;
 }
 
@@ -102,32 +100,34 @@
     Tensor& is_target,
     int64_t reduction) {
   auto target_arg = TensorArg(target, "target", 2);
+
+  const auto ndims = input.dim();
+
+  TORCH_CHECK(
+      input.numel() > 0 && ndims <= 2,
+      "non-empty vector or matrix expected, got size: ",
+      input.sizes());
+
   int64_t nframe, dim;
-  const int64_t ndims = input.dim();
   if (ndims <= 1) {
     nframe = 1;
     dim = ndims == 0 ? 1 : input.size(0);
-  }
-  else {
+    TORCH_CHECK(
+        target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
+        "inconsistent size ",
+        target.sizes(),
+        " for ",
+        target_arg);
+  } else {
     nframe = input.size(0);
     dim = input.size(1);
-  }
-  multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
-
-  // special case target.dim() <= 1: produce scalar output for scalar inputs
-  // even if reduction == Reduction::None
-  if (reduction != Reduction::None || target.dim() <= 1) {
-    output.resize_({});
-  } else {
-    output.resize_({nframe});
-  }
-
-  is_target.resize_as_(target);
-  TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
-  is_target.zero_();
-
-  if (input.numel() == 0) {
-    return;
+    TORCH_CHECK(
+        target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
+            target.size(1) == dim,
+        "inconsistent size ",
+        target.sizes(),
+        " for ",
+        target_arg);
   }
 
   TORCH_CHECK(
@@ -138,6 +138,18 @@
   auto input_contiguous = input.contiguous();
   auto target_contiguous = target.contiguous();
 
+  is_target.resize_as_(target);
+  TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
+  is_target.zero_();
+
+  // special case target.dim() <= 1: produce scalar output for scalar inputs
+  // even if reduction == Reduction::None
+  if (reduction != Reduction::None || target.dim() <= 1) {
+    output.resize_({});
+  } else {
+    output.resize_({nframe});
+  }
+
   AT_DISPATCH_FLOATING_TYPES(
       input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
         multilabel_margin_loss_forward_out_frame<scalar_t>(
@@ -220,22 +232,39 @@
     const Tensor& target,
     int64_t reduction,
     const Tensor& is_target) {
-  int64_t nframe, dim;
   CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
   auto target_arg = TensorArg(target, "target", 3);
   auto is_target_arg = TensorArg(is_target, "is_target", 5);
-  const int64_t ndims = input.dim();
 
-  multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
-  checkSameSize(c, target_arg, is_target_arg);
+  const auto ndims = input.dim();
 
-  grad_input.resize_as_(input);
-  if (grad_input.numel() == 0) {
-    return;
+  TORCH_CHECK(
+      input.numel() > 0 && ndims <= 2,
+      "non-empty vector or matrix expected, got size: ",
+      input.sizes());
+
+  int64_t nframe, dim;
+  if (ndims <= 1) {
+    nframe = 1;
+    dim = ndims == 0 ? 1 : input.size(0);
+    TORCH_CHECK(
+        target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
+        "inconsistent size ",
+        target.sizes(),
+        " for ",
+        target_arg);
+  } else {
+    nframe = input.size(0);
+    dim = input.size(1);
+    TORCH_CHECK(
+        target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
+            target.size(1) == dim,
+        "inconsistent size ",
+        target.sizes(),
+        " for ",
+        target_arg);
   }
-
-  TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
-  grad_input.zero_();
+  checkSameSize(c, target_arg, is_target_arg);
 
   TORCH_CHECK(
       target.min().item<int64_t>() >= -1, target_arg, " is out of range");
@@ -246,6 +275,10 @@
   auto target_contiguous = target.contiguous();
   auto is_target_contiguous = is_target.contiguous();
 
+  grad_input.resize_as_(input);
+  TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
+  grad_input.zero_();
+
   AT_DISPATCH_FLOATING_TYPES(
       input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
         multilabel_margin_loss_backward_out_frame<scalar_t>(
diff --git a/aten/src/ATen/native/LossMultiMargin.cpp b/aten/src/ATen/native/LossMultiMargin.cpp
index db18d1f..48446a9 100644
--- a/aten/src/ATen/native/LossMultiMargin.cpp
+++ b/aten/src/ATen/native/LossMultiMargin.cpp
@@ -1,7 +1,6 @@
 #include <ATen/ATen.h>
 #include <ATen/Dispatch.h>
 #include <ATen/AccumulateType.h>
-#include <ATen/native/LossMulti.h>
 
 namespace at {
 namespace native {
@@ -94,13 +93,27 @@
     Scalar margin,
     const Tensor& weight,
     int64_t reduction) {
-  int64_t nframe, dim;
   const auto ndims = input.dim();
-  auto target_arg = TensorArg(target, "target", 2);
+  TORCH_CHECK(
+      input.numel() > 0 && ndims <= 2,
+      "non-empty vector or matrix expected, got size: ",
+      input.sizes());
 
   TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
 
-  multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
+  int64_t nframe, dim;
+  if (ndims <= 1) {
+    nframe = 1;
+    dim = ndims == 0 ? 1 : input.size(0);
+  } else {
+    nframe = input.size(0);
+    dim = input.size(1);
+  }
+
+  TORCH_CHECK(
+      target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
+      "inconsistent target size, got: ",
+      target.sizes());
 
   // produce a scalar output for 1d input
   if (reduction == Reduction::None && target.dim() > 0) {
@@ -108,9 +121,6 @@
   } else {
     output.resize_({});
   }
-  if (input.numel() == 0) {
-    return;
-  }
 
   auto input_contiguous = input.contiguous();
   auto target_contiguous = target.contiguous();
@@ -202,13 +212,28 @@
     Scalar margin,
     const Tensor& weight,
     int64_t reduction) {
-  int64_t nframe, dim;
-  auto target_arg = TensorArg(target, "target", 2);
   const auto ndims = input.dim();
+  TORCH_CHECK(
+      input.numel() > 0 && ndims <= 2,
+      "non-empty vector or matrix expected, got size: ",
+      input.sizes());
 
   TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");
 
-  multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
+  int64_t nframe, dim;
+  if (ndims <= 1) {
+    nframe = 1;
+    dim = ndims == 0 ? 1 : input.size(0);
+  } else {
+    nframe = input.size(0);
+    dim = input.size(1);
+  }
+
+  TORCH_CHECK(
+      target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
+      "inconsistent target size, got: ",
+      target.sizes());
+
   grad_input.resize_as_(input);
   TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
 
diff --git a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu
index 6e8d9bc..ab8d2cb 100644
--- a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu
+++ b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu
@@ -3,30 +3,21 @@
 #else
 
 static inline void THNN_(MultiLabelMarginCriterion_shapeCheck)(
-  THCState *state,
-  THCTensor *input, THCTensor *target) {
-  int64_t ndims = input->dim();
-  bool valid_inputs = (ndims == 2 && input->size(1) != 0) || (ndims == 1 && input->size(0) != 0) || ndims == 0;
-  TORCH_CHECK(
-    valid_inputs,
-    "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
-    input->sizes());
-
-  if (ndims <= 1) {
+                         THCState *state,
+                         THCTensor *input, THCTensor *target) {
+  if (input->dim() <= 1) {
     int dim = input->dim() == 0 ? 1 : input->size(0);
     int target_size = target->dim() == 0 ? 1 : target->size(0);
-
-    TORCH_CHECK(valid_inputs && target->dim() <= 1 && target->numel() == dim,
-      "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes());
-  } else if (ndims == 2) {
+    TORCH_CHECK(!target->is_empty() && (target->dim() <= 1) && (target_size == dim),
+                "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes());
+  } else if (input->dim() == 2) {
     int nframe = input->size(0);
     int dim = input->size(1);
-
-    TORCH_CHECK(
-      valid_inputs && target->dim() == 2 && target->size(0) == nframe && target->size(1) == dim,
-      "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes());
+    TORCH_CHECK(!target->is_empty() && (target->dim() == 2)
+                && (target->size(0) == nframe) && (target->size(1) == dim),
+                "inconsistent target size: ", target->sizes(), " for input of size: ", input->sizes());
   } else {
-    TORCH_CHECK(false, "Expected input of ndims <= 2, but got ndims: ", ndims);
+    TORCH_CHECK(false, "non-empty vector or matrix expected, got size: ", input->sizes());
   }
 }
 
@@ -40,9 +31,6 @@
            int64_t reduction)
 {
   THNN_(MultiLabelMarginCriterion_shapeCheck)(state, input, target);
-  if (input->numel() == 0) {
-    return;
-  }
   input = THCTensor_(newContiguous)(state, input);
   target = THCIndexTensor_(newContiguous)(state, target);
   istarget = THCTensor_(newContiguous)(state, istarget);
@@ -112,8 +100,7 @@
     }
   }
   else {
-    TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", 
-      input->sizes());
+    TORCH_INTERNAL_ASSERT(false, "non-empty vector or matrix expected (shouldn't get here)");
   }
 
   THCTensor_(free)(state, input);
@@ -130,17 +117,11 @@
             THCTensor *istarget,
             int64_t reduction)
 {
-  THNN_(MultiLabelMarginCriterion_shapeCheck)(state, input, target);
   input = THCTensor_(newContiguous)(state, input);
-  THCTensor_(resizeAs)(state, gradInput, input);
-  if (input->numel() == 0) {
-    THCTensor_(free)(state, input);
-    return;
-  }
-
   target = THCIndexTensor_(newContiguous)(state, target);
   istarget = THCTensor_(newContiguous)(state, istarget);
   gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+  THCTensor_(resizeAs)(state, gradInput, input);
 
   if(gradInput->dim() <= 1)
   {
@@ -168,11 +149,10 @@
   {
     int nframe = gradInput->size(0);
     int dim = gradInput->size(1);
-    THArgCheck((input->size(1) != 0) && (target->dim() == 2) && (target->size(0) == nframe)
+    THArgCheck(!target->is_empty() && (target->dim() == 2) && (target->size(0) == nframe)
                && (target->size(1) == dim), 3, "inconsistent target size");
-    THArgCheck((istarget->dim() == 2) && (istarget->size(0) == nframe)
+    THArgCheck(!istarget->is_empty() && (istarget->dim() == 2) && (istarget->size(0) == nframe)
                && (istarget->size(1) == dim), 3, "inconsistent isTarget size");
-
     dim3 blocks(gradInput->size(0));
     dim3 threads(MULTILABELMARGIN_THREADS);
 
@@ -188,8 +168,7 @@
         reduction != at::Reduction::None);
   }
   else {
-    TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
-      gradInput->sizes());
+    AT_ERROR("non-empty vector or matrix expected, got size: ", gradInput->sizes());
   }
 
   THCudaCheck(cudaGetLastError());
diff --git a/aten/src/THCUNN/generic/MultiMarginCriterion.cu b/aten/src/THCUNN/generic/MultiMarginCriterion.cu
index 129413f..f2df150 100644
--- a/aten/src/THCUNN/generic/MultiMarginCriterion.cu
+++ b/aten/src/THCUNN/generic/MultiMarginCriterion.cu
@@ -2,30 +2,6 @@
 #define THC_GENERIC_FILE "THCUNN/generic/MultiMarginCriterion.cu"
 #else
 
-static inline void THNN_(MultiMarginCriterion_shapeCheck)(
-                         THCState *state,
-                         THCTensor *input, THCTensor *target) {
-  int64_t nframe, dim;
-  int64_t ndims = input->dim();
-  bool valid_inputs = (ndims == 2 && input->size(1) != 0) || (ndims == 1 && input->size(0) != 0) || ndims == 0;
-  if (ndims <= 1) {
-    nframe = 1;
-    dim = ndims == 0 ? 1 : input->size(0);
-  } else {
-    nframe = input->size(0);
-    dim = input->size(1);
-  }
-
-  TORCH_CHECK(
-    valid_inputs,
-    "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
-    input->sizes());
-  TORCH_CHECK(
-      valid_inputs && target->dim() <= 1 && target->numel() == nframe,
-      "inconsistent target size, got: ",
-      target->sizes());
-}
-
 // TODO: improve error messages
 void THNN_(MultiMarginCriterion_updateOutput)(
            THCState *state,
@@ -37,10 +13,6 @@
            THCTensor *weights,
            accreal margin_)
 {
-  THNN_(MultiMarginCriterion_shapeCheck)(state, input, target);
-  if (input->numel() == 0) {
-    return;
-  }
   scalar_t margin = ScalarConvert<accreal, scalar_t>::to(margin_);
   THCUNN_assertSameGPU(state, 2, input, target);
   input = THCTensor_(newContiguous)(state, input);
@@ -87,8 +59,7 @@
   else if (input->dim() == 2)
   {
     int nframe = input->size(0);
-    // allow zero-dim target for 2D input.
-    THArgCheck((input->size(1) != 0) && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3,
+    THArgCheck(!target->is_empty() && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3,
                "inconsistent target size");
     dim3 blocks(input->size(0));
     dim3 threads(MULTIMARGIN_THREADS);
@@ -159,8 +130,7 @@
   }
   else
   {
-    TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
-    input->sizes());
+    AT_ERROR("non-empty vector or matrix expected, got sizes: ", input->sizes());
   }
 
   THCTensor_(free)(state, input);
@@ -179,17 +149,11 @@
            THCTensor *weights,
            accreal margin_)
 {
-  THNN_(MultiMarginCriterion_shapeCheck)(state, input, target);
-  input = THCTensor_(newContiguous)(state, input);
-  THCTensor_(resizeAs)(state, gradInput, input);
-  if (input->numel() == 0) {
-    THCTensor_(free)(state, input);
-    return;
-  }
   scalar_t margin = ScalarConvert<accreal, scalar_t>::to(margin_);
   THCUNN_assertSameGPU(state, 3, input, gradInput, target);
+  input = THCTensor_(newContiguous)(state, input);
   gradOutput = THCTensor_(newContiguous)(state, gradOutput);
-
+  THCTensor_(resizeAs)(state, gradInput, input);
   if(weights)
     weights = THCTensor_(newContiguous)(state, weights);
 
@@ -231,7 +195,7 @@
   else if (input->dim() == 2)
   {
     int nframe = gradInput->size(0);
-    THArgCheck((input->size(1) != 0) && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3,
+    THArgCheck(!target->is_empty() && (THTensor_nDimensionLegacyNoScalars(target) == 1) && (THTensor_sizeLegacyNoScalars(target, 0) == nframe), 3,
                "inconsistent target size");
     dim3 blocks(gradInput->size(0));
     dim3 threads(MULTIMARGIN_THREADS);
@@ -268,8 +232,7 @@
   }
   else
   {
-    TORCH_CHECK(false, "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", 
-    input->sizes());
+    AT_ERROR("non-empty vector or matrix expected, got ", input->sizes());
   }
 
   THCTensor_(free)(state, input);
diff --git a/test/test_nn.py b/test/test_nn.py
index ef9ea4c..386ba36 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10864,35 +10864,6 @@
             inp = torch.randn(3, 0, 10, 10, device=device)
             mod(inp)
 
-
-    @onlyOnCPUAndCUDA
-    @dtypes(torch.float, torch.double)
-    def test_MarginLoss_empty(self, device, dtype):
-        for mod, x, y in [
-                (torch.nn.MultiMarginLoss().to(device),
-                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
-                 torch.ones(0, device=device).type(torch.long)),
-                (torch.nn.MultiLabelMarginLoss().to(device),
-                 torch.randn(0, 10, requires_grad=True, device=device, dtype=dtype),
-                 torch.ones(0, 10, device=device).type(torch.long))]:
-
-            out = mod(x, y)
-            out.sum().backward()
-
-            self.assertEqual(x, torch.zeros_like(x))
-            self.assertEqual(x.grad, torch.zeros_like(x))
-
-            with self.assertRaisesRegex(RuntimeError, 'Expected'):
-                x = torch.randn(0, requires_grad=True, device=device, dtype=dtype)
-                y = torch.ones(10, device=device).type(torch.long)
-                mod(x, y)
-
-            with self.assertRaisesRegex(RuntimeError, 'Expected'):
-                x = torch.randn(10, 0, requires_grad=True, device=device, dtype=dtype)
-                y = torch.ones(10, 0, device=device).type(torch.long)
-                mod(x, y)
-
-
     @onlyOnCPUAndCUDA
     def test_Unfold_empty(self, device):
         inp = torch.randn(0, 3, 3, 4, device=device)