Adding 1d upsampling (#2846)
diff --git a/test/test_nn.py b/test/test_nn.py
index 3ef89ea..b3ebc1a 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2800,6 +2800,24 @@
self.assertEqual(out_cpu, out_cuda)
self.assertEqual(input_cpu.grad, input_gpu.grad)
+ def test_upsamplingNearest1d(self):
+ m = nn.Upsample(size=4, mode='nearest')
+ in_t = torch.ones(1, 1, 2)
+ out_t = m(Variable(in_t))
+ self.assertEqual(torch.ones(1, 1, 4), out_t.data)
+
+ input = Variable(torch.randn(1, 1, 2), requires_grad=True)
+ self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='nearest'), (input,)))
+
+ def test_upsamplingLinear1d(self):
+ m = nn.Upsample(size=4, mode='linear')
+ in_t = torch.ones(1, 1, 2)
+ out_t = m(Variable(in_t))
+ self.assertEqual(torch.ones(1, 1, 4), out_t.data)
+
+ input = Variable(torch.randn(1, 1, 2), requires_grad=True)
+ self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='linear'), (input,)))
+
def test_upsamplingNearest2d(self):
m = nn.Upsample(size=4, mode='nearest')
in_t = torch.ones(1, 1, 2, 2)
@@ -3970,6 +3988,42 @@
dict(
module_name='Upsample',
constructor_args=(12, None, 'nearest'),
+ input_size=(1, 2, 4),
+ desc='nearest_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=((12, ), None, 'nearest'),
+ input_size=(1, 2, 3),
+ desc='nearest_tuple_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=(None, 4, 'nearest'),
+ input_size=(1, 2, 4),
+ desc='nearest_scale_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=(12, None, 'linear'),
+ input_size=(1, 2, 4),
+ desc='linear_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=((4, ), None, 'linear'),
+ input_size=(1, 2, 3),
+ desc='linear_tuple_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=(None, 4, 'linear'),
+ input_size=(1, 2, 4),
+ desc='linear_scale_1d',
+ ),
+ dict(
+ module_name='Upsample',
+ constructor_args=(12, None, 'nearest'),
input_size=(1, 2, 4, 4),
desc='nearest_2d',
),
diff --git a/torch/lib/THCUNN/TemporalUpSamplingLinear.cu b/torch/lib/THCUNN/TemporalUpSamplingLinear.cu
new file mode 100644
index 0000000..fe59b2a
--- /dev/null
+++ b/torch/lib/THCUNN/TemporalUpSamplingLinear.cu
@@ -0,0 +1,96 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+#include "THCUNN.h"
+#include "common.h"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCDeviceUtils.cuh"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+#include "THCAtomics.cuh"
+
+template<typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+ const Acctype rwidth,
+ const THCDeviceTensor<Dtype, 3> data1, THCDeviceTensor<Dtype, 3> data2) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int batchsize = data1.getSize(0);
+ const int channels = data1.getSize(1);
+ const int width1 = data1.getSize(2);
+ const int width2 = data2.getSize(2);
+
+ if (index < n) {
+ const int w2 = index % width2;
+ // special case: just copy
+ if (width1 == width2) {
+ const int w1 = w2;
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype val = data1[n][c][w1];
+ data2[n][c][w2] = val;
+ }
+ }
+ return;
+ }
+ //
+ const Acctype w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < width1 - 1) ? 1 : 0;
+ const Acctype w1lambda = w1r - w1;
+ const Acctype w0lambda = Acctype(1) - w1lambda;
+ //
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Acctype val = w0lambda * data1[n][c][w1]
+ + w1lambda * data1[n][c][w1+w1p];
+ data2[n][c][w2] = ScalarConvert<Acctype, Dtype>::to(val);
+ }
+ }
+ }
+}
+
+// Backward (adjoint) operation 1 <- 2 (accumulates)
+template <typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel_backward(const int n,
+ const Acctype rwidth,
+ THCDeviceTensor<Dtype, 3> data1, const THCDeviceTensor<Dtype, 3> data2){
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int batchsize = data1.getSize(0);
+ const int channels = data1.getSize(1);
+ const int width1 = data1.getSize(2);
+ const int width2 = data2.getSize(2);
+ if (index < n) {
+ const int w2 = index % width2;
+ // special case: just copy
+ if (width1 == width2) {
+ const int w1 = w2;
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype val = data2[n][c][w1];
+ data1[n][c][w2] += val;
+ }
+ }
+ return;
+ }
+ //
+ const Acctype w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < width1 - 1) ? 1 : 0;
+ const Acctype w1lambda = w1r - w1;
+ const Acctype w0lambda = Acctype(1) - w1lambda;
+ //
+ for (int n = 0; n < batchsize ; n++){
+ for (int c = 0; c < channels; ++c) {
+ const Dtype d2val = data2[n][c][w2];
+ atomicAdd(data1[n][c][w1].data(),
+ ScalarConvert<Acctype, Dtype>::to(w0lambda * d2val));
+ atomicAdd(data1[n][c][w1+w1p].data(),
+ ScalarConvert<Acctype, Dtype>::to(w1lambda * d2val));
+ }
+ }
+ }
+}
+
+
+#include "generic/TemporalUpSamplingLinear.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/TemporalUpSamplingNearest.cu b/torch/lib/THCUNN/TemporalUpSamplingNearest.cu
new file mode 100644
index 0000000..f5dd5d9
--- /dev/null
+++ b/torch/lib/THCUNN/TemporalUpSamplingNearest.cu
@@ -0,0 +1,75 @@
+#include "THCUNN.h"
+#include "common.h"
+
+#include <thrust/transform.h>
+#include <thrust/reduce.h>
+#include <thrust/transform_reduce.h>
+#include <thrust/functional.h>
+
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+/*
+ * Description:
+ */
+
+__device__ int translate_idx(int ii, int d1, int d2, int scale_factor)
+{
+ int x, y, z;
+ z = ii % d2;
+ ii = ii/d2;
+ y = ii % d1;
+ ii = ii/d1;
+ x = ii;
+ z = z/scale_factor;
+ d2 /= scale_factor;
+ return ((x*d1+y)*d2)+z;
+
+}
+__device__ int translate_idx_inv(int ii, int d1, int d2, int scale_factor, int off_x)
+{
+ int x, y, z;
+ z = ii % d2;
+ ii = ii/d2;
+ y = ii % d1;
+ ii = ii/d1;
+ x = ii;
+ z = z*scale_factor+off_x;
+ d2 *= scale_factor;
+ return ((x*d1+y)*d2)+z;
+
+}
+
+template <typename Dtype>
+__global__ void upscale(Dtype *input, Dtype *output, int64_t no_elements,
+ int scale_factor, int d1, int d2)
+{
+ // output offset:
+ int64_t ii = threadIdx.x + blockDim.x * blockIdx.x;
+ ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+ if (ii >= no_elements) return;
+ int ipidx = translate_idx(ii, d1, d2, scale_factor);
+ output[ii]=input[ipidx];
+}
+
+/*
+ * Description:
+ */
+template <typename Dtype, typename Acctype>
+__global__ void downscale(Dtype *gradInput_data, Dtype *gradOutput_data, int64_t no_elements,
+ int scale_factor, int d1, int d2)
+{
+ // output offset:
+ int64_t ii = threadIdx.x + blockDim.x * blockIdx.x;
+ ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+ if (ii >= no_elements) return;
+ Acctype sum = Acctype(0);
+ for (int i=0; i < scale_factor; i++){
+ int ipidx = translate_idx_inv(ii, d1, d2, scale_factor, i);
+ sum += gradOutput_data[ipidx];
+ }
+ gradInput_data[ii] += ScalarConvert<Acctype, Dtype>::to(sum);
+}
+
+#include "generic/TemporalUpSamplingNearest.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index 2159ed3..9963706 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -1257,6 +1257,34 @@
THCTensor *gradInput,
int padL, int padR);
+TH_API void THNN_(TemporalUpSamplingLinear_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int outputWidth);
+
+TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+ THCState *state,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputWidth,
+ int outputWidth);
+
+TH_API void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int scale_factor);
+
+TH_API void THNN_(TemporalUpSamplingNearest_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int scale_factor);
+
TH_API void THNN_(Threshold_updateOutput)(
THCState *state,
THCTensor *input,
diff --git a/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu b/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu
new file mode 100644
index 0000000..194794e
--- /dev/null
+++ b/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu
@@ -0,0 +1,97 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/TemporalUpSamplingLinear.cu"
+#else
+
+static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (THCState *state,
+ THCTensor *input, THCTensor *gradOutput,
+ int nBatch, int nChannels,
+ int inputWidth,
+ int outputWidth) {
+ THArgCheck(inputWidth > 0 && outputWidth > 0, 2,
+ "input and output sizes should be greater than 0,"
+ " but got input (W: %d) output (W: %d)",
+ inputWidth, outputWidth);
+ if (input != NULL) {
+ THCUNN_argCheck(state, input->nDimension == 3, 2, input,
+ "3D input tensor expected but got: %s");
+ }
+
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 3, 0, nBatch);
+ THCUNN_check_dim_size(state, gradOutput, 3, 1, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 3, 2, outputWidth);
+ }
+}
+
+void THNN_(TemporalUpSamplingLinear_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int outputWidth)
+{
+ int nbatch = THCTensor_(size)(state, input, 0);
+ int channels = THCTensor_(size)(state, input, 1);
+ int inputWidth = THCTensor_(size)(state, input, 2);
+ THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (state, input, NULL,
+ nbatch, channels,
+ inputWidth, outputWidth);
+ input = THCTensor_(newContiguous)(state, input);
+ THCUNN_assertSameGPU(state, 2, input, output);
+ THCTensor_(resize3d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ THCTensor_(size)(state, input, 1),
+ outputWidth);
+ THCTensor_(zero)(state, output);
+ THCDeviceTensor<real, 3> idata = toDeviceTensor<real, 3>(state, input);
+ THCDeviceTensor<real, 3> odata = toDeviceTensor<real, 3>(state, output);
+ THAssert(inputWidth > 0 && outputWidth > 0);
+ const accreal rwidth = (outputWidth > 1) ? (accreal)(inputWidth - 1)/(outputWidth - 1) : accreal(0);
+ const int num_kernels = outputWidth;
+ const int num_threads =
+ THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ caffe_gpu_interp2_kernel<real, accreal> <<<THCCeilDiv(num_kernels, num_threads), num_threads ,
+ 0 , stream>>>(num_kernels, rwidth, idata, odata);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, input);
+}
+
+
+void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+ THCState *state,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputWidth,
+ int outputWidth)
+{
+ THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (state, NULL, gradOutput,
+ nbatch, nchannels,
+ inputWidth, outputWidth);
+ gradInput = THCTensor_(newContiguous)(state, gradInput);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+ THCTensor_(resize3d)(state, gradInput, nbatch, nchannels, inputWidth);
+ THCTensor_(zero)(state, gradInput);
+ THCDeviceTensor<real, 3> data1 = toDeviceTensor<real, 3>(state, gradInput);
+ THCDeviceTensor<real, 3> data2 = toDeviceTensor<real, 3>(state, gradOutput);
+ int width1 = data1.getSize(2);
+ int width2 = data2.getSize(2);
+ assert(width1 > 0 && width2 > 0);
+ const accreal rwidth = (width2 > 1) ? (accreal)(width1 - 1) / (width2 - 1) : accreal(0);
+ const int num_kernels = width2;
+ const int num_threads =
+ THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+ cudaStream_t stream = THCState_getCurrentStream(state);
+ caffe_gpu_interp2_kernel_backward<real ,accreal> <<<THCCeilDiv(num_kernels, num_threads),
+ num_threads, 0, stream>>>(num_kernels, rwidth, data1, data2);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, gradInput);
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu b/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu
new file mode 100644
index 0000000..567346c
--- /dev/null
+++ b/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu
@@ -0,0 +1,157 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/TemporalUpSamplingNearest.cu"
+#else
+
+#include "../common.h"
+
+static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
+ (THCState *state,THCTensor *input, THCTensor *gradOutput,
+ int scale_factor) {
+ THArgCheck(input != NULL, 2, "3D input tensor expected but got NULL");
+ THArgCheck(scale_factor > 1, 4,
+ "scale_factor must be greater than 1, but got: %d", scale_factor);
+ THCUNN_argCheck(state, input->nDimension == 2 || input->nDimension == 3, 2, input,
+ "2D or 3D input tensor expected but got: %s");
+ if (input->nDimension == 2) {
+ int nChannels = THCTensor_(size)(state, input, 0);
+ int inputWidth = THCTensor_(size)(state, input, 1);
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 2, 0, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 2, 1, outputWidth);
+ }
+ } else {
+ int nBatch = THCTensor_(size)(state, input, 0);
+ int nChannels = THCTensor_(size)(state, input, 1);
+ int inputWidth = THCTensor_(size)(state, input, 2);
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THCUNN_check_dim_size(state, gradOutput, 3, 0, nBatch);
+ THCUNN_check_dim_size(state, gradOutput, 3, 1, nChannels);
+ THCUNN_check_dim_size(state, gradOutput, 3, 2, outputWidth);
+ }
+ }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateOutput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *output,
+ int scale_factor)
+{
+ THCTensor_(zero)(state, output);
+
+ THCUNN_assertSameGPU(state, 2, input, output);
+ THNN_(TemporalUpSamplingNearest_shapeCheck)(state, input, NULL, scale_factor);
+ int inputWidth = THCTensor_(size)(state, input, input->nDimension-1);
+ int outputWidth = inputWidth * scale_factor;
+
+ if (input->nDimension == 2) {
+ THCTensor_(resize2d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ outputWidth);
+ } else {
+ THCTensor_(resize3d)(state, output,
+ THCTensor_(size)(state, input, 0),
+ THCTensor_(size)(state, input, 1),
+ outputWidth);
+ }
+
+ input = THCTensor_(newContiguous)(state, input);
+ // This is for allocating output Tensor
+ int64_t no_elements = 1;
+ for(int i = 0; i < input->nDimension; i++){
+ no_elements *= input->size[i];
+ }
+ no_elements *= scale_factor * scale_factor;
+
+ int d1;
+ int d2;
+
+ if (input->nDimension == 2) {
+ d1 = output->size[0];
+ d2 = output->size[1];
+ } else {
+ d1 = output->size[1];
+ d2 = output->size[2];
+ }
+
+ real *input_data = THCTensor_(data)(state, input);
+ real *output_data = THCTensor_(data)(state, output);
+
+ // cuda blocks & threads:
+ int64_t nthreads = 256;
+ // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+ // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+ // TODO: When we move to SM 3.5 we should update this
+ int64_t n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+ int64_t n_yblocks = (int64_t)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+ if (n_yblocks > 65535) {
+ THError("Input size is too large! aborting");
+ }
+ dim3 blocks(n_xblocks, n_yblocks);
+ dim3 threads(nthreads);
+
+ // kernel:
+ upscale<<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, no_elements, scale_factor, d1, d2);
+ THCudaCheck(cudaGetLastError());
+
+ // final cut:
+ THCTensor_(free)(state, input);
+}
+
+void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+ THCState *state,
+ THCTensor *input,
+ THCTensor *gradOutput,
+ THCTensor *gradInput,
+ int scale_factor)
+{
+
+ THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+ THNN_(TemporalUpSamplingNearest_shapeCheck)(state, input, gradOutput, scale_factor);
+ gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+ THCTensor_(resizeAs)(state, gradInput, input);
+
+ THCTensor_(zero)(state, gradInput);
+
+ real *gradInput_data = THCTensor_(data)(state, gradInput);
+ real *gradOutput_data = THCTensor_(data)(state, gradOutput);
+
+ int64_t no_elements = 1;
+ for(int i = 0; i < gradInput->nDimension; i++){
+ no_elements *= gradInput->size[i];
+ }
+
+ int d1;
+ int d2;
+
+ if (gradInput->nDimension == 2) {
+ d1 = gradInput->size[0];
+ d2 = gradInput->size[1];
+ } else {
+ d1 = gradInput->size[1];
+ d2 = gradInput->size[2];
+ }
+
+ // cuda blocks & threads:
+ int64_t nthreads = 256;
+ // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+ // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+ // TODO: When we move to SM 3.5 we should update this
+ int64_t n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+ int64_t n_yblocks = (int64_t)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+ if (n_yblocks > 65535) {
+ THError("Input size is too large! aborting");
+ }
+ dim3 blocks(n_xblocks, n_yblocks);
+ dim3 threads(nthreads);
+
+ // kernel:
+ downscale<real ,accreal> <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, no_elements,
+ scale_factor, d1, d2);
+ THCudaCheck(cudaGetLastError());
+ THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index 8dda789..d7ee0a1 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -720,6 +720,32 @@
bool featFirst,
accreal scale);
+TH_API void THNN_(TemporalUpSamplingNearest_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int scale_factor);
+TH_API void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int scale_factor);
+
+TH_API void THNN_(TemporalUpSamplingLinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int outputWidth);
+TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+ THNNState *state,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int nbatch,
+ int nchannels,
+ int inputWidth,
+ int outputWidth);
+
TH_API void THNN_(BatchNormalization_updateOutput)(
THNNState *state,
THTensor *input,
@@ -1203,7 +1229,7 @@
THNNState *state,
THTensor *input,
THTensor *output,
- int outputHeight,
+ int outputHeight,
int outputWidth);
TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
THNNState *state,
diff --git a/torch/lib/THNN/generic/TemporalUpSamplingLinear.c b/torch/lib/THNN/generic/TemporalUpSamplingLinear.c
new file mode 100644
index 0000000..67606d7
--- /dev/null
+++ b/torch/lib/THNN/generic/TemporalUpSamplingLinear.c
@@ -0,0 +1,140 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/TemporalUpSamplingLinear.c"
+#else
+
+static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (THTensor *input, THTensor *gradOutput,
+ int nBatch, int nChannels,
+ int inputWidth, int outputWidth) {
+ THArgCheck(inputWidth > 0 && outputWidth > 0, 2,
+ "input and output sizes should be greater than 0,"
+ " but got input (W: %d) output (W: %d)",
+ inputWidth, outputWidth);
+ if (input != NULL) {
+ THNN_ARGCHECK(input->nDimension == 3, 2, input,
+ "3D input tensor expected but got: %s");
+ }
+
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch);
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth);
+ }
+}
+
+void THNN_(TemporalUpSamplingLinear_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int outputWidth){
+
+ int nbatch = THTensor_(size)(input, 0);
+ int channels = THTensor_(size)(input, 1);
+ int inputWidth = THTensor_(size)(input, 2);
+
+ THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (input, NULL,
+ nbatch, channels,
+ inputWidth, outputWidth);
+
+ input = THTensor_(newContiguous)(input);
+ THTensor_(resize3d)(output,
+ THTensor_(size)(input, 0),
+ THTensor_(size)(input, 1),
+ outputWidth);
+ THTensor_(zero)(output);
+ real *idata = THTensor_(data)(input);
+ real *odata = THTensor_(data)(output);
+ channels = nbatch * channels;
+ THAssert(inputWidth > 0 && outputWidth > 0);
+ // special case: just copy
+ if (inputWidth == outputWidth) {
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const int w1 = w2;
+ const real* pos1 = &idata[w1];
+ real* pos2 = &odata[w2];
+ for (int c = 0; c < channels; ++c) {
+ pos2[0] = pos1[0];
+ pos1 += inputWidth;
+ pos2 += outputWidth;
+ }
+ }
+ return;
+ }
+ const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1) / (outputWidth - 1) : 0.f;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const float w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+ const real w1lambda = w1r - w1;
+ const real w0lambda = (real)1. - w1lambda;
+ const real* pos1 = &idata[w1];
+ real* pos2 = &odata[w2];
+ for (int c = 0; c < channels; ++c) {
+ pos2[0] = w0lambda * pos1[0] + w1lambda * pos1[w1p];
+ pos1 += inputWidth;
+ pos2 += outputWidth;
+ }
+ }
+ THTensor_(free)(input);
+}
+
+void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+ THNNState *state,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int nbatch,
+ int channels,
+ int inputWidth,
+ int outputWidth){
+
+ THNN_(TemporalUpSamplingLinear_shapeCheck)
+ (NULL, gradOutput,
+ nbatch, channels,
+ inputWidth,
+ outputWidth);
+
+ THTensor_(resize3d)(gradInput, nbatch, channels, inputWidth);
+ THTensor_(zero)(gradInput);
+ gradOutput = THTensor_(newContiguous)(gradOutput);
+ real *data1 = THTensor_(data)(gradInput);
+ real *data2 = THTensor_(data)(gradOutput);
+ channels = nbatch * channels;
+
+ // special case: same-size matching grids
+ if (inputWidth == outputWidth) {
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const int w1 = w2;
+ real* pos1 = &data1[w1];
+ const real* pos2 = &data2[w2];
+ for (int c = 0; c < channels; ++c) {
+ pos1[0] += pos2[0];
+ pos1 += inputWidth;
+ pos2 += outputWidth;
+ }
+ }
+ return;
+ }
+ const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1)/(outputWidth - 1) : 0.f;
+ for (int w2 = 0; w2 < outputWidth; ++w2) {
+ const float w1r = rwidth * w2;
+ const int w1 = w1r;
+ const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+ const real w1lambda = w1r - w1;
+ const real w0lambda = (real)1. - w1lambda;
+ real* pos1 = &data1[w1];
+ const real* pos2 = &data2[w2];
+ for (int c = 0; c < channels; ++c) {
+ pos1[0] += w0lambda * pos2[0];
+ pos1[w1p] += w1lambda * pos2[0];
+ pos1 += inputWidth;
+ pos2 += outputWidth;
+ }
+ }
+ THTensor_(free)(gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THNN/generic/TemporalUpSamplingNearest.c b/torch/lib/THNN/generic/TemporalUpSamplingNearest.c
new file mode 100644
index 0000000..bc8d3a8
--- /dev/null
+++ b/torch/lib/THNN/generic/TemporalUpSamplingNearest.c
@@ -0,0 +1,173 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/TemporalUpSamplingNearest.c"
+#else
+
+
+static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
+ (THTensor *input, THTensor *gradOutput,
+ int scale_factor) {
+ THArgCheck(input != NULL, 2, "3D input tensor expected but got NULL");
+ THArgCheck(scale_factor > 1, 4,
+ "scale_factor must be greater than 1, but got: %d", scale_factor);
+ THNN_ARGCHECK(input->nDimension == 2 || input->nDimension == 3, 2, input,
+ "2D or 3D input tensor expected but got: %s");
+ if (input->nDimension == 2) {
+ int nChannels = THTensor_(size)(input, 0);
+ int inputWidth = THTensor_(size)(input, 1);
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, outputWidth);
+ }
+ } else {
+ int nBatch = THTensor_(size)(input, 0);
+ int nChannels = THTensor_(size)(input, 1);
+ int inputWidth = THTensor_(size)(input, 2);
+ int outputWidth = inputWidth * scale_factor;
+ if (gradOutput != NULL) {
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch);
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels);
+ THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth);
+ }
+ }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateOutput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *output,
+ int scale_factor)
+{
+ THNN_(TemporalUpSamplingNearest_shapeCheck)(input, NULL, scale_factor);
+ int inputWidth = THTensor_(size)(input, input->nDimension-1);
+ int outputWidth = inputWidth * scale_factor;
+
+ if (input->nDimension == 2) {
+ THTensor_(resize2d)(output,
+ THTensor_(size)(input, 0),
+ outputWidth);
+ } else {
+ THTensor_(resize3d)(output,
+ THTensor_(size)(input, 0),
+ THTensor_(size)(input, 1),
+ outputWidth);
+ }
+
+ int dW = scale_factor;
+ int xDim = input->nDimension-1;
+
+ // dims
+ int idim = input->nDimension;
+ int osz0 = output->size[0];
+ int osz1 = output->size[1];
+ int osz2 = 1;
+ if (idim > 2) {
+ osz2 = output->size[2];
+ }
+
+ // get strides
+ int64_t *is = input->stride;
+ int64_t *os = output->stride;
+
+ // get raw pointers
+ real *pin = THTensor_(data)(input);
+ real *pout = THTensor_(data)(output);
+
+ // perform the upsampling
+ int i0, i1, i2, isrc, idst;
+ int iout[3]; // Output indices
+ int iin[3]; // Input indices
+
+ for (i0 = 0; i0 < osz0; i0++) {
+ iout[0] = i0;
+ iin[0] = i0;
+ for (i1 = 0; i1 < osz1; i1++) {
+ iout[1] = i1;
+ iin[1] = i1;
+ for (i2 = 0; i2 < osz2; i2++) {
+ iout[2] = i2;
+ iin[2] = i2;
+
+ // set the indices for the upsampled dimensions
+ iin[xDim] = iout[xDim] / dW;
+
+ idst = i0*os[0] + i1*os[1];
+ isrc = iin[0]*is[0] + iin[1]*is[1];
+ if (idim > 2) {
+ idst += i2*os[2];
+ isrc += iin[2]*is[2];
+ }
+
+ pout[idst] = pin[isrc];
+ }
+ }
+ }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+ THNNState *state,
+ THTensor *input,
+ THTensor *gradOutput,
+ THTensor *gradInput,
+ int scale_factor)
+{
+ THNN_(TemporalUpSamplingNearest_shapeCheck)(input, gradOutput, scale_factor);
+ THTensor_(resizeAs)(gradInput, input);
+
+ int dW = scale_factor;
+ int xDim = gradInput->nDimension-1;
+
+ // dims
+ int idim = gradInput->nDimension; // Guaranteed to be between 2 and 4
+ int isz0 = gradInput->size[0];
+ int isz1 = gradInput->size[1];
+ int isz2 = 1;
+ if (idim > 2) {
+ isz2 = gradInput->size[2];
+ }
+
+ // get strides
+ int64_t *is = gradInput->stride;
+ int64_t *os = gradOutput->stride;
+
+ // get raw pointers
+ real *pin = THTensor_(data)(gradInput);
+ real *pout = THTensor_(data)(gradOutput);
+
+ // perform the upsampling
+ int i0, i1, i2, isrc, idst, x, y;
+ int iin[3]; // Input indices
+ int iout[3]; // Output indices
+
+ THTensor_(zero)(gradInput);
+
+ for (i0 = 0; i0 < isz0; i0++) {
+ iin[0] = i0;
+ iout[0] = i0;
+ for (i1 = 0; i1 < isz1; i1++) {
+ iin[1] = i1;
+ iout[1] = i1;
+ for (i2 = 0; i2 < isz2; i2++) {
+ iin[2] = i2;
+ iout[2] = i2;
+
+ idst = i0*is[0] + i1*is[1];
+ if (idim > 2) {
+ idst += i2*is[2];
+ }
+
+ // Now accumulate the gradients from gradOutput
+ for (x = 0; x < dW; x++) {
+ iout[xDim] = dW * iin[xDim] + x;
+ isrc = iout[0]*os[0] + iout[1]*os[1];
+ if (idim > 2) {
+ isrc += iout[2]*os[2];
+ }
+ pin[idst] += pout[isrc];
+ }
+ }
+ }
+ }
+}
+
+#endif
diff --git a/torch/lib/THNN/init.c b/torch/lib/THNN/init.c
index ed28119..d0e6ba6 100644
--- a/torch/lib/THNN/init.c
+++ b/torch/lib/THNN/init.c
@@ -179,6 +179,12 @@
#include "generic/TemporalRowConvolution.c"
#include "THGenerateFloatTypes.h"
+#include "generic/TemporalUpSamplingNearest.c"
+#include "THGenerateFloatTypes.h"
+
+#include "generic/TemporalUpSamplingLinear.c"
+#include "THGenerateFloatTypes.h"
+
#include "generic/FeatureLPPooling.c"
#include "THGenerateFloatTypes.h"
diff --git a/torch/nn/_functions/thnn/upsampling.py b/torch/nn/_functions/thnn/upsampling.py
index 9ab14a1..48ecaeb 100644
--- a/torch/nn/_functions/thnn/upsampling.py
+++ b/torch/nn/_functions/thnn/upsampling.py
@@ -4,7 +4,7 @@
from torch._thnn import type2backend
from . import _all_functions
-from ...modules.utils import _pair, _triple
+from ...modules.utils import _single, _pair, _triple
def _check_size_scale_factor(size, scale_factor):
@@ -14,6 +14,158 @@
raise ValueError('scale_factor must be of integer type or a tuple of integer types')
+def _check_linear_scale_factor(scale_factor, dim=2):
+ if dim == 1:
+ scale_factor = _single(scale_factor)
+ elif dim == 2:
+ scale_factor = _pair(scale_factor)
+ elif dim == 3:
+ scale_factor = _triple(scale_factor)
+ else:
+ raise ValueError("dim has to be 1, 2 or 3")
+
+ try:
+ assert len(scale_factor) == 1 or len(scale_factor) == 2 or len(scale_factor) == 3
+ assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
+ except AssertionError as e:
+ raise ValueError('scale_factor must be a non-negative integer, '
+ 'or a tuple of non-negative integers for linear, bilinear and trilinear upsampling, but got: '
+ '{}'.format(scale_factor))
+ return scale_factor
+
+
+class UpsamplingNearest1d(Function):
+
+ @staticmethod
+ def forward(ctx, input, size=None, scale_factor=None):
+ assert input.dim() == 3
+
+ _check_size_scale_factor(size, scale_factor)
+
+ ctx.size = size
+ ctx.scale_factor = scale_factor
+
+ if ctx.scale_factor is not None and not isinstance(ctx.scale_factor, Integral):
+ raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
+
+ if ctx.scale_factor is None:
+ if (ctx.size[0] % input.size(2) != 0):
+ raise RuntimeError("output size specified in UpsamplingNearest "
+ "({}) has to be divisible by the input size, but got: "
+ "{}".format('x'.join(map(str, ctx.size)),
+ 'x'.join(map(str, input.size()))))
+ ctx.scale_factor = ctx.size[0] // input.size(2)
+
+ output = input.new()
+ backend = type2backend[type(input)]
+ ctx.save_for_backward(input)
+ backend.TemporalUpSamplingNearest_updateOutput(
+ backend.library_state,
+ input,
+ output,
+ ctx.scale_factor
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, = ctx.saved_variables
+ grad_input = UpsamplingNearest1dBackward.apply(input, grad_output, ctx.scale_factor)
+ return grad_input, None, None
+
+
+class UpsamplingNearest1dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, input, grad_output, scale_factor):
+ assert grad_output.dim() == 3
+ ctx.scale_factor = scale_factor
+
+ grad_input = grad_output.new()
+ backend = type2backend[type(input)]
+ backend.TemporalUpSamplingNearest_updateGradInput(
+ backend.library_state,
+ input,
+ grad_output,
+ grad_input,
+ ctx.scale_factor
+ )
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, ggI):
+ gI = None
+ ggO = UpsamplingNearest1d.apply(ggI, None, ctx.scale_factor)
+
+ return gI, ggO, None
+
+
+class UpsamplingLinear1d(Function):
+
+ @staticmethod
+ def forward(ctx, input, size=None, scale_factor=None):
+ assert input.dim() == 3
+
+ ctx.size = size
+ ctx.scale_factor = scale_factor
+
+ if ctx.scale_factor is not None:
+ ctx.scale_factor = _check_linear_scale_factor(ctx.scale_factor, dim=1)
+
+ if ctx.scale_factor is not None:
+ ctx.output_size = (
+ input.size(2) * ctx.scale_factor[0],
+ )
+ else:
+ ctx.output_size = ctx.size
+
+ ctx.input_size = input.size()
+ output = input.new()
+ backend = type2backend[type(input)]
+ backend.TemporalUpSamplingLinear_updateOutput(
+ backend.library_state,
+ input,
+ output,
+ ctx.output_size[0]
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = UpsamplingLinear1dBackward.apply(grad_output, ctx.input_size, ctx.output_size)
+ return grad_input, None, None
+
+
+class UpsamplingLinear1dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, input_size, output_size):
+ assert grad_output.dim() == 3
+
+ ctx.input_size = input_size
+ ctx.output_size = output_size
+
+ grad_output = grad_output.contiguous()
+ grad_input = grad_output.new()
+ backend = type2backend[type(grad_output)]
+ backend.TemporalUpSamplingLinear_updateGradInput(
+ backend.library_state,
+ grad_output,
+ grad_input,
+ ctx.input_size[0],
+ ctx.input_size[1],
+ ctx.input_size[2],
+ ctx.output_size[0],
+ )
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, ggI):
+ ggO = UpsamplingLinear1d.apply(ggI, ctx.output_size, None)
+
+ return ggO, None, None
+
+
class UpsamplingNearest2d(Function):
@staticmethod
@@ -84,24 +236,6 @@
return gI, ggO, None
-def _check_linear_scale_factor(scale_factor, dim=2):
- if dim == 2:
- scale_factor = _pair(scale_factor)
- elif dim == 3:
- scale_factor = _triple(scale_factor)
- else:
- raise ValueError("dim has to be 2 or 3")
-
- try:
- assert len(scale_factor) == 2 or len(scale_factor) == 3
- assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
- except AssertionError as e:
- raise ValueError('scale_factor must be a non-negative integer, '
- 'or a tuple of non-negative integers for bilinear and trilinear upsampling, but got: '
- '{}'.format(scale_factor))
- return scale_factor
-
-
class UpsamplingBilinear2d(Function):
@staticmethod
@@ -311,6 +445,10 @@
return ggO, None, None
+_all_functions.append(UpsamplingNearest1d)
+_all_functions.append(UpsamplingNearest1dBackward)
+_all_functions.append(UpsamplingLinear1d)
+_all_functions.append(UpsamplingLinear1dBackward)
_all_functions.append(UpsamplingNearest2d)
_all_functions.append(UpsamplingNearest2dBackward)
_all_functions.append(UpsamplingBilinear2d)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 3ff9f0f..7469db5 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -940,38 +940,50 @@
The algorithm used for upsampling is determined by :attr:`mode`.
- Currently spatial and volumetric upsampling are supported, i.e.
- expected inputs are 4-D or 5-D in shape.
+ Currently temporal, spatial and volumetric upsampling are supported, i.e.
+ expected inputs are 3-D, 4-D or 5-D in shape.
The input dimensions are interpreted in the form:
- `mini-batch x channels x [depth] x height x width`
+ `mini-batch x channels x [depth] x [height] x width`
- The modes available for upsampling are: `nearest`, `bilinear` (4D-only),
- `trilinear` (5D-only)
+ The modes available for upsampling are: `nearest`, `linear` (3D-only),
+ `bilinear` (4D-only), `trilinear` (5D-only)
Args:
input (Variable): input
- size (int or Tuple[int, int] or Tuple[int, int, int]):
+ size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
output spatial size.
scale_factor (int): multiplier for spatial size. Has to be an integer.
mode (string): algorithm used for upsampling:
- 'nearest' | 'bilinear' | 'trilinear'. Default: 'nearest'
+ 'nearest' | 'linear' | 'bilinear' | 'trilinear'. Default: 'nearest'
"""
- if input.dim() == 4 and mode == 'nearest':
+ if input.dim() == 3 and mode == 'nearest':
+ return _functions.thnn.UpsamplingNearest1d.apply(input, _single(size), scale_factor)
+ elif input.dim() == 4 and mode == 'nearest':
return _functions.thnn.UpsamplingNearest2d.apply(input, _pair(size), scale_factor)
elif input.dim() == 5 and mode == 'nearest':
return _functions.thnn.UpsamplingNearest3d.apply(input, _triple(size), scale_factor)
+ elif input.dim() == 3 and mode == 'linear':
+ return _functions.thnn.UpsamplingLinear1d.apply(input, _single(size), scale_factor)
+ elif input.dim() == 3 and mode == 'bilinear':
+ raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
+ elif input.dim() == 3 and mode == 'trilinear':
+ raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
+ elif input.dim() == 4 and mode == 'linear':
+ raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
elif input.dim() == 4 and mode == 'bilinear':
return _functions.thnn.UpsamplingBilinear2d.apply(input, _pair(size), scale_factor)
elif input.dim() == 4 and mode == 'trilinear':
raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
+ elif input.dim() == 5 and mode == 'linear':
+ raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
elif input.dim() == 5 and mode == 'bilinear':
raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
elif input.dim() == 5 and mode == 'trilinear':
return _functions.thnn.UpsamplingTrilinear3d.apply(input, _triple(size), scale_factor)
else:
- raise NotImplementedError("Input Error: Only 4D and 5D input Tensors supported"
- " (got {}D) for the modes: nearest | bilinear | trilinear"
+ raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported"
+ " (got {}D) for the modes: nearest | linear | bilinear | trilinear"
" (got {})".format(input.dim(), mode))
diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py
index d26bd9d..b979404 100644
--- a/torch/nn/modules/upsampling.py
+++ b/torch/nn/modules/upsampling.py
@@ -3,30 +3,30 @@
from .module import Module
from .. import functional as F
-from .utils import _pair, _triple
class Upsample(Module):
"""
- Upsamples a given multi-channel 2D (spatial) or 3D (volumetric) data.
+ Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
- The input data is assumed to be of the form `minibatch x channels x [depth] x height x width`.
+ The input data is assumed to be of the form `minibatch x channels x [depth] x [height] x width`.
Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
- The algorithms available for upsampling are nearest neighbor, bilinear and trilinear upsampling,
- with bilinear only available for 4D Tensor inputs and trilinear for 4D Tensor inputs.
+ The algorithms available for upsampling are nearest neighbor and linear, bilinear and trilinear
+ for 3D, 4D and 5D input Tensor, respectively.
One can either give a :attr:`scale_factor` or the target output :attr:`size` to
calculate the output size. (You cannot give both, as it is ambiguous)
Args:
- size (tuple, optional): a tuple of ints ([D_out], H_out, W_out) output sizes
+ size (tuple, optional): a tuple of ints ([D_out], [H_out], W_out) output sizes
scale_factor (int / tuple of ints, optional): the multiplier for the image height / width / depth
- mode (string, optional): the upsampling algorithm: nearest | bilinear | trilinear. Default: nearest
+ mode (string, optional): the upsampling algorithm: nearest | linear | bilinear | trilinear. Default: nearest
Shape:
- - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
- - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+ - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
+ - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
+ or :math:`(N, C, D_{out}, H_{out}, W_{out})` where
:math:`D_{out} = floor(D_{in} * scale\_factor)` or `size[-3]`
:math:`H_{out} = floor(H_{in} * scale\_factor)` or `size[-2]`
:math:`W_{out} = floor(W_{in} * scale\_factor)` or `size[-1]`