improve EmbeddingBag performance on cuda (#33589)

Summary:
This PR improves performance of EmbeddingBag on cuda by removing 5 kernel launches (2 of those are synchronizing memcopies).
- 2 memcopies are checking values of offsets[0] and offsets[-1] to be in expected range (0 for the former, less than number of indices for the latter). It seems strange to check only those 2 values, if users are providing invalid offsets, invalid values can be anywhere in the array, not only the first and last element. After this PR, the checks are skipped on cuda, the first value is forced to 0, if the last value is larger than expected, cuda kernel will assert. It is less nice than ValueError, but then again, the kernel could have asserted if other offset values were invalid. On the cpu, the checks are moved inside the cpu implementation from functional.py, and will throw RuntimeError instead of ValueError.
- 3 or 4 initializations (depending on the mode) of the output tensors with .zeros() are unnecessary, because every element of those tensors is written to, so their data can be uninitialized on the start.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33589

Reviewed By: jianyuh

Differential Revision: D20078011

Pulled By: ngimel

fbshipit-source-id: 2fb2e2080313af64adc5cf1b9fc6ffbdc6efaf16
diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp
index 7ac3111..29cf2da 100644
--- a/aten/src/ATen/native/EmbeddingBag.cpp
+++ b/aten/src/ATen/native/EmbeddingBag.cpp
@@ -374,6 +374,14 @@
   checkScalarType("embedding_bag", offsets_arg, kLong);
   auto weight_arg = TensorArg(weight, "weight", 1);
   checkScalarTypes("embedding_bag", weight_arg, {kFloat, kDouble});
+  int64_t offset_0 = offsets.data_ptr<int64_t>()[0];
+  int64_t offset_n = offsets.data_ptr<int64_t>()[offsets.size(0)-1];
+  TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence "
+                             "in the mini-batch has to start from position 0. "
+                             "However, got ", offsets[0]);
+  TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not "
+               "be greater than input's length ", indices.size(0), " but got offsets[-1] of ",
+               offset_n);
 
   if (per_sample_weights.defined()) {
     TORCH_CHECK(mode == MODE_SUM,
@@ -381,8 +389,8 @@
     auto per_input_weights_arg = TensorArg(
         per_sample_weights,"per_sample_weights", 1);
     checkSameType("embedding_bag", weight_arg, per_input_weights_arg);
-    AT_ASSERT(per_sample_weights.dim() == 1);
-    AT_ASSERT(per_sample_weights.numel() == indices.numel());
+    TORCH_CHECK(per_sample_weights.dim() == 1);
+    TORCH_CHECK(per_sample_weights.numel() == indices.numel());
   }
 
   auto bag_size = make_bag_size(offsets, indices, mode, weight.requires_grad());
diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu
index a138d98..6fdb5c6 100644
--- a/aten/src/ATen/native/cuda/EmbeddingBag.cu
+++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu
@@ -52,7 +52,7 @@
     if (featureDim < featureSize) {
       int64_t bag = chunk / chunksPerBag;
       scalar_t *weightFeat = weight + featureDim * weight_stride1;
-      int64_t begin = offsets[bag];
+      int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
       int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
       assert(end >= begin);
 
@@ -276,21 +276,21 @@
   }
   int64_t featureSize = weight.size(1);
 
-  auto bag_size = at::zeros(offsets.sizes(), indices.options());
+  auto bag_size = at::empty(offsets.sizes(), indices.options());
   auto offset2bag =
-      at::zeros({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]
+      at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]
 
   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
-  auto output = at::zeros({numBags, featureSize}, weight.options());
+  auto output = at::empty({numBags, featureSize}, weight.options());
 
   Tensor max_indices;
 
   if (mode == MODE_MAX) {
-    max_indices = at::zeros({numBags, featureSize}, indices.options());
+    max_indices = at::empty({numBags, featureSize}, indices.options());
   } else {
     // No need to allocate if we aren't doing a backwards pass
-    max_indices = at::zeros({0}, indices.options());
+    max_indices = at::empty({0}, indices.options());
   }
 
 #ifdef __HIP_PLATFORM_HCC__
diff --git a/test/test_nn.py b/test/test_nn.py
index 66bb99c..6a66454 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -10132,7 +10132,6 @@
                  [0, 0],
                  [1, 2],
                  [3, 4]], device=device, dtype=dtype)
-
         output = es(input, offsets)
         output.backward(grad_output_with_empty)
 
@@ -10174,15 +10173,16 @@
 
         # check that giving illegal input combos raises error
         es = nn.EmbeddingBag(10, 20, mode=mode, sparse=sparse)
-        input = torch.ones(3, 4)
+        input = torch.ones(3, 4, dtype=torch.long)
         offset = torch.arange(0, 3)
         self.assertRaises(ValueError, lambda: es(input, offset))
         self.assertRaises(ValueError, lambda: es(input.view(-1)))
         offset[0] = 1
-        self.assertRaises(ValueError, lambda: es(input.view(-1), offset))
-        offset[0] = 0
-        offset[-1] = 100
-        self.assertRaises(ValueError, lambda: es(input.view(-1), offset))
+        if self.device_type == "cpu":
+            self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
+            offset[0] = 0
+            offset[-1] = 100
+            self.assertRaises(RuntimeError, lambda: es(input.view(-1), offset))
 
     @dtypesIfCUDA(torch.half, torch.float, torch.double)
     @dtypes(torch.float, torch.double)
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 7361ad1..03a6e2f 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -1820,15 +1820,7 @@
         if offsets is None:
             raise ValueError("offsets has to be a 1D Tensor but got None")
         if offsets.dim() != 1:
-            raise ValueError("offsets has to be a 1D Tensor")
-        if int(offsets[0]) != 0:
-            raise ValueError("offsets[0] has to be 0, i.e., the first sequence "
-                             "in the mini-batch has to start from position 0. "
-                             "However, got {}".format(offsets[0].item()))
-        if int(offsets[-1]) > input.size(0):
-            raise ValueError("offsets[-1] can not be greater than input's length"
-                             " ({}), but got offsets[-1] of {}"
-                             .format(input.size(0), offsets[-1].item()))
+            raise ValueError("offsets has to be a 1D Tensor")        
     else:
         raise ValueError("input has to be 1D or 2D Tensor,"
                          " but got Tensor of dimension {}".format(input.dim()))