Adding check for a single batch in adaptive_avg_pool
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23137
Test Plan: Imported from OSS
Differential Revision: D16403804
Pulled By: zafartahirov
fbshipit-source-id: df79a8c768ffabeceb4c0044c967a623c5885484
diff --git a/aten/src/ATen/native/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/AdaptiveAveragePooling.cpp
index eec9c75..616cda3 100644
--- a/aten/src/ATen/native/AdaptiveAveragePooling.cpp
+++ b/aten/src/ATen/native/AdaptiveAveragePooling.cpp
@@ -129,10 +129,13 @@
auto osizeW = output_size[1];
/* resize output */
- if (input.ndimension() == 3)
+ if (input.ndimension() == 3 || input.size(-4) == 1)
{
- output.resize_({sizeD, osizeH, osizeW});
-
+ if (input.ndimension() == 3) {
+ output.resize_({sizeD, osizeH, osizeW});
+ } else {
+ output.resize_({1, sizeD, osizeH, osizeW});
+ }
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "adaptive_avg_pool2d_cpu", [&] {
auto input_data = input.data<scalar_t>();
auto output_data = output.data<scalar_t>();
@@ -260,7 +263,7 @@
auto gradOutput = gradOutput_.contiguous();
/* backprop */
- if (input.ndimension() == 3)
+ if (input.ndimension() == 3 || input.size(-4) == 1)
{
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(), "adaptive_avg_pool2d_backward_cpu", [&] {