Revert D26820202: Support mix of int32 and int64 offsets/indices for EmbeddingBag and its variants

Test Plan: revert-hammer

Differential Revision:
D26820202 (https://github.com/pytorch/pytorch/commit/f9097c43b95e4fe94ec55071bcb22968bdc40e83)

Original commit changeset: 3e8f09523329

fbshipit-source-id: 5742b69a96ce1c848d75348d0f761cf66a69cbf3
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index fba51e2..5d57fce 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -592,22 +592,12 @@
 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
 std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl(
     const Tensor& weight,
-    const Tensor& indices_,
-    const Tensor& offsets_,
+    const Tensor& indices,
+    const Tensor& offsets,
     const int64_t mode,
     const Tensor& per_sample_weights,
     bool include_last_offset,
     bool requires_grad) {
-  Tensor indices = indices_;
-  Tensor offsets = offsets_;
-  const auto commonType =
-      promoteTypes(offsets.scalar_type(), indices.scalar_type());
-  if (indices.scalar_type() != commonType) {
-    indices = indices.toType(commonType);
-  }
-  if (offsets.scalar_type() != commonType) {
-    offsets = offsets.toType(commonType);
-  }
 
   check_arguments(weight, indices, offsets, mode, per_sample_weights, include_last_offset);
 
@@ -695,8 +685,8 @@
 
 // Assumes all input tensors are contiguous.
 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
-Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_,
-                              const Tensor &offsets_,
+Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices,
+                              const Tensor &offsets,
                               const Tensor &offset2bag,
                               const Tensor &bag_size_,
                               const Tensor &max_indices_,
@@ -706,17 +696,6 @@
   // See [Note: hacky wrapper removal for optional tensor]
   const Tensor& per_sample_weights = c10::value_or_else(per_sample_weights_opt, [] {return Tensor();});
 
-  Tensor indices = indices_;
-  Tensor offsets = offsets_;
-  const auto commonType =
-      promoteTypes(offsets.scalar_type(), indices.scalar_type());
-  if (indices.scalar_type() != commonType) {
-    indices = indices.toType(commonType);
-  }
-  if (offsets.scalar_type() != commonType) {
-    offsets = offsets.toType(commonType);
-  }
-
   auto indices_arg = TensorArg(indices, "indices", 1);
   checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
   checkContiguous("embedding_bag", indices_arg);
@@ -936,8 +915,8 @@
 Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
     const Tensor& grad,
     const Tensor& weight,  // NB: embedding table, not per_sample_weights
-    const Tensor& indices_,
-    const Tensor& offsets_,
+    const Tensor& indices,
+    const Tensor& offsets,
     const Tensor& offset2bag,
     int64_t mode) {
   TORCH_CHECK(
@@ -947,17 +926,6 @@
   AT_ASSERT(grad.dim() == 2);
   auto embedding_features = grad.sizes()[1];
 
-  Tensor indices = indices_;
-  Tensor offsets = offsets_;
-  const auto commonType =
-      promoteTypes(offsets.scalar_type(), indices.scalar_type());
-  if (indices.scalar_type() != commonType) {
-    indices = indices.toType(commonType);
-  }
-  if (offsets.scalar_type() != commonType) {
-    offsets = offsets.toType(commonType);
-  }
-
   AT_ASSERT(indices.dim() == 1);
   auto num_samples = indices.sizes()[0];
 
diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
index 1b98ab3..b192f0b 100644
--- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp
+++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp
@@ -479,28 +479,21 @@
   return self.clone(at::MemoryFormat::Preserve).index_copy_(dim, index, source);
 }
 
-Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source_, const Scalar &alpha) {
+
+Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source, const Scalar &alpha) {
   dim = maybe_wrap_dim(dim, self.dim());
 
   auto numel = index.numel();
   TORCH_CHECK_INDEX(index.dim() <= 1, "index_add_(): Index is supposed to be a vector");
   TORCH_CHECK(index.scalar_type() == ScalarType::Long || index.scalar_type() == ScalarType::Int,
           "index_add_(): Expected dtype int32/int64 for index");
-  TORCH_CHECK(dim == 0 || dim < source_.dim(),
+  TORCH_CHECK(self.scalar_type() == source.scalar_type(),
+              "index_add_(): self and source must have the same scalar type");
+  TORCH_CHECK(dim == 0 || dim < source.dim(),
               "index_add_(): Indexing dim ", dim, " is out of bounds of tensor");
-  TORCH_CHECK(numel == (source_.dim() == 0 ? 1 : source_.size(dim)),
+  TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)),
               "index_add_(): Number of indices should be equal to self.size(dim)");
 
-  Tensor source = source_;
-  const auto selfType = self.scalar_type();
-  const auto commonType = promoteTypes(selfType, source.scalar_type());
-  auto promoteToType = [&](Tensor& tensor, ScalarType type) {
-    if (tensor.scalar_type() != type) {
-      tensor = tensor.toType(type);
-    }
-  };
-  promoteToType(source, commonType);
-
   at::assert_no_internal_overlap(self);
   at::assert_no_overlap(self, index);
   at::assert_no_overlap(self, source);
@@ -518,8 +511,6 @@
     if (numel == 0) {
       return self;
     }
-
-    promoteToType(self, commonType);
     auto selfSlice = self.select(dim, 0);
     auto sourceSlice = source.select(dim, 0);
     auto self_stride_bytes = self.stride(dim) * elementSize(self.scalar_type());
@@ -544,7 +535,6 @@
   else {
     TORCH_CHECK(source.dim() <= 1, "source.dim() (", source.dim(), ") must one or zero for given self.dim() (", self.dim(), ")");
 
-    promoteToType(self, commonType);
     // explicitly capture all required variables to work around windows build
     // TODO: fix this when windows can correctly capture variables in nested lambda
     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
@@ -567,7 +557,6 @@
       });
     });
   }
