Adding 1d upsampling (#2846)

diff --git a/test/test_nn.py b/test/test_nn.py
index 3ef89ea..b3ebc1a 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2800,6 +2800,24 @@
             self.assertEqual(out_cpu, out_cuda)
             self.assertEqual(input_cpu.grad, input_gpu.grad)
 
+    def test_upsamplingNearest1d(self):
+        m = nn.Upsample(size=4, mode='nearest')
+        in_t = torch.ones(1, 1, 2)
+        out_t = m(Variable(in_t))
+        self.assertEqual(torch.ones(1, 1, 4), out_t.data)
+
+        input = Variable(torch.randn(1, 1, 2), requires_grad=True)
+        self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='nearest'), (input,)))
+
+    def test_upsamplingLinear1d(self):
+        m = nn.Upsample(size=4, mode='linear')
+        in_t = torch.ones(1, 1, 2)
+        out_t = m(Variable(in_t))
+        self.assertEqual(torch.ones(1, 1, 4), out_t.data)
+
+        input = Variable(torch.randn(1, 1, 2), requires_grad=True)
+        self.assertTrue(gradcheck(lambda x: F.upsample(x, 4, mode='linear'), (input,)))
+
     def test_upsamplingNearest2d(self):
         m = nn.Upsample(size=4, mode='nearest')
         in_t = torch.ones(1, 1, 2, 2)
@@ -3970,6 +3988,42 @@
     dict(
         module_name='Upsample',
         constructor_args=(12, None, 'nearest'),
+        input_size=(1, 2, 4),
+        desc='nearest_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=((12, ), None, 'nearest'),
+        input_size=(1, 2, 3),
+        desc='nearest_tuple_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, 4, 'nearest'),
+        input_size=(1, 2, 4),
+        desc='nearest_scale_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(12, None, 'linear'),
+        input_size=(1, 2, 4),
+        desc='linear_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=((4, ), None, 'linear'),
+        input_size=(1, 2, 3),
+        desc='linear_tuple_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(None, 4, 'linear'),
+        input_size=(1, 2, 4),
+        desc='linear_scale_1d',
+    ),
+    dict(
+        module_name='Upsample',
+        constructor_args=(12, None, 'nearest'),
         input_size=(1, 2, 4, 4),
         desc='nearest_2d',
     ),
diff --git a/torch/lib/THCUNN/TemporalUpSamplingLinear.cu b/torch/lib/THCUNN/TemporalUpSamplingLinear.cu
new file mode 100644
index 0000000..fe59b2a
--- /dev/null
+++ b/torch/lib/THCUNN/TemporalUpSamplingLinear.cu
@@ -0,0 +1,96 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+#include "THCUNN.h"
+#include "common.h"
+#include "THCDeviceTensor.cuh"
+#include "THCDeviceTensorUtils.cuh"
+#include "THCDeviceUtils.cuh"
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+#include "THCAtomics.cuh"
+
+template<typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel(const int n,
+    const Acctype rwidth,
+    const THCDeviceTensor<Dtype, 3> data1, THCDeviceTensor<Dtype, 3> data2) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = data1.getSize(0);
+  const int channels = data1.getSize(1);
+  const int width1 = data1.getSize(2);
+  const int width2 = data2.getSize(2);
+
+  if (index < n) {
+    const int w2 = index % width2;
+    // special case: just copy
+    if (width1 == width2) {
+      const int w1 = w2;
+      for (int n = 0; n < batchsize ; n++){
+        for (int c = 0; c < channels; ++c) {
+          const Dtype val = data1[n][c][w1];
+          data2[n][c][w2] = val;
+        }
+      }
+      return;
+    }
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+    //
+    for (int n = 0; n < batchsize ; n++){
+        for (int c = 0; c < channels; ++c) {
+        const Acctype val = w0lambda * data1[n][c][w1]
+                            + w1lambda * data1[n][c][w1+w1p];
+        data2[n][c][w2] = ScalarConvert<Acctype, Dtype>::to(val);
+      }
+    }
+  }
+}
+
+// Backward (adjoint) operation 1 <- 2 (accumulates)
+template <typename Dtype, typename Acctype>
+__global__ void caffe_gpu_interp2_kernel_backward(const int n,
+    const Acctype rwidth,
+    THCDeviceTensor<Dtype, 3> data1, const THCDeviceTensor<Dtype, 3> data2){
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  const int batchsize = data1.getSize(0);
+  const int channels = data1.getSize(1);
+  const int width1 = data1.getSize(2);
+  const int width2 = data2.getSize(2);
+  if (index < n) {
+    const int w2 = index % width2;
+    // special case: just copy
+    if (width1 == width2) {
+      const int w1 = w2;
+      for (int n = 0; n < batchsize ; n++){
+        for (int c = 0; c < channels; ++c) {
+          const Dtype val = data2[n][c][w1];
+          data1[n][c][w2] += val;
+        }
+      }
+      return;
+    }
+    //
+    const Acctype w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < width1 - 1) ? 1 : 0;
+    const Acctype w1lambda = w1r - w1;
+    const Acctype w0lambda = Acctype(1) - w1lambda;
+    //
+    for (int n = 0; n < batchsize ; n++){
+      for (int c = 0; c < channels; ++c) {
+        const Dtype d2val = data2[n][c][w2];
+        atomicAdd(data1[n][c][w1].data(),
+                  ScalarConvert<Acctype, Dtype>::to(w0lambda * d2val));
+        atomicAdd(data1[n][c][w1+w1p].data(),
+                  ScalarConvert<Acctype, Dtype>::to(w1lambda * d2val));
+      }
+    }
+  }
+}
+
+
+#include "generic/TemporalUpSamplingLinear.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/TemporalUpSamplingNearest.cu b/torch/lib/THCUNN/TemporalUpSamplingNearest.cu
new file mode 100644
index 0000000..f5dd5d9
--- /dev/null
+++ b/torch/lib/THCUNN/TemporalUpSamplingNearest.cu
@@ -0,0 +1,75 @@
+#include "THCUNN.h"
+#include "common.h"
+
+#include <thrust/transform.h>
+#include <thrust/reduce.h>
+#include <thrust/transform_reduce.h>
+#include <thrust/functional.h>
+
+#include "THCHalf.h"
+#include "THCHalfAutoNumerics.cuh"
+
+/*
+ * Description:
+ */
+
+__device__ int translate_idx(int ii, int d1, int d2, int scale_factor)
+{
+  int x, y, z;
+  z = ii % d2;
+  ii = ii/d2;
+  y = ii % d1;
+  ii = ii/d1;
+  x = ii;
+  z = z/scale_factor;
+  d2 /= scale_factor;
+  return ((x*d1+y)*d2)+z;
+
+}
+__device__ int translate_idx_inv(int ii, int d1, int d2, int scale_factor, int off_x)
+{
+  int x, y, z;
+  z = ii % d2;
+  ii = ii/d2;
+  y = ii % d1;
+  ii = ii/d1;
+  x = ii;
+  z = z*scale_factor+off_x;
+  d2 *= scale_factor;
+  return ((x*d1+y)*d2)+z;
+
+}
+
+template <typename Dtype>
+__global__ void upscale(Dtype *input, Dtype *output, int64_t no_elements,
+                        int scale_factor, int d1, int d2)
+{
+  // output offset:
+  int64_t ii = threadIdx.x + blockDim.x * blockIdx.x;
+  ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+  if (ii >= no_elements) return;
+  int ipidx = translate_idx(ii, d1, d2, scale_factor);
+  output[ii]=input[ipidx];
+}
+
+/*
+ * Description:
+ */
+template <typename Dtype, typename Acctype>
+__global__ void downscale(Dtype *gradInput_data, Dtype *gradOutput_data, int64_t no_elements,
+                              int scale_factor, int d1, int d2)
+{
+  // output offset:
+  int64_t ii = threadIdx.x + blockDim.x * blockIdx.x;
+  ii += threadIdx.y + blockDim.y * (blockDim.x * gridDim.x) * blockIdx.y;
+  if (ii >= no_elements) return;
+  Acctype sum = Acctype(0);
+  for (int i=0; i < scale_factor; i++){
+    int ipidx = translate_idx_inv(ii, d1, d2, scale_factor, i);
+    sum += gradOutput_data[ipidx];
+  }
+  gradInput_data[ii] += ScalarConvert<Acctype, Dtype>::to(sum);
+}
+
+#include "generic/TemporalUpSamplingNearest.cu"
+#include "THCGenerateFloatTypes.h"
diff --git a/torch/lib/THCUNN/generic/THCUNN.h b/torch/lib/THCUNN/generic/THCUNN.h
index 2159ed3..9963706 100644
--- a/torch/lib/THCUNN/generic/THCUNN.h
+++ b/torch/lib/THCUNN/generic/THCUNN.h
@@ -1257,6 +1257,34 @@
                   THCTensor *gradInput,
                   int padL, int padR);
 
