remove extra computations for input usage check
diff --git a/torch/csrc/autograd/functions/convolution.cpp b/torch/csrc/autograd/functions/convolution.cpp
index a5ea954..426ff7c 100644
--- a/torch/csrc/autograd/functions/convolution.cpp
+++ b/torch/csrc/autograd/functions/convolution.cpp
@@ -448,15 +448,12 @@
auto input_size = ggI->data->sizes();
std::vector<long> input_shape(input_size.begin() + 2, input_size.end());
for(size_t i=0; i<gw_conv_params.padding.size(); ++i) {
- // Formula for conv output size before the floor operation
- auto out_size = float(input_shape[i] + 2 * gw_conv_params.padding[i] -
- gw_conv_params.dilation[i] * (kernel_size[i] - 1) - 1) /
- gw_conv_params.stride[i] + 1;
- auto exact_out_size = floorf(out_size);
- if (exact_out_size != out_size) {
- auto used_input_size = (exact_out_size - 1) * gw_conv_params.stride[i] + 1 +
- gw_conv_params.dilation[i] * (kernel_size[i] - 1) -
- 2 * gw_conv_params.padding[i];
+ // Check if whole input has been used or not
+ auto numerator = 2 * gw_conv_params.padding[i] -
+ gw_conv_params.dilation[i] * (kernel_size[i] - 1) - 1;
+ auto remainder = (input_shape[i] + numerator) % gw_conv_params.stride[i];
+ if (remainder != 0) {
+ auto used_input_size = input_shape[i] - remainder;
ggI = std::make_shared<Narrow>(i+2, 0, used_input_size)->apply({ggI})[0];
}
}