-  promoteToType(self, selfType);
   return self;
 }
 
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index 977affd..ae13be2 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -276,24 +276,13 @@
 // Assumes all input tensors are contiguous.
 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
 std::tuple<Tensor, Tensor, Tensor, Tensor>
-_embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
-                   const Tensor &offsets_, const bool scale_grad_by_freq,
+_embedding_bag_cuda(const Tensor &weight, const Tensor &indices,
+                   const Tensor &offsets, const bool scale_grad_by_freq,
                    const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt,
                    bool include_last_offset) {
   // See [Note: hacky wrapper removal for optional tensor]
   const Tensor& per_sample_weights = c10::value_or_else(per_sample_weights_opt, [] {return Tensor();});
 
-  Tensor indices = indices_;
-  Tensor offsets = offsets_;
-  const auto commonType =
-      promoteTypes(offsets.scalar_type(), indices.scalar_type());
-  if (indices.scalar_type() != commonType) {
-    indices = indices.toType(commonType);
-  }
-  if (offsets.scalar_type() != commonType) {
-    offsets = offsets.toType(commonType);
-  }
-
   auto indices_arg = TensorArg(indices, "indices", 1);
   checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
   auto offsets_arg = TensorArg(offsets, "offsets", 1);
@@ -447,8 +436,8 @@
 Tensor _embedding_bag_per_sample_weights_backward_cuda(
     const Tensor& grad,
     const Tensor& weight,  // NB: embedding table, not per_sample_weights
-    const Tensor& indices_,
-    const Tensor& offsets_,
+    const Tensor& indices,
+    const Tensor& offsets,
     const Tensor& offset2bag,
     int64_t mode) {
   TORCH_CHECK(
@@ -458,17 +447,6 @@
   AT_ASSERT(grad.dim() == 2);
   auto embedding_features = grad.size(1);
 
-  Tensor indices = indices_;
-  Tensor offsets = offsets_;
-  const auto commonType =
-      promoteTypes(offsets.scalar_type(), indices.scalar_type());
-  if (indices.scalar_type() != commonType) {
-    indices = indices.toType(commonType);
-  }
-  if (offsets.scalar_type() != commonType) {
-    offsets = offsets.toType(commonType);
-  }
-
   AT_ASSERT(indices.dim() == 1);
   auto num_samples = indices.size(0);
 
diff --git a/test/test_nn.py b/test/test_nn.py
index 8b4d431..73ed85c 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -13573,27 +13573,27 @@
         self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool2d(t, []))
         self.assertRaises(RuntimeError, lambda: F.adaptive_max_pool3d(t, []))
 
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
-    def test_embedding_bag_empty_input(self, device, dtypes):
+    @dtypes(torch.int, torch.long)
+    def test_embedding_bag_empty_input(self, device, dtype):
         m = 4
         n = 3
-        x = torch.tensor([], device=device, dtype=dtypes[0])
+        x = torch.tensor([], device=device, dtype=dtype)
         for sparse in [True, False]:
             Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse)
             Embed.to(device)
 
-            output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtypes[1]))
+            output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=dtype))
             self.assertEqual(output, torch.zeros_like(output))
 