+TH_API void THNN_(TemporalUpSamplingLinear_updateOutput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *output,
+                  int outputWidth);
+
+TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+                  THCState *state,
+                  THCTensor *gradOutput,
+                  THCTensor *gradInput,
+                  int nbatch,
+                  int nchannels,
+                  int inputWidth,
+                  int outputWidth);
+
+TH_API void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *gradOutput,
+                  THCTensor *gradInput,
+                  int scale_factor);
+
+TH_API void THNN_(TemporalUpSamplingNearest_updateOutput)(
+                  THCState *state,
+                  THCTensor *input,
+                  THCTensor *output,
+                  int scale_factor);
+
 TH_API void THNN_(Threshold_updateOutput)(
                   THCState *state,
                   THCTensor *input,
diff --git a/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu b/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu
new file mode 100644
index 0000000..194794e
--- /dev/null
+++ b/torch/lib/THCUNN/generic/TemporalUpSamplingLinear.cu
@@ -0,0 +1,97 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/TemporalUpSamplingLinear.cu"
+#else
+
+static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
+                        (THCState *state,
+                         THCTensor *input, THCTensor *gradOutput,
+                         int nBatch, int nChannels,
+                         int inputWidth,
+                         int outputWidth) {
+  THArgCheck(inputWidth > 0 && outputWidth > 0, 2,
+             "input and output sizes should be greater than 0,"
+             " but got input (W: %d) output (W: %d)",
+             inputWidth, outputWidth);
+  if (input != NULL) {
+     THCUNN_argCheck(state, input->nDimension == 3, 2, input,
+                     "3D input tensor expected but got: %s");
+  }
+
+  if (gradOutput != NULL) {
+    THCUNN_check_dim_size(state, gradOutput, 3, 0, nBatch);
+    THCUNN_check_dim_size(state, gradOutput, 3, 1, nChannels);
+    THCUNN_check_dim_size(state, gradOutput, 3, 2, outputWidth);
+  }
+}
+
+void THNN_(TemporalUpSamplingLinear_updateOutput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *output,
+           int outputWidth)
+{
+  int nbatch = THCTensor_(size)(state, input, 0);
+  int channels = THCTensor_(size)(state, input, 1);
+  int inputWidth = THCTensor_(size)(state, input, 2);
+  THNN_(TemporalUpSamplingLinear_shapeCheck)
+       (state, input, NULL,
+        nbatch, channels,
+        inputWidth, outputWidth);
+  input = THCTensor_(newContiguous)(state, input);
+  THCUNN_assertSameGPU(state, 2, input, output);
+  THCTensor_(resize3d)(state, output,
+                       THCTensor_(size)(state, input, 0),
+                       THCTensor_(size)(state, input, 1),
+                       outputWidth);
+  THCTensor_(zero)(state, output);
+  THCDeviceTensor<real, 3> idata = toDeviceTensor<real, 3>(state, input);
+  THCDeviceTensor<real, 3> odata = toDeviceTensor<real, 3>(state, output);
+  THAssert(inputWidth > 0 && outputWidth > 0);
+  const accreal rwidth = (outputWidth > 1) ? (accreal)(inputWidth - 1)/(outputWidth - 1) : accreal(0);
+  const int num_kernels = outputWidth;
+  const int num_threads =
+    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+  cudaStream_t stream = THCState_getCurrentStream(state);
+  caffe_gpu_interp2_kernel<real, accreal> <<<THCCeilDiv(num_kernels, num_threads), num_threads ,
+   0 , stream>>>(num_kernels, rwidth, idata, odata);
+  THCudaCheck(cudaGetLastError());
+  THCTensor_(free)(state, input);
+}
+
+
+void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+           THCState *state,
+           THCTensor *gradOutput,
+           THCTensor *gradInput,
+           int nbatch,
+           int nchannels,
+           int inputWidth,
+           int outputWidth)
+{
+  THNN_(TemporalUpSamplingLinear_shapeCheck)
+       (state, NULL, gradOutput,
+        nbatch, nchannels,
+        inputWidth, outputWidth);
+  gradInput = THCTensor_(newContiguous)(state, gradInput);
+  gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+  THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+  THCTensor_(resize3d)(state, gradInput, nbatch, nchannels, inputWidth);
+  THCTensor_(zero)(state, gradInput);
+  THCDeviceTensor<real, 3> data1 = toDeviceTensor<real, 3>(state, gradInput);
+  THCDeviceTensor<real, 3> data2 = toDeviceTensor<real, 3>(state, gradOutput);
+  int width1 = data1.getSize(2);
+  int width2 = data2.getSize(2);
+  assert(width1 > 0 && width2 > 0);
+  const accreal rwidth = (width2 > 1) ? (accreal)(width1 - 1) / (width2 - 1) : accreal(0);
+  const int num_kernels = width2;
+  const int num_threads =
+    THCState_getCurrentDeviceProperties(state)->maxThreadsPerBlock;
+  cudaStream_t stream = THCState_getCurrentStream(state);
+  caffe_gpu_interp2_kernel_backward<real ,accreal> <<<THCCeilDiv(num_kernels, num_threads),
+  num_threads, 0, stream>>>(num_kernels, rwidth, data1, data2);
+  THCudaCheck(cudaGetLastError());
+  THCTensor_(free)(state, gradInput);
+  THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu b/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu
new file mode 100644
index 0000000..567346c
--- /dev/null
+++ b/torch/lib/THCUNN/generic/TemporalUpSamplingNearest.cu
@@ -0,0 +1,157 @@
+#ifndef THC_GENERIC_FILE
+#define THC_GENERIC_FILE "generic/TemporalUpSamplingNearest.cu"
+#else
+
+#include "../common.h"
+
+static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
+                        (THCState *state,THCTensor *input, THCTensor *gradOutput,
+                         int scale_factor) {
+  THArgCheck(input != NULL, 2, "3D input tensor expected but got NULL");
+  THArgCheck(scale_factor > 1, 4,
+             "scale_factor must be greater than 1, but got: %d", scale_factor);
+  THCUNN_argCheck(state, input->nDimension == 2 || input->nDimension == 3, 2, input,
+                  "2D or 3D input tensor expected but got: %s");
+  if (input->nDimension == 2) {
+    int nChannels    = THCTensor_(size)(state, input, 0);
+    int inputWidth   = THCTensor_(size)(state, input, 1);
+    int outputWidth  = inputWidth  * scale_factor;
+    if (gradOutput != NULL) {
+      THCUNN_check_dim_size(state, gradOutput, 2, 0, nChannels);
+      THCUNN_check_dim_size(state, gradOutput, 2, 1, outputWidth);
+    }
+  } else {
+    int nBatch       = THCTensor_(size)(state, input, 0);
+    int nChannels    = THCTensor_(size)(state, input, 1);
+    int inputWidth   = THCTensor_(size)(state, input, 2);
+    int outputWidth  = inputWidth  * scale_factor;
+    if (gradOutput != NULL) {
+      THCUNN_check_dim_size(state, gradOutput, 3, 0, nBatch);
+      THCUNN_check_dim_size(state, gradOutput, 3, 1, nChannels);
+      THCUNN_check_dim_size(state, gradOutput, 3, 2, outputWidth);
+    }
+  }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateOutput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *output,
+           int scale_factor)
+{
+  THCTensor_(zero)(state, output);
+
+  THCUNN_assertSameGPU(state, 2, input, output);
+  THNN_(TemporalUpSamplingNearest_shapeCheck)(state, input, NULL, scale_factor);
+  int inputWidth  = THCTensor_(size)(state, input,  input->nDimension-1);
+  int outputWidth = inputWidth * scale_factor;
+
+   if (input->nDimension == 2) {
+     THCTensor_(resize2d)(state, output,
+                          THCTensor_(size)(state, input, 0),
+                          outputWidth);
+   } else {
+     THCTensor_(resize3d)(state, output,
+                          THCTensor_(size)(state, input, 0),
+                          THCTensor_(size)(state, input, 1),
+                          outputWidth);
+  }
+
+  input = THCTensor_(newContiguous)(state, input);
+  // This is for allocating output Tensor
+  int64_t no_elements = 1;
+  for(int i = 0; i < input->nDimension; i++){
+    no_elements *= input->size[i];
+  }
+  no_elements *= scale_factor * scale_factor;
+
+  int d1;
+  int d2;
+
+  if (input->nDimension == 2) {
+    d1 = output->size[0];
+    d2 = output->size[1];
+  } else {
+    d1 = output->size[1];
+    d2 = output->size[2];
+  }
+
+  real *input_data = THCTensor_(data)(state, input);
+  real *output_data = THCTensor_(data)(state, output);
+
+  // cuda blocks & threads:
+  int64_t nthreads = 256;
+  // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+  // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+  // TODO: When we move to SM 3.5 we should update this
+  int64_t n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+  int64_t n_yblocks = (int64_t)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+  if (n_yblocks > 65535) {
+    THError("Input size is too large!  aborting");
+  }
+  dim3 blocks(n_xblocks, n_yblocks);
+  dim3 threads(nthreads);
+
+  // kernel:
+  upscale<<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (input_data, output_data, no_elements, scale_factor, d1, d2);
+  THCudaCheck(cudaGetLastError());
+
+  // final cut:
+  THCTensor_(free)(state, input);
+}
+
+void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+           THCState *state,
+           THCTensor *input,
+           THCTensor *gradOutput,
+           THCTensor *gradInput,
+           int scale_factor)
+{
+
+  THCUNN_assertSameGPU(state, 2, gradOutput, gradInput);
+  THNN_(TemporalUpSamplingNearest_shapeCheck)(state, input, gradOutput, scale_factor);
+  gradOutput = THCTensor_(newContiguous)(state, gradOutput);
+  THCTensor_(resizeAs)(state, gradInput, input);
+
+  THCTensor_(zero)(state, gradInput);
+
+  real *gradInput_data = THCTensor_(data)(state, gradInput);
+  real *gradOutput_data = THCTensor_(data)(state, gradOutput);
+
+  int64_t no_elements = 1;
+  for(int i = 0; i < gradInput->nDimension; i++){
+    no_elements *= gradInput->size[i];
+  }
+
+  int d1;
+  int d2;
+
+  if (gradInput->nDimension == 2) {
+    d1 = gradInput->size[0];
+    d2 = gradInput->size[1];
+  } else {
+    d1 = gradInput->size[1];
+    d2 = gradInput->size[2];
+  }
+
+  // cuda blocks & threads:
+  int64_t nthreads = 256;
+  // Max number of blocks: http://en.wikipedia.org/wiki/CUDA
+  // 65535 for SM 2.x, 2^32 -1 for >= 3.0
+  // TODO: When we move to SM 3.5 we should update this
+  int64_t n_xblocks = min(max((int)ceil((float)no_elements / nthreads), 1), 65535);
+  int64_t n_yblocks = (int64_t)ceil((float)no_elements / (float)(n_xblocks * nthreads));
+  if (n_yblocks > 65535) {
+    THError("Input size is too large!  aborting");
+  }
+  dim3 blocks(n_xblocks, n_yblocks);
+  dim3 threads(nthreads);
+
+  // kernel:
+  downscale<real ,accreal> <<<blocks, threads, 0, THCState_getCurrentStream(state)>>> (gradInput_data, gradOutput_data, no_elements,
+    scale_factor, d1, d2);
+  THCudaCheck(cudaGetLastError());
+  THCTensor_(free)(state, gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THNN/generic/THNN.h b/torch/lib/THNN/generic/THNN.h
index 8dda789..d7ee0a1 100644
--- a/torch/lib/THNN/generic/THNN.h
+++ b/torch/lib/THNN/generic/THNN.h
@@ -720,6 +720,32 @@
           bool featFirst,
           accreal scale);
 
+TH_API void THNN_(TemporalUpSamplingNearest_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          int scale_factor);
+TH_API void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          int scale_factor);
+
+TH_API void THNN_(TemporalUpSamplingLinear_updateOutput)(
+          THNNState *state,
+          THTensor *input,
+          THTensor *output,
+          int outputWidth);
+TH_API void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+          THNNState *state,
+          THTensor *gradOutput,
+          THTensor *gradInput,
+          int nbatch,
+          int nchannels,
+          int inputWidth,
+          int outputWidth);
+
 TH_API void THNN_(BatchNormalization_updateOutput)(
           THNNState *state,
           THTensor *input,
@@ -1203,7 +1229,7 @@
           THNNState *state,
           THTensor *input,
           THTensor *output,
-	  int outputHeight,
+          int outputHeight,
           int outputWidth);
 TH_API void THNN_(SpatialUpSamplingBilinear_updateGradInput)(
           THNNState *state,
diff --git a/torch/lib/THNN/generic/TemporalUpSamplingLinear.c b/torch/lib/THNN/generic/TemporalUpSamplingLinear.c
new file mode 100644
index 0000000..67606d7
--- /dev/null
+++ b/torch/lib/THNN/generic/TemporalUpSamplingLinear.c
@@ -0,0 +1,140 @@
+// Adapted from interp.cpp from Caffe util by Pauline Luc
+// Originally developed by George Papandreou
+
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/TemporalUpSamplingLinear.c"
+#else
+
+static inline void THNN_(TemporalUpSamplingLinear_shapeCheck)
+     (THTensor *input, THTensor *gradOutput,
+      int nBatch, int nChannels,
+      int inputWidth, int outputWidth) {
+  THArgCheck(inputWidth > 0 && outputWidth > 0, 2,
+	     "input and output sizes should be greater than 0,"
+	     " but got input (W: %d) output (W: %d)",
+	     inputWidth, outputWidth);
+  if (input != NULL) {
+    THNN_ARGCHECK(input->nDimension == 3, 2, input,
+		  "3D input tensor expected but got: %s");
+  }
+
+  if (gradOutput != NULL) {
+    THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch);
+    THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels);
+    THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth);
+  }
+}
+
+void THNN_(TemporalUpSamplingLinear_updateOutput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *output,
+    int outputWidth){
+
+  int nbatch = THTensor_(size)(input, 0);
+  int channels = THTensor_(size)(input, 1);
+  int inputWidth = THTensor_(size)(input, 2);
+
+  THNN_(TemporalUpSamplingLinear_shapeCheck)
+    (input, NULL,
+     nbatch, channels,
+     inputWidth, outputWidth);
+
+  input = THTensor_(newContiguous)(input);
+  THTensor_(resize3d)(output, 
+		      THTensor_(size)(input, 0), 
+		      THTensor_(size)(input, 1), 
+		      outputWidth);
+  THTensor_(zero)(output);
+  real *idata = THTensor_(data)(input);
+  real *odata = THTensor_(data)(output);
+  channels = nbatch * channels;
+  THAssert(inputWidth > 0 && outputWidth > 0);
+  // special case: just copy
+  if (inputWidth == outputWidth) {
+    for (int w2 = 0; w2 < outputWidth; ++w2) {
+      const int w1 = w2;
+      const real* pos1 = &idata[w1];
+      real* pos2 = &odata[w2];
+      for (int c = 0; c < channels; ++c) {
+        pos2[0] = pos1[0];
+        pos1 += inputWidth;
+        pos2 += outputWidth;
+      }
+    }
+    return;
+  }
+  const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1) / (outputWidth - 1) : 0.f;
+  for (int w2 = 0; w2 < outputWidth; ++w2) {
+    const float w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+    const real w1lambda = w1r - w1;
+    const real w0lambda = (real)1. - w1lambda;
+    const real* pos1 = &idata[w1];
+    real* pos2 = &odata[w2];
+    for (int c = 0; c < channels; ++c) {
+      pos2[0] = w0lambda * pos1[0] + w1lambda * pos1[w1p];
+      pos1 += inputWidth;
+      pos2 += outputWidth;
+    }
+  }
+  THTensor_(free)(input);
+}
+
+void THNN_(TemporalUpSamplingLinear_updateGradInput)(
+    THNNState *state,
+    THTensor *gradOutput,
+    THTensor *gradInput,
+    int nbatch,
+    int channels,
+    int inputWidth,
+    int outputWidth){
+
+  THNN_(TemporalUpSamplingLinear_shapeCheck)
+    (NULL, gradOutput,
+     nbatch, channels,
+     inputWidth,
+     outputWidth);
+
+  THTensor_(resize3d)(gradInput, nbatch, channels, inputWidth);
+  THTensor_(zero)(gradInput);
+  gradOutput = THTensor_(newContiguous)(gradOutput);
+  real *data1 = THTensor_(data)(gradInput);
+  real *data2 = THTensor_(data)(gradOutput);
+  channels = nbatch * channels;
+
+  // special case: same-size matching grids
+  if (inputWidth == outputWidth) {
+    for (int w2 = 0; w2 < outputWidth; ++w2) {
+      const int w1 = w2;
+      real* pos1 = &data1[w1];
+      const real* pos2 = &data2[w2];
+      for (int c = 0; c < channels; ++c) {
+        pos1[0] += pos2[0];
+        pos1 += inputWidth;
+        pos2 += outputWidth;
+      }
+    }
+    return;
+  }
+  const float rwidth = (outputWidth > 1) ? (float)(inputWidth - 1)/(outputWidth - 1) : 0.f;
+  for (int w2 = 0; w2 < outputWidth; ++w2) {
+    const float w1r = rwidth * w2;
+    const int w1 = w1r;
+    const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
+    const real w1lambda = w1r - w1;
+    const real w0lambda = (real)1. - w1lambda;
+    real* pos1 = &data1[w1];
+    const real* pos2 = &data2[w2];
+    for (int c = 0; c < channels; ++c) {
+      pos1[0] += w0lambda * pos2[0];
+      pos1[w1p] += w1lambda * pos2[0];
+      pos1 += inputWidth;
+      pos2 += outputWidth;
+    }
+  }
+  THTensor_(free)(gradOutput);
+}
+
+#endif
diff --git a/torch/lib/THNN/generic/TemporalUpSamplingNearest.c b/torch/lib/THNN/generic/TemporalUpSamplingNearest.c
new file mode 100644
index 0000000..bc8d3a8
--- /dev/null
+++ b/torch/lib/THNN/generic/TemporalUpSamplingNearest.c
@@ -0,0 +1,173 @@
+#ifndef TH_GENERIC_FILE
+#define TH_GENERIC_FILE "generic/TemporalUpSamplingNearest.c"
+#else
+
+
+static inline void THNN_(TemporalUpSamplingNearest_shapeCheck)
+     (THTensor *input, THTensor *gradOutput,
+      int scale_factor) {
+  THArgCheck(input != NULL, 2, "3D input tensor expected but got NULL");
+  THArgCheck(scale_factor > 1, 4,
+	     "scale_factor must be greater than 1, but got: %d", scale_factor);
+  THNN_ARGCHECK(input->nDimension == 2 || input->nDimension == 3, 2, input,
+		"2D or 3D input tensor expected but got: %s");
+  if (input->nDimension == 2) {
+    int nChannels    = THTensor_(size)(input, 0);
+    int inputWidth   = THTensor_(size)(input, 1);
+    int outputWidth  = inputWidth  * scale_factor;
+    if (gradOutput != NULL) {
+      THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nChannels);
+      THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, outputWidth);
+    }
+  } else {
+    int nBatch       = THTensor_(size)(input, 0);
+    int nChannels    = THTensor_(size)(input, 1);
+    int inputWidth   = THTensor_(size)(input, 2);
+    int outputWidth  = inputWidth  * scale_factor;
+    if (gradOutput != NULL) {
+      THNN_CHECK_DIM_SIZE(gradOutput, 3, 0, nBatch);
+      THNN_CHECK_DIM_SIZE(gradOutput, 3, 1, nChannels);
+      THNN_CHECK_DIM_SIZE(gradOutput, 3, 2, outputWidth);
+    }
+  }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateOutput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *output,
+    int scale_factor)
+{
+  THNN_(TemporalUpSamplingNearest_shapeCheck)(input, NULL, scale_factor);
+  int inputWidth  = THTensor_(size)(input,  input->nDimension-1);
+  int outputWidth = inputWidth * scale_factor;
+
+  if (input->nDimension == 2) {
+    THTensor_(resize2d)(output,
+			THTensor_(size)(input, 0),
+      outputWidth);    
+  } else {
+    THTensor_(resize3d)(output,
+			THTensor_(size)(input, 0),
+      THTensor_(size)(input, 1),
+      outputWidth);
+  }
+
+  int dW = scale_factor;
+  int xDim = input->nDimension-1;
+
+  // dims
+  int idim = input->nDimension;
+  int osz0 = output->size[0];
+  int osz1 = output->size[1];
+  int osz2 = 1;  
+  if (idim > 2) {
+    osz2 = output->size[2];
+  }
+
+  // get strides
+  int64_t *is = input->stride;
+  int64_t *os = output->stride;
+
+  // get raw pointers
+  real *pin = THTensor_(data)(input);
+  real *pout = THTensor_(data)(output);
+
+  // perform the upsampling
+  int i0, i1, i2, isrc, idst;
+  int iout[3];  // Output indices
+  int iin[3];  // Input indices
+
+  for (i0 = 0; i0 < osz0; i0++) {
+    iout[0] = i0;
+    iin[0] = i0;
+    for (i1 = 0; i1 < osz1; i1++) {
+      iout[1] = i1;
+      iin[1] = i1;
+      for (i2 = 0; i2 < osz2; i2++) {
+        iout[2] = i2;
+        iin[2] = i2;
+
+        // set the indices for the upsampled dimensions
+        iin[xDim] = iout[xDim] / dW;
+
+        idst = i0*os[0] + i1*os[1];
+        isrc = iin[0]*is[0] + iin[1]*is[1];
+        if (idim > 2) {
+          idst += i2*os[2];
+          isrc += iin[2]*is[2];
+        }
+
+        pout[idst] = pin[isrc];
+      }
+    }
+  }
+}
+
+void THNN_(TemporalUpSamplingNearest_updateGradInput)(
+    THNNState *state,
+    THTensor *input,
+    THTensor *gradOutput,
+    THTensor *gradInput,
+    int scale_factor)
+{
+  THNN_(TemporalUpSamplingNearest_shapeCheck)(input, gradOutput, scale_factor);
+  THTensor_(resizeAs)(gradInput, input);
+
+  int dW = scale_factor;
+  int xDim = gradInput->nDimension-1;
+
+  // dims
+  int idim = gradInput->nDimension;  // Guaranteed to be between 2 and 4
+  int isz0 = gradInput->size[0];
+  int isz1 = gradInput->size[1];
+  int isz2 = 1;
+  if (idim > 2) {
+    isz2 = gradInput->size[2];
+  }
+
+  // get strides
+  int64_t *is = gradInput->stride;
+  int64_t *os = gradOutput->stride;
+
+  // get raw pointers
+  real *pin = THTensor_(data)(gradInput);
+  real *pout = THTensor_(data)(gradOutput);
+
+  // perform the upsampling
+  int i0, i1, i2, isrc, idst, x, y;
+  int iin[3];  // Input indices
+  int iout[3];  // Output indices
+
+  THTensor_(zero)(gradInput);
+
+  for (i0 = 0; i0 < isz0; i0++) {
+    iin[0] = i0;
+    iout[0] = i0;
+    for (i1 = 0; i1 < isz1; i1++) {
+      iin[1] = i1;
+      iout[1] = i1;
+      for (i2 = 0; i2 < isz2; i2++) {
+        iin[2] = i2;
+        iout[2] = i2;
+
+        idst = i0*is[0] + i1*is[1];
+        if (idim > 2) {
+          idst += i2*is[2];
+        }
+
+        // Now accumulate the gradients from gradOutput
+        for (x = 0; x < dW; x++) {
+          iout[xDim] = dW * iin[xDim] + x;
+          isrc = iout[0]*os[0] + iout[1]*os[1];
+          if (idim > 2) {
+            isrc += iout[2]*os[2];
+          }
+          pin[idst] += pout[isrc];
+        }
+      }
+    }
+  }
+}
+
+#endif
diff --git a/torch/lib/THNN/init.c b/torch/lib/THNN/init.c
index ed28119..d0e6ba6 100644
--- a/torch/lib/THNN/init.c
+++ b/torch/lib/THNN/init.c
@@ -179,6 +179,12 @@
 #include "generic/TemporalRowConvolution.c"
 #include "THGenerateFloatTypes.h"
 
