fix zero-batch handling in convtranspose (#24341)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24341

ConvTransposeOp doesn't crash for zero-batch, but it doesn't modify the output blob. This leads to buggy behaviour especially when running the same network twice using different input, or backprop during training.

Seems `ConvTransposeUnpoolBase<Context>::GetOutputSize` works for zero-batch, so I remove the check for `input.numel() > 0`, and reshape the output blob before returning.

For CudnnConvTransposeGradientOp, it's a bit verbose to set `dfilter` and `dbias`, it's a  seems the Cudnn can handle it, so simply remove the `X.numel() == 0` branch.

Test Plan: buck test mode/dev-nosan caffe2/caffe2/python/operator_test:conv_transpose_test -- --run-disabled

Reviewed By: BIT-silence

Differential Revision: D16807606

fbshipit-source-id: 0d72c5bd8f2e03c34465e7b530cca548d9bdd5e1
diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc
index 1446404..c00e43e 100644
--- a/caffe2/operators/conv_transpose_op_cudnn.cc
+++ b/caffe2/operators/conv_transpose_op_cudnn.cc
@@ -186,14 +186,14 @@
       LOG(FATAL) << "Unknown storage order: " << order_;
   }
 
+  auto sizes = ConvTransposeUnpoolBase<CUDAContext>::GetOutputSize(X, C);
+  auto* Y = Output(0, sizes, at::dtype<T>());
+
   if (X.numel() == 0) {
     VLOG(2) << "Number on elements is 0 in CudnnConvTransposeOp";
     return true;
   }
 
-  auto sizes = ConvTransposeUnpoolBase<CUDAContext>::GetOutputSize(X, C);
-  auto* Y = Output(0, sizes, at::dtype<T>());
-
   int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
   switch (order_) {
     case StorageOrder::NHWC:
@@ -468,11 +468,6 @@
       LOG(FATAL) << "Unknown storage order: " << order_;
   }
 
-  if (X.numel() == 0) {
-    VLOG(2) << "Number of elements is 0 in CudnnConvTransposeOp";
-    return true;
-  }
-
   int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
   switch (order_) {
     case StorageOrder::NHWC:
diff --git a/caffe2/operators/conv_transpose_op_impl.h b/caffe2/operators/conv_transpose_op_impl.h
index 7caba35..e3be8a8 100644
--- a/caffe2/operators/conv_transpose_op_impl.h
+++ b/caffe2/operators/conv_transpose_op_impl.h
@@ -42,13 +42,13 @@
       filter.dim32(3),
       this->kernel_w(),
       "filter width must be equal to kernel width");
+  const std::vector<std::int64_t> Y_dims =
+      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+  auto* Y = Output(0, Y_dims, at::dtype<T>());
   if (X.numel() == 0) {
     VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
     return true;
   }
-  const std::vector<std::int64_t> Y_dims =
-      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
-  auto* Y = Output(0, Y_dims, at::dtype<T>());
 
   const int K_HxW = kernel_h() * kernel_w();
   const int kernel_dim = C / G * K_HxW;
@@ -196,13 +196,13 @@
       kernel_w(),
       "filter width must be equal to kernel width");
 
+  const std::vector<std::int64_t> Y_dims =
+      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+  auto* Y = Output(0, Y_dims, at::dtype<T>());
   if (X.numel() == 0) {
     VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
     return true;
   }
-  const std::vector<std::int64_t> Y_dims =
-      ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
-  auto* Y = Output(0, Y_dims, at::dtype<T>());
 
   const int K_HxW = kernel_h() * kernel_w();
   const int kernel_dim = C / G * K_HxW;
@@ -362,7 +362,9 @@
 
   if (X.numel() == 0) {
     VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
-    math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    if (dbias_data != nullptr) {
+      math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    }
     return true;
   }
 
@@ -525,7 +527,9 @@
 
   if (X.numel() == 0) {
     VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
-    math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    if (dbias_data != nullptr) {
+      math::Set<T, Context>(C, T(0), dbias_data, &context_);
+    }
     return true;
   }
 
diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h
index b9ada87..45fc78c 100644
--- a/caffe2/operators/conv_transpose_op_mobile_impl.h
+++ b/caffe2/operators/conv_transpose_op_mobile_impl.h
@@ -548,14 +548,14 @@
         "bias dimension must be equal to output channel number");
   }
 
+  auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+  Tensor* Y = Output(0, sizes, at::dtype<T>());
+
   if (X.numel() == 0) {
     VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
     return true;
   }
 
-  auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
-  Tensor* Y = Output(0, sizes, at::dtype<T>());
-
   const int outputH = Y->dim32(2);
   const int outputW = Y->dim32(3);
   const int outputPlaneSize = outputH * outputW;
diff --git a/caffe2/operators/conv_transpose_unpool_op_base.h b/caffe2/operators/conv_transpose_unpool_op_base.h
index c98b3ba..331a0ac 100644
--- a/caffe2/operators/conv_transpose_unpool_op_base.h
+++ b/caffe2/operators/conv_transpose_unpool_op_base.h
@@ -136,7 +136,7 @@
   // Gets the output size. The output channel is manually specified.
   std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
     CAFFE_ENFORCE(4 == input.dim());
-    CAFFE_ENFORCE(input.numel() > 0);
+    CAFFE_ENFORCE_GT(input.size_from_dim(1), 0);
     int N = input.dim32(0);
     bool channel_first = false; // initialized to suppress compiler warning.
     int H = 0, W = 0; // initialized to suppress compiler warning.
diff --git a/caffe2/python/operator_test/conv_transpose_test.py b/caffe2/python/operator_test/conv_transpose_test.py
index 272ac3a..c2e3ef0 100644
--- a/caffe2/python/operator_test/conv_transpose_test.py
+++ b/caffe2/python/operator_test/conv_transpose_test.py
@@ -20,7 +20,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 3),
+           batch_size=st.integers(0, 3),
            engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
            shared_buffer=st.booleans(),
            use_bias=st.booleans(),
@@ -90,7 +90,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 3),
+           batch_size=st.integers(0, 3),
            engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
            shared_buffer=st.booleans(),
            use_bias=st.booleans(),
@@ -166,7 +166,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 3),
+           batch_size=st.integers(0, 3),
            engine=st.sampled_from(["", "BLOCK"]),
            use_bias=st.booleans(),
            **hu.gcs)
@@ -235,7 +235,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 3),
+           batch_size=st.integers(0, 3),
            order=st.sampled_from(["NCHW", "NHWC"]),
            engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
            use_bias=st.booleans(),
@@ -303,7 +303,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 3),
+           batch_size=st.integers(0, 3),
            order=st.sampled_from(["NCHW", "NHWC"]),
            engine=st.sampled_from(["", "BLOCK"]),
            use_bias=st.booleans(),
@@ -368,7 +368,7 @@
            size=st.integers(7, 10),
            input_channels=st.integers(1, 8),
            output_channels=st.integers(1, 8),
-           batch_size=st.integers(1, 4),
+           batch_size=st.integers(0, 4),
            group=st.integers(1, 4),
            order=st.sampled_from(["NCHW", "NHWC"]),
            engine=st.sampled_from(["", "CUDNN", "BLOCK"]),