-            output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtypes[1]))
+            output = Embed(input=x, offsets=torch.tensor([0, 0], device=device, dtype=dtype))
             self.assertEqual(output, torch.zeros_like(output))
 
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
-    def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes):
+    @dtypes(torch.int, torch.long)
+    def test_EmbeddingBag_per_sample_weights_failures(self, device, dtype):
         # Failure 1: mismatched embeddings / per_sample_weights dtype
         es = nn.EmbeddingBag(5, 2, mode='sum').to(dtype=torch.float, device=device)
-        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
-        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
+        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
+        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
         per_sample_weights = torch.randn_like(input, dtype=torch.double, device=device)
         if device == 'cpu':
             with self.assertRaisesRegex(RuntimeError, 'have the same type as'):
@@ -13603,14 +13603,14 @@
                 es(input, offsets, per_sample_weights)
 
         # Failure 2.1: input/per_sample_weights have different sizes (1d input)
-        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device)
-        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device)
+        input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtype, device=device)
+        offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtype, device=device)
         per_sample_weights = torch.randn(5, dtype=torch.float, device=device)
         with self.assertRaisesRegex(ValueError, 'same shape as the input'):
             es(input, offsets, per_sample_weights)
 
         # Failure 2.2: input/per_sample_weights have different sizes (2d input)
-        input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
+        input = torch.randint(5, (7, 3), dtype=dtype, device=device)
         offsets = None
         per_sample_weights = torch.randn(7 * 3, dtype=torch.float, device=device)
         with self.assertRaisesRegex(ValueError, 'same shape as the input'):
@@ -13620,7 +13620,7 @@
         for unsupported_mode in ('max', 'mean'):
             es = nn.EmbeddingBag(5, 2, mode=unsupported_mode).to(
                 dtype=torch.float, device=device)
-            input = torch.randint(5, (7, 3), dtype=dtypes[0], device=device)
+            input = torch.randint(5, (7, 3), dtype=dtype, device=device)
             offsets = None
             per_sample_weights = torch.randn(7, 3, dtype=torch.float, device=device)
             with self.assertRaisesRegex(NotImplementedError,
@@ -13682,18 +13682,18 @@
                         bags.append(embeddings.narrow(0, offset, length).max(0)[0])
         return torch.stack(bags)
 
-    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double)))
+    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
+    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
     def test_EmbeddingBag_empty_per_sample_weights_and_offsets(self, device, dtypes):
         # Test empty input and per sample weight, and backward pass. There was a CUDA
         # invalid configuration bug (more context in #46572)
         def test_per_sample_weights(mode, trainable_scale):