+#include "generic/TemporalUpSamplingNearest.c"
+#include "THGenerateFloatTypes.h"
+
+#include "generic/TemporalUpSamplingLinear.c"
+#include "THGenerateFloatTypes.h"
+
 #include "generic/FeatureLPPooling.c"
 #include "THGenerateFloatTypes.h"
 
diff --git a/torch/nn/_functions/thnn/upsampling.py b/torch/nn/_functions/thnn/upsampling.py
index 9ab14a1..48ecaeb 100644
--- a/torch/nn/_functions/thnn/upsampling.py
+++ b/torch/nn/_functions/thnn/upsampling.py
@@ -4,7 +4,7 @@
 from torch._thnn import type2backend
 
 from . import _all_functions
-from ...modules.utils import _pair, _triple
+from ...modules.utils import _single, _pair, _triple
 
 
 def _check_size_scale_factor(size, scale_factor):
@@ -14,6 +14,158 @@
         raise ValueError('scale_factor must be of integer type or a tuple of integer types')
 
 
+def _check_linear_scale_factor(scale_factor, dim=2):
+    if dim == 1:
+        scale_factor = _single(scale_factor)
+    elif dim == 2:
+        scale_factor = _pair(scale_factor)
+    elif dim == 3:
+        scale_factor = _triple(scale_factor)
+    else:
+        raise ValueError("dim has to be 1, 2 or 3")
+
+    try:
+        assert len(scale_factor) == 1 or len(scale_factor) == 2 or len(scale_factor) == 3
+        assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
+    except AssertionError as e:
+        raise ValueError('scale_factor must be a non-negative integer, '
+                         'or a tuple of non-negative integers for linear, bilinear and trilinear upsampling, but got: '
+                         '{}'.format(scale_factor))
+    return scale_factor
+
+
+class UpsamplingNearest1d(Function):
+
+    @staticmethod
+    def forward(ctx, input, size=None, scale_factor=None):
+        assert input.dim() == 3
+
+        _check_size_scale_factor(size, scale_factor)
+
+        ctx.size = size
+        ctx.scale_factor = scale_factor
+
+        if ctx.scale_factor is not None and not isinstance(ctx.scale_factor, Integral):
+            raise ValueError('scale_factor must be a single Integer value for nearest neighbor sampling')
+
+        if ctx.scale_factor is None:
+            if (ctx.size[0] % input.size(2) != 0):
+                raise RuntimeError("output size specified in UpsamplingNearest "
+                                   "({}) has to be divisible by the input size, but got: "
+                                   "{}".format('x'.join(map(str, ctx.size)),
+                                               'x'.join(map(str, input.size()))))
+            ctx.scale_factor = ctx.size[0] // input.size(2)
+
+        output = input.new()
+        backend = type2backend[type(input)]
+        ctx.save_for_backward(input)
+        backend.TemporalUpSamplingNearest_updateOutput(
+            backend.library_state,
+            input,
+            output,
+            ctx.scale_factor
+        )
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        input, = ctx.saved_variables
+        grad_input = UpsamplingNearest1dBackward.apply(input, grad_output, ctx.scale_factor)
+        return grad_input, None, None
+
+
+class UpsamplingNearest1dBackward(Function):
+
+    @staticmethod
+    def forward(ctx, input, grad_output, scale_factor):
+        assert grad_output.dim() == 3
+        ctx.scale_factor = scale_factor
+
+        grad_input = grad_output.new()
+        backend = type2backend[type(input)]
+        backend.TemporalUpSamplingNearest_updateGradInput(
+            backend.library_state,
+            input,
+            grad_output,
+            grad_input,
+            ctx.scale_factor
+        )
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, ggI):
+        gI = None
+        ggO = UpsamplingNearest1d.apply(ggI, None, ctx.scale_factor)
+
+        return gI, ggO, None
+
+
+class UpsamplingLinear1d(Function):
+
+    @staticmethod
+    def forward(ctx, input, size=None, scale_factor=None):
+        assert input.dim() == 3
+
+        ctx.size = size
+        ctx.scale_factor = scale_factor
+
+        if ctx.scale_factor is not None:
+            ctx.scale_factor = _check_linear_scale_factor(ctx.scale_factor, dim=1)
+
+        if ctx.scale_factor is not None:
+            ctx.output_size = (
+                input.size(2) * ctx.scale_factor[0],
+            )
+        else:
+            ctx.output_size = ctx.size
+
+        ctx.input_size = input.size()
+        output = input.new()
+        backend = type2backend[type(input)]
+        backend.TemporalUpSamplingLinear_updateOutput(
+            backend.library_state,
+            input,
+            output,
+            ctx.output_size[0]
+        )
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_output):
+        grad_input = UpsamplingLinear1dBackward.apply(grad_output, ctx.input_size, ctx.output_size)
+        return grad_input, None, None
+
+
+class UpsamplingLinear1dBackward(Function):
+
+    @staticmethod
+    def forward(ctx, grad_output, input_size, output_size):
+        assert grad_output.dim() == 3
+
+        ctx.input_size = input_size
+        ctx.output_size = output_size
+
+        grad_output = grad_output.contiguous()
+        grad_input = grad_output.new()
+        backend = type2backend[type(grad_output)]
+        backend.TemporalUpSamplingLinear_updateGradInput(
+            backend.library_state,
+            grad_output,
+            grad_input,
+            ctx.input_size[0],
+            ctx.input_size[1],
+            ctx.input_size[2],
+            ctx.output_size[0],
+        )
+        return grad_input
+
+    @staticmethod
+    def backward(ctx, ggI):
+        ggO = UpsamplingLinear1d.apply(ggI, ctx.output_size, None)
+
+        return ggO, None, None
+
+
 class UpsamplingNearest2d(Function):
 
     @staticmethod
