fix zero-batch handling in convtranspose (#24341)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/24341
ConvTransposeOp doesn't crash for zero-batch, but it doesn't modify the output blob. This leads to buggy behaviour especially when running the same network twice using different input, or backprop during training.
Seems `ConvTransposeUnpoolBase<Context>::GetOutputSize` works for zero-batch, so I remove the check for `input.numel() > 0`, and reshape the output blob before returning.
For CudnnConvTransposeGradientOp, it's a bit verbose to set `dfilter` and `dbias`, it's a seems the Cudnn can handle it, so simply remove the `X.numel() == 0` branch.
Test Plan: buck test mode/dev-nosan caffe2/caffe2/python/operator_test:conv_transpose_test -- --run-disabled
Reviewed By: BIT-silence
Differential Revision: D16807606
fbshipit-source-id: 0d72c5bd8f2e03c34465e7b530cca548d9bdd5e1
diff --git a/caffe2/operators/conv_transpose_op_cudnn.cc b/caffe2/operators/conv_transpose_op_cudnn.cc
index 1446404..c00e43e 100644
--- a/caffe2/operators/conv_transpose_op_cudnn.cc
+++ b/caffe2/operators/conv_transpose_op_cudnn.cc
@@ -186,14 +186,14 @@
LOG(FATAL) << "Unknown storage order: " << order_;
}
+ auto sizes = ConvTransposeUnpoolBase<CUDAContext>::GetOutputSize(X, C);
+ auto* Y = Output(0, sizes, at::dtype<T>());
+
if (X.numel() == 0) {
VLOG(2) << "Number on elements is 0 in CudnnConvTransposeOp";
return true;
}
- auto sizes = ConvTransposeUnpoolBase<CUDAContext>::GetOutputSize(X, C);
- auto* Y = Output(0, sizes, at::dtype<T>());
-
int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
switch (order_) {
case StorageOrder::NHWC:
@@ -468,11 +468,6 @@
LOG(FATAL) << "Unknown storage order: " << order_;
}
- if (X.numel() == 0) {
- VLOG(2) << "Number of elements is 0 in CudnnConvTransposeOp";
- return true;
- }
-
int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
switch (order_) {
case StorageOrder::NHWC:
diff --git a/caffe2/operators/conv_transpose_op_impl.h b/caffe2/operators/conv_transpose_op_impl.h
index 7caba35..e3be8a8 100644
--- a/caffe2/operators/conv_transpose_op_impl.h
+++ b/caffe2/operators/conv_transpose_op_impl.h
@@ -42,13 +42,13 @@
filter.dim32(3),
this->kernel_w(),
"filter width must be equal to kernel width");
+ const std::vector<std::int64_t> Y_dims =
+ ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+ auto* Y = Output(0, Y_dims, at::dtype<T>());
if (X.numel() == 0) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
return true;
}
- const std::vector<std::int64_t> Y_dims =
- ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
- auto* Y = Output(0, Y_dims, at::dtype<T>());
const int K_HxW = kernel_h() * kernel_w();
const int kernel_dim = C / G * K_HxW;
@@ -196,13 +196,13 @@
kernel_w(),
"filter width must be equal to kernel width");
+ const std::vector<std::int64_t> Y_dims =
+ ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+ auto* Y = Output(0, Y_dims, at::dtype<T>());
if (X.numel() == 0) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
return true;
}
- const std::vector<std::int64_t> Y_dims =
- ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
- auto* Y = Output(0, Y_dims, at::dtype<T>());
const int K_HxW = kernel_h() * kernel_w();
const int kernel_dim = C / G * K_HxW;
@@ -362,7 +362,9 @@
if (X.numel() == 0) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
- math::Set<T, Context>(C, T(0), dbias_data, &context_);
+ if (dbias_data != nullptr) {
+ math::Set<T, Context>(C, T(0), dbias_data, &context_);
+ }
return true;
}
@@ -525,7 +527,9 @@
if (X.numel() == 0) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
- math::Set<T, Context>(C, T(0), dbias_data, &context_);
+ if (dbias_data != nullptr) {
+ math::Set<T, Context>(C, T(0), dbias_data, &context_);
+ }
return true;
}
diff --git a/caffe2/operators/conv_transpose_op_mobile_impl.h b/caffe2/operators/conv_transpose_op_mobile_impl.h
index b9ada87..45fc78c 100644
--- a/caffe2/operators/conv_transpose_op_mobile_impl.h
+++ b/caffe2/operators/conv_transpose_op_mobile_impl.h
@@ -548,14 +548,14 @@
"bias dimension must be equal to output channel number");
}
+ auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
+ Tensor* Y = Output(0, sizes, at::dtype<T>());
+
if (X.numel() == 0) {
VLOG(2) << "Number of elements is 0 in ConvTrasposeOp";
return true;
}
- auto sizes = ConvTransposeUnpoolBase<Context>::GetOutputSize(X, C);
- Tensor* Y = Output(0, sizes, at::dtype<T>());
-
const int outputH = Y->dim32(2);
const int outputW = Y->dim32(3);
const int outputPlaneSize = outputH * outputW;
diff --git a/caffe2/operators/conv_transpose_unpool_op_base.h b/caffe2/operators/conv_transpose_unpool_op_base.h
index c98b3ba..331a0ac 100644
--- a/caffe2/operators/conv_transpose_unpool_op_base.h
+++ b/caffe2/operators/conv_transpose_unpool_op_base.h
@@ -136,7 +136,7 @@
// Gets the output size. The output channel is manually specified.
std::vector<int64_t> GetOutputSize(const Tensor& input, int output_channel) {
CAFFE_ENFORCE(4 == input.dim());
- CAFFE_ENFORCE(input.numel() > 0);
+ CAFFE_ENFORCE_GT(input.size_from_dim(1), 0);
int N = input.dim32(0);
bool channel_first = false; // initialized to suppress compiler warning.
int H = 0, W = 0; // initialized to suppress compiler warning.
diff --git a/caffe2/python/operator_test/conv_transpose_test.py b/caffe2/python/operator_test/conv_transpose_test.py
index 272ac3a..c2e3ef0 100644
--- a/caffe2/python/operator_test/conv_transpose_test.py
+++ b/caffe2/python/operator_test/conv_transpose_test.py
@@ -20,7 +20,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 3),
+ batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
@@ -90,7 +90,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 3),
+ batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
shared_buffer=st.booleans(),
use_bias=st.booleans(),
@@ -166,7 +166,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 3),
+ batch_size=st.integers(0, 3),
engine=st.sampled_from(["", "BLOCK"]),
use_bias=st.booleans(),
**hu.gcs)
@@ -235,7 +235,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 3),
+ batch_size=st.integers(0, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),
use_bias=st.booleans(),
@@ -303,7 +303,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 3),
+ batch_size=st.integers(0, 3),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "BLOCK"]),
use_bias=st.booleans(),
@@ -368,7 +368,7 @@
size=st.integers(7, 10),
input_channels=st.integers(1, 8),
output_channels=st.integers(1, 8),
- batch_size=st.integers(1, 4),
+ batch_size=st.integers(0, 4),
group=st.integers(1, 4),
order=st.sampled_from(["NCHW", "NHWC"]),
engine=st.sampled_from(["", "CUDNN", "BLOCK"]),