-            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
+            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
             es.weight.data.copy_(
-                torch.arange(1, 11, device=device, dtype=dtypes[2]).view_as(es.weight))
+                torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
             input = torch.tensor([], device=device, dtype=dtypes[0])
-            offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[1])
-            per_sample_weights = torch.randn_like(input, dtype=dtypes[2]) \
+            offsets = torch.tensor([0, 0, 0, 0, 0], device=device, dtype=dtypes[0])
+            per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
                                       .requires_grad_(trainable_scale)
             ref_per_sample_weights = \
                 per_sample_weights.detach().requires_grad_(trainable_scale)
@@ -13702,7 +13702,7 @@
             expected = self._embedding_bag_reference_impl(
                 input, reference_weights, offsets, mode, ref_per_sample_weights)
             result = es(input, offsets, per_sample_weights)
-            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
             grad = torch.randn_like(expected)
             result.backward(grad)
@@ -13710,27 +13710,27 @@
             # simply be a zero tensor
             ref_weights_grad = torch.zeros_like(es.weight)
             self.assertEqual(es.weight.grad, ref_weights_grad,
-                             atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                             atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
             if trainable_scale:
                 ref_per_sample_weights_grad = torch.empty_like(per_sample_weights)
                 self.assertEqual(per_sample_weights.grad, ref_per_sample_weights_grad,
-                                 atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                                 atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
         modes = ('sum',)
         trainable_scale = (True, False)
         for mode, trainable in itertools.product(modes, trainable_scale):
             test_per_sample_weights(mode, trainable)
 
-    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double)))
+    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
+    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
     def test_EmbeddingBag_per_sample_weights_and_offsets(self, device, dtypes):
         def test_per_sample_weights(mode, trainable_scale):
-            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[2], device=device)
+            es = nn.EmbeddingBag(5, 2, mode=mode).to(dtype=dtypes[1], device=device)
             es.weight.data.copy_(
-                torch.arange(1, 11, device=device, dtype=dtypes[2]).view_as(es.weight))
+                torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
             input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
-            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
-            per_sample_weights = torch.randn_like(input, dtype=dtypes[2]) \
+            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
+            per_sample_weights = torch.randn_like(input, dtype=dtypes[1]) \
                                       .requires_grad_(trainable_scale)
             ref_per_sample_weights = \
                 per_sample_weights.detach().requires_grad_(trainable_scale)
@@ -13739,37 +13739,37 @@
             expected = self._embedding_bag_reference_impl(
                 input, reference_weights, offsets, mode, ref_per_sample_weights)
             result = es(input, offsets, per_sample_weights)
-            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
-            grad = torch.randn_like(expected).to(dtype=dtypes[2], device=device)
+            grad = torch.randn_like(expected)
             result.backward(grad)
             expected.backward(grad)
             self.assertEqual(es.weight.grad, reference_weights.grad,
-                             atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                             atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
             if trainable_scale:
                 self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
-                                 atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                                 atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
         modes = ('sum',)
         trainable_scale = (True, False)
         for mode, trainable in itertools.product(modes, trainable_scale):
             test_per_sample_weights(mode, trainable)
 
-    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double)))
+    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
+    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
     def test_EmbeddingBag_per_sample_weights_and_new_offsets(self, device, dtypes):
         def test_per_sample_weights_new_offsets(mode, trainable_scale, include_last_offset, has_weight=True):
-            es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[2], device=device)
+            es = nn.EmbeddingBag(5, 2, mode=mode, include_last_offset=include_last_offset).to(dtype=dtypes[1], device=device)
             es.weight.data.copy_(
-                torch.arange(1, 11, device=device, dtype=dtypes[2]).view_as(es.weight))
+                torch.arange(1, 11, device=device, dtype=dtypes[1]).view_as(es.weight))
             input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtypes[0])
-            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[1])
+            offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtypes[0])
 
             if include_last_offset:
-                offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[1])), 0)
+                offsets = torch.cat((offsets, torch.tensor([input.size(0)], device=device, dtype=dtypes[0])), 0)
 
             if has_weight:
-                per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[2]) \
+                per_sample_weights = torch.randn_like(input, device=device, dtype=dtypes[1]) \
                                           .requires_grad_(trainable_scale)
                 ref_per_sample_weights = \
                     per_sample_weights.detach().requires_grad_(trainable_scale)
@@ -13782,16 +13782,16 @@
             expected = self._embedding_bag_reference_impl(
                 input, reference_weights, offsets, mode, ref_per_sample_weights, include_last_offset)
             result = es(input, offsets, per_sample_weights)
-            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+            self.assertEqual(result, expected, atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
             grad = torch.randn_like(expected)
             result.backward(grad)
             expected.backward(grad)
             self.assertEqual(es.weight.grad, reference_weights.grad,
-                             atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                             atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
             if has_weight and trainable_scale:
                 self.assertEqual(per_sample_weights.grad, ref_per_sample_weights.grad,
-                                 atol=dtype2prec_DONTUSE[dtypes[2]], rtol=0)
+                                 atol=dtype2prec_DONTUSE[dtypes[1]], rtol=0)
 
         trainable_scale = (True, False)
         include_last_offset = (True, False)
@@ -13858,7 +13858,7 @@
 
         # We have more floating point error here because we are dealing with larger numbers
         if backward_prec is None:
-            needed_prec = dtype2prec_DONTUSE[wdtype] * 5
+            needed_prec = dtype2prec_DONTUSE[wdtype] * 3
         else:
             needed_prec = backward_prec
 
@@ -13905,21 +13905,12 @@
                     itertools.product(modes, sparsity, trainable_scale):
                 run_tests(mode, sparse, trainable_per_sample_weights)
 
-    def _test_EmbeddingBag(
-        self,
-        device,
-        mode,
-        sparse,
-        wdtype=torch.double,
-        dtype=torch.long,
-        odtype=torch.long,
-        test_backward=True,
-    ):
+    def _test_EmbeddingBag(self, device, mode, sparse, wdtype=torch.double, dtype=torch.long, test_backward=True):
         # check a known test example
         es = nn.EmbeddingBag(5, 2, mode=mode, sparse=sparse).to(device, wdtype)
         es.weight.data.copy_(torch.arange(1, 11, device=device, dtype=wdtype).view_as(es.weight))
         input = torch.tensor([3, 1, 1, 1, 4, 0], device=device, dtype=dtype)
-        offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=odtype)
+        offsets = torch.tensor([0, 0, 3, 3, 6], device=device, dtype=dtype)
 
         grad_output = torch.tensor(
             [1, 2,
@@ -13992,7 +13983,7 @@
         # test all empty bags
         es.zero_grad()
         inputs = torch.tensor([], dtype=dtype, device=device)
-        offsets = torch.tensor([0, 0, 0, 0], dtype=odtype, device=device)
+        offsets = torch.tensor([0, 0, 0, 0], dtype=dtype, device=device)
         es(inputs, offsets).sum().backward()
         dense_grad = es.weight.grad
         if dense_grad.is_sparse:
@@ -14010,7 +14001,7 @@
         # check that giving illegal input combos raises error
         es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
         input = torch.ones(3, 4, dtype=dtype)
-        offset = torch.arange(0, 3, dtype=odtype)
+        offset = torch.arange(0, 3, dtype=dtype)
         self.assertRaises(ValueError, lambda: es(input, offset))
         self.assertRaises(ValueError, lambda: es(input.view(-1)))
         offset[0] = 1
@@ -14020,51 +14011,35 @@
             offset[-1] = 100
             self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
 
-    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double)))
+    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
+    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
     def test_embedding_bag_device(self, device, dtypes):
-        self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1])
-        self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1])
-        self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[2], dtype=dtypes[0], odtype=dtypes[1])
+        self._test_EmbeddingBag(device, 'sum', False, wdtype=dtypes[1], dtype=dtypes[0])
+        self._test_EmbeddingBag(device, 'mean', False, wdtype=dtypes[1], dtype=dtypes[0])
+        self._test_EmbeddingBag(device, 'max', False, wdtype=dtypes[1], dtype=dtypes[0])
 
         test_backward = False
         if self.device_type == 'cuda':
             # see 'todo' in test_embedding_bag.
-            test_backward = dtypes[2] is not torch.float16
+            test_backward = dtypes[1] is not torch.float16
         elif self.device_type == 'cpu':
             # TODO: figure out why precision on sparse embeddings isn't the
             # same as for dense.
-            test_backward = dtypes[2] is not torch.float
+            test_backward = dtypes[1] is not torch.float
 
-        self._test_EmbeddingBag(
-            device,
-            'sum',
-            True,
-            wdtype=dtypes[2],
-            dtype=dtypes[0],
-            odtype=dtypes[1],
-            test_backward=test_backward,
-        )
-        self._test_EmbeddingBag(
-            device,
-            'mean',
-            True,
-            wdtype=dtypes[2],
-            dtype=dtypes[0],
-            odtype=dtypes[1],
-            test_backward=test_backward,
-        )
+        self._test_EmbeddingBag(device, 'sum', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
+        self._test_EmbeddingBag(device, 'mean', True, wdtype=dtypes[1], dtype=dtypes[0], test_backward=test_backward)
 
-    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double, torch.half)))
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long), (torch.float, torch.double)))
+    @dtypesIfCUDA(*itertools.product((torch.int, torch.long), (torch.float, torch.double, torch.half)))
+    @dtypes(*itertools.product((torch.int, torch.long), (torch.float, torch.double)))
     def test_embedding_bag_non_contiguous_weight(self, device, dtypes):
-        weight_tensor = torch.randn(3, 4, dtype=dtypes[2], device=device)
+        weight_tensor = torch.randn(3, 4, dtype=dtypes[1], device=device)
 
         weight_tensor_non_contig = weight_tensor[:, :3]  # This is non-contiguous strided.
         weight_tensor_contig = weight_tensor_non_contig.clone().contiguous()  # Contig-strided.
 
         index = torch.tensor([0, 1, 2], dtype=dtypes[0], device=device)
-        offsets = torch.tensor([0, 2], dtype=dtypes[1], device=device)
+        offsets = torch.tensor([0, 2], dtype=dtypes[0], device=device)
         for mode in ['sum', 'mean', 'max']:
             output_non_contig = F.embedding_bag(
                 input=index,
@@ -14082,10 +14057,10 @@
 
 
     @onlyCUDA
-    @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
-    def test_embedding_bag_bfloat16(self, device, dtypes):
-        self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
-        self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtypes[0], odtype=dtypes[1], test_backward=True)
+    @dtypes(torch.int, torch.long)
+    def test_embedding_bag_bfloat16(self, device, dtype):
+        self._test_EmbeddingBag(device, 'sum', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
+        self._test_EmbeddingBag(device, 'mean', True, wdtype=torch.bfloat16, dtype=dtype, test_backward=True)
 
 
     @onlyCUDA
diff --git a/torch/nn/modules/sparse.py b/torch/nn/modules/sparse.py
index 19487f8..283afb7 100644
--- a/torch/nn/modules/sparse.py
+++ b/torch/nn/modules/sparse.py
@@ -262,6 +262,8 @@
     Inputs: :attr:`input` (IntTensor or LongTensor), :attr:`offsets` (IntTensor or LongTensor, optional), and
         :attr:`per_index_weights` (Tensor, optional)
 
+        - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
+
         - If :attr:`input` is 2D of shape `(B, N)`,
 
           it will be treated as ``B`` bags (sequences) each of fixed length ``N``, and