@@ -84,24 +236,6 @@
         return gI, ggO, None
 
 
-def _check_linear_scale_factor(scale_factor, dim=2):
-    if dim == 2:
-        scale_factor = _pair(scale_factor)
-    elif dim == 3:
-        scale_factor = _triple(scale_factor)
-    else:
-        raise ValueError("dim has to be 2 or 3")
-
-    try:
-        assert len(scale_factor) == 2 or len(scale_factor) == 3
-        assert all(isinstance(s, Integral) and s >= 1 for s in scale_factor)
-    except AssertionError as e:
-        raise ValueError('scale_factor must be a non-negative integer, '
-                         'or a tuple of non-negative integers for bilinear and trilinear upsampling, but got: '
-                         '{}'.format(scale_factor))
-    return scale_factor
-
-
 class UpsamplingBilinear2d(Function):
 
     @staticmethod
@@ -311,6 +445,10 @@
         return ggO, None, None
 
 
+_all_functions.append(UpsamplingNearest1d)
+_all_functions.append(UpsamplingNearest1dBackward)
+_all_functions.append(UpsamplingLinear1d)
+_all_functions.append(UpsamplingLinear1dBackward)
 _all_functions.append(UpsamplingNearest2d)
 _all_functions.append(UpsamplingNearest2dBackward)
 _all_functions.append(UpsamplingBilinear2d)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 3ff9f0f..7469db5 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -940,38 +940,50 @@
 
     The algorithm used for upsampling is determined by :attr:`mode`.
 
-    Currently spatial and volumetric upsampling are supported, i.e.
-    expected inputs are 4-D or 5-D in shape.
+    Currently temporal, spatial and volumetric upsampling are supported, i.e.
+    expected inputs are 3-D, 4-D or 5-D in shape.
 
     The input dimensions are interpreted in the form:
-    `mini-batch x channels x [depth] x height x width`
+    `mini-batch x channels x [depth] x [height] x width`
 
-    The modes available for upsampling are: `nearest`, `bilinear` (4D-only),
-    `trilinear` (5D-only)
+    The modes available for upsampling are: `nearest`, `linear` (3D-only),
+    `bilinear` (4D-only), `trilinear` (5D-only)
 
     Args:
         input (Variable): input
-        size (int or Tuple[int, int] or Tuple[int, int, int]):
+        size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
             output spatial size.
         scale_factor (int): multiplier for spatial size. Has to be an integer.
         mode (string): algorithm used for upsampling:
-            'nearest' | 'bilinear' | 'trilinear'. Default: 'nearest'
+            'nearest' | 'linear' | 'bilinear' | 'trilinear'. Default: 'nearest'
     """
-    if input.dim() == 4 and mode == 'nearest':
+    if input.dim() == 3 and mode == 'nearest':
+        return _functions.thnn.UpsamplingNearest1d.apply(input, _single(size), scale_factor)
+    elif input.dim() == 4 and mode == 'nearest':
         return _functions.thnn.UpsamplingNearest2d.apply(input, _pair(size), scale_factor)
     elif input.dim() == 5 and mode == 'nearest':
         return _functions.thnn.UpsamplingNearest3d.apply(input, _triple(size), scale_factor)
+    elif input.dim() == 3 and mode == 'linear':
+        return _functions.thnn.UpsamplingLinear1d.apply(input, _single(size), scale_factor)
+    elif input.dim() == 3 and mode == 'bilinear':
+        raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input")
+    elif input.dim() == 3 and mode == 'trilinear':
+        raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input")
+    elif input.dim() == 4 and mode == 'linear':
+        raise NotImplementedError("Got 4D input, but linear mode needs 3D input")
     elif input.dim() == 4 and mode == 'bilinear':
         return _functions.thnn.UpsamplingBilinear2d.apply(input, _pair(size), scale_factor)
     elif input.dim() == 4 and mode == 'trilinear':
         raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input")
+    elif input.dim() == 5 and mode == 'linear':
+        raise NotImplementedError("Got 5D input, but linear mode needs 3D input")
     elif input.dim() == 5 and mode == 'bilinear':
         raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input")
     elif input.dim() == 5 and mode == 'trilinear':
         return _functions.thnn.UpsamplingTrilinear3d.apply(input, _triple(size), scale_factor)
     else:
-        raise NotImplementedError("Input Error: Only 4D and 5D input Tensors supported"
-                                  " (got {}D) for the modes: nearest | bilinear | trilinear"
+        raise NotImplementedError("Input Error: Only 3D, 4D and 5D input Tensors supported"
+                                  " (got {}D) for the modes: nearest | linear | bilinear | trilinear"
                                   " (got {})".format(input.dim(), mode))
 
 
diff --git a/torch/nn/modules/upsampling.py b/torch/nn/modules/upsampling.py
index d26bd9d..b979404 100644
--- a/torch/nn/modules/upsampling.py
+++ b/torch/nn/modules/upsampling.py
@@ -3,30 +3,30 @@
 
 from .module import Module
 from .. import functional as F
-from .utils import _pair, _triple
 
 
 class Upsample(Module):
     """
-    Upsamples a given multi-channel 2D (spatial) or 3D (volumetric) data.
+    Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
 
-    The input data is assumed to be of the form `minibatch x channels x [depth] x height x width`.
+    The input data is assumed to be of the form `minibatch x channels x [depth] x [height] x width`.
     Hence, for spatial inputs, we expect a 4D Tensor and for volumetric inputs, we expect a 5D Tensor.
 
-    The algorithms available for upsampling are nearest neighbor, bilinear and trilinear upsampling,
-    with bilinear only available for 4D Tensor inputs and trilinear for 4D Tensor inputs.
+    The algorithms available for upsampling are nearest neighbor and linear, bilinear and trilinear
+    for 3D, 4D and 5D input Tensor, respectively.
 
     One can either give a :attr:`scale_factor` or the target output :attr:`size` to
     calculate the output size. (You cannot give both, as it is ambiguous)
 
     Args:
-        size (tuple, optional): a tuple of ints ([D_out], H_out, W_out) output sizes
+        size (tuple, optional): a tuple of ints ([D_out], [H_out], W_out) output sizes
         scale_factor (int / tuple of ints, optional): the multiplier for the image height / width / depth
-        mode (string, optional): the upsampling algorithm: nearest | bilinear | trilinear. Default: nearest
+        mode (string, optional): the upsampling algorithm: nearest | linear | bilinear | trilinear. Default: nearest
 
     Shape:
-        - Input: :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
-        - Output: :math:`(N, C, H_{out}, W_{out})` or :math:`(N, C, D_{out}, H_{out}, W_{out})` where
+        - Input: :math:`(N, C, W_{in})`, :math:`(N, C, H_{in}, W_{in})` or :math:`(N, C, D_{in}, H_{in}, W_{in})`
+        - Output: :math:`(N, C, W_{out})`, :math:`(N, C, H_{out}, W_{out})`
+          or :math:`(N, C, D_{out}, H_{out}, W_{out})` where
           :math:`D_{out} = floor(D_{in} * scale\_factor)` or `size[-3]`
           :math:`H_{out} = floor(H_{in} * scale\_factor)` or `size[-2]`
           :math:`W_{out} = floor(W_{in}  * scale\_factor)` or `size[-1]`