Updates div to perform true division (#42907)

Summary:
This PR:

- updates div to perform true division
- makes torch.true_divide an alias of torch.div

This follows on work in previous PyTorch releases that first deprecated div performing "integer" or "floor" division, then prevented it by throwing a runtime error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42907

Reviewed By: ngimel

Differential Revision: D23622114

Pulled By: mruberry

fbshipit-source-id: 414c7e3c1a662a6c3c731ad99cc942507d843927
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index abfd9c2..3ba93d5 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -292,8 +292,6 @@
 _(aten, digamma) \
 _(aten, dim) \
 _(aten, dist) \
-_(aten, div) \
-_(aten, div_) \
 _(aten, dot) \
 _(aten, dropout) \
 _(aten, dstack) \
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index 394a902..3caac54 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -188,6 +188,10 @@
   _(aten, __isnot__)                 \
   _(aten, copy)                      \
   _(aten, copy_)                     \
+  _(aten, div)                       \
+  _(aten, div_)                      \
+  _(aten, true_divide)               \
+  _(aten, true_divide_)              \
   _(aten, t_)                        \
   _(aten, addbmm_)                   \
   _(aten, addcdiv_)                  \
diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp
index 6d4391f..049c368 100644
--- a/aten/src/ATen/native/BinaryOps.cpp
+++ b/aten/src/ATen/native/BinaryOps.cpp
@@ -122,29 +122,14 @@
 }
 
 Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
-  if (isIntegralType(result.scalar_type(), /*includeBool=*/ true)) {
-    TORCH_CHECK(false,
-      "Integer division of tensors using div or / is no longer supported, ",
-      "and in a future release div will perform true division as in Python 3. ",
-      "Use true_divide or floor_divide (// in Python) instead.");
-  }
-
-  auto iter = TensorIterator::binary_op(result, self, other);
+  auto iter = TensorIterator::binary_float_op(result, self, other);
   div_stub(iter.device_type(), iter);
   return result;
 }
 
 Tensor div(const Tensor& self, const Tensor& other) {
-  if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
-      && isIntegralType(other.scalar_type(), /*includeBool=*/ true)) {
-    TORCH_CHECK(false,
-      "Integer division of tensors using div or / is no longer supported, ",
-      "and in a future release div will perform true division as in Python 3. ",
-      "Use true_divide or floor_divide (// in Python) instead.");
-  }
-
   Tensor result;
-  auto iter = TensorIterator::binary_op(result, self, other);
+  auto iter = TensorIterator::binary_float_op(result, self, other);
   div_stub(iter.device_type(), iter);
   return iter.output();
 }
@@ -153,6 +138,20 @@
   return native::div_out(self, self, other);
 }
 
+// WARNING: There doesn't appear to be any testing for this function
+// with sparse self input.
+Tensor div(const Tensor& self, Scalar other) {
+  return self.div(wrapped_scalar_tensor(other)); // redispatch!
+}
+
+// WARNING: This function, with a sparse self, is currently only
+// exercised by DistributedDataParallelTest.test_sparse_gradients
+// (you need to exercise it from C++, because this overload is never
+// used for Python)
+Tensor& div_(Tensor& self, Scalar other) {
+  return self.div_(wrapped_scalar_tensor(other)); // redispatch!
+}
+
 Tensor& remainder_out(Tensor& result, const Tensor& self, const Tensor& other) {
   auto iter = TensorIterator::binary_op(result, self, other);
   remainder_stub(iter.device_type(), iter);
@@ -170,47 +169,25 @@
   return native::remainder_out(self, self, other);
 }
 
+// true_divide, an alias for div
 Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
-  TensorIterator iter = TensorIteratorConfig()
-     .add_output(result)
-     .add_input(self)
-     .add_input(divisor)
-     .allow_cpu_scalars(true)
-     .promote_inputs_to_common_dtype(true)
-     .promote_integer_inputs_to_float(true)
-     .cast_common_dtype_to_outputs(true)
-     .enforce_safe_casting_to_output(true)
-     .build();
-
-  div_stub(iter.device_type(), iter);
-  return result;
+  return native::div_out(result, self, divisor);
 }
 
 Tensor true_divide(const Tensor& self, const Tensor& divisor) {
-  // If both inputs have integral (or bool) types, creates
-  // temporary float copies as new inputs and sets the result's type to
-  // the default scalar type
-  if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
-   && isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) {
-    const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype());
-    Tensor result = at::empty({0}, self.options().dtype(scalar_type));
-    auto iter = TensorIterator::binary_op(result,
-                                          self.to(scalar_type),
-                                          divisor.to(scalar_type));
-    div_stub(iter.device_type(), iter);
-    return result;
-  }
-
-  // If at least one input is non-integral (or bool) participates in
-  // type promotion like other binary ufuncs
-  Tensor result;
-  auto iter = TensorIterator::binary_op(result, self, divisor);
-  div_stub(iter.device_type(), iter);
-  return iter.output();
+  return self.div(divisor);
 }
 
 Tensor& true_divide_(Tensor& self, const Tensor& divisor) {
-  return native::true_divide_out(self, self, divisor);
+  return self.div_(divisor);
+}
+
+Tensor true_divide(const Tensor& self, Scalar divisor) {
+  return self.div(divisor);
+}
+
+Tensor& true_divide_(Tensor& self, Scalar divisor) {
+  return self.div_(divisor);
 }
 
 Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) {
@@ -401,20 +378,6 @@
   return native::add_(self, wrapped_scalar_tensor(other), alpha);
 }
 
-// WARNING: There doesn't appear to be any testing for this function
-// with sparse self input.
-Tensor div(const Tensor& self, Scalar other) {
-  return self.div(wrapped_scalar_tensor(other)); // redispatch!
-}
-
-// WARNING: This function, with a sparse self, is currently only
-// exercised by DistributedDataParallelTest.test_sparse_gradients
-// (you need to exercise it from C++, because this overload is never
-// used for Python)
-Tensor& div_(Tensor& self, Scalar other) {
-  return self.div_(wrapped_scalar_tensor(other)); // redispatch!
-}
-
 Tensor remainder(const Tensor& self, Scalar other) {
   Tensor other_tensor = wrapped_scalar_tensor(other);
   // FIXME: 'other' is converted to match the dtype of 'self' to retain
@@ -978,14 +941,6 @@
   return at::nextafter_out(self, self, other);
 }
 
-Tensor true_divide(const Tensor& self, Scalar divisor) {
-  return self.true_divide(wrapped_scalar_tensor(divisor)); // redispatch!
-}
-
-Tensor& true_divide_(Tensor& self, Scalar divisor) {
-  return self.true_divide_(wrapped_scalar_tensor(divisor)); // redispatch!
-}
-
 // Note: this function is only for testing.
 // It is undocumented and should not be used outside of tests.
 Tensor _test_serialization_subcmul(const Tensor& self, const Tensor& other, Scalar alpha) {
diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp
index 72d2719..3c8dc15 100644
--- a/aten/src/ATen/native/PointwiseOps.cpp
+++ b/aten/src/ATen/native/PointwiseOps.cpp
@@ -73,11 +73,12 @@
     TORCH_CHECK(false,
       "Integer division with addcdiv is no longer supported, and in a future  ",
       "release addcdiv will perform a true division of tensor1 and tensor2. ",
-      "The historic addcdiv behavior can be implemented using floor_divide ",
-      "for integral inputs (self + value * tensor1 // tensor2) and ",
-      "division for float inputs (self + value * tensor1 / tensor2). ",
-      "The future addcdiv behavior can be implemented with true_divide ",
-      "(self + value * torch.true_divide(tensor1, tensor2).");
+      "The historic addcdiv behavior can be implemented as ",
+      "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
+      "for integer inputs and as ",
+      "(input + value * tensor1 / tensor2) for float inputs. ",
+      "The future addcdiv behavior is just the latter implementation: ",
+      "(input + value * tensor1 / tensor2), for all dtypes.");
   }
   checkBackend("addcdiv_cpu", result, self.options().backend());
   auto iter = at::TensorIteratorConfig()
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index e58ce5d..71b5527 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -829,6 +829,22 @@
      .build();
 }
 
+// Helper to construct a binary op that promotes integer inputs to float.
+TensorIterator TensorIterator::binary_float_op(Tensor& out, const Tensor& a,
+    const Tensor& b) {
+  return TensorIteratorConfig()
+     .set_check_mem_overlap(true)
+     .add_output(out)
+     .add_input(a)
+     .add_input(b)
+     .allow_cpu_scalars(true)
+     .promote_inputs_to_common_dtype(true)
+     .cast_common_dtype_to_outputs(true)
+     .enforce_safe_casting_to_output(true)
+     .promote_integer_inputs_to_float(true)
+     .build();
+}
+
 TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a,
     const Tensor& b) {
   return TensorIteratorConfig()
diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h
index 118fd08..f15af0b 100644
--- a/aten/src/ATen/native/TensorIterator.h
+++ b/aten/src/ATen/native/TensorIterator.h
@@ -156,6 +156,7 @@
 
   void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true);
 
+  static TensorIterator binary_float_op(Tensor& out, const Tensor& a, const Tensor& b);
   static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b);
   static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b);
   static TensorIterator unary_op(Tensor& out, const Tensor& a);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 6625529..d26d1d9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1241,6 +1241,25 @@
   use_c10_dispatcher: full
   variants: method
 
+  # true_divide, an alias for div
+- func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: method
+
+- func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+
+- func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor
+  use_c10_dispatcher: full
+  variants: function, method
+
+- func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
+  use_c10_dispatcher: full
+  variants: method
+
 - func: dot(Tensor self, Tensor tensor) -> Tensor
   use_c10_dispatcher: full
   variants: function, method
@@ -3268,33 +3287,6 @@
 - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor
   use_c10_dispatcher: full
 
-- func: true_divide.Tensor(Tensor self, Tensor other) -> Tensor
-  use_c10_dispatcher: full
-  variants: function, method
-  dispatch:
-    CPU, CUDA: true_divide
-    SparseCPU, SparseCUDA: true_divide_sparse
-
-- func: true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
-  use_c10_dispatcher: full
-  variants: method
-  dispatch:
-    CPU, CUDA: true_divide_
-    SparseCPU, SparseCUDA: true_divide_sparse_
-
-- func: true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
-  dispatch:
-    CPU, CUDA: true_divide_out
-    SparseCPU, SparseCUDA: true_divide_out_sparse_zerodim
-
-- func: true_divide.Scalar(Tensor self, Scalar other) -> Tensor
-  use_c10_dispatcher: full
-  variants: function, method
-
-- func: true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
-  use_c10_dispatcher: full
-  variants: method
-
 - func: trunc(Tensor self) -> Tensor
   use_c10_dispatcher: full
   variants: function, method
diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
index 4cedc8a..9eee5e0 100644
--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
@@ -207,6 +207,9 @@
 
 Tensor div_sparse(const Tensor& self, const Tensor& value) {
   auto commonDtype = at::result_type(self, value);
+  if (c10::isIntegralType(commonDtype, /*include_bool=*/true)) {
+    commonDtype = typeMetaToScalarType(at::get_default_dtype());
+  }
   Tensor result = at::empty({0}, self.options().dtype(commonDtype));
   return div_out_sparse_zerodim(result, self, value);
 }
@@ -220,64 +223,6 @@
 }
 
 // --------------------------------------------------------------------
-// true_divide(SparseTensor, Scalar)
-// --------------------------------------------------------------------
-
-SparseTensor& true_divide_out_sparse_zerodim(
-    SparseTensor& result,
-    const SparseTensor& dividend,
-    const Tensor& divisor) {
-  TORCH_CHECK(divisor.dim() == 0, "Sparse true division requires a scalar or ",
-    "zero-dim dense tensor divisor (got shape ", divisor.sizes(), " for divisor)");
-  TORCH_CHECK(!divisor.is_sparse(), "Sparse true division requires a scalar or ",
-    "zero-dim dense tensor divisor (got a sparse divisor)");
-
-  AT_ASSERT(result.is_sparse());
-  AT_ASSERT(dividend.is_sparse());
-
-  // Short-circuits if result and dividend are the same tensor
-  if (is_same_tensor(result, dividend)) {
-    Tensor result_values = result._values();
-    at::true_divide_out(result_values, result_values, divisor);
-  } else {
-    Tensor dividend_tmp = dividend;
-    result.resize_as_(dividend_tmp);
-    auto indices = result._indices();
-    indices.resize_as_(dividend_tmp._indices());
-    indices.copy_(dividend_tmp._indices());
-    Tensor result_values = result._values();
-    at::true_divide_out(result_values, dividend_tmp._values(), divisor);
-    get_sparse_impl(result)->set_nnz_and_narrow(dividend_tmp._nnz());
-    result._coalesced_(dividend_tmp.is_coalesced());
-  }
-
-  return result;
-}
-
-Tensor true_divide_sparse(const Tensor& self, const Tensor& value) {
-  auto commonDtype = at::result_type(self, value);
-
-  // Ensures floating dtype
-  if (isIntegralType(commonDtype, /*includeBool=*/ true)) {
-    commonDtype = typeMetaToScalarType(c10::get_default_dtype());
-  }
-
-  Tensor result = at::empty({0}, self.options().dtype(commonDtype));
-  return true_divide_out_sparse_zerodim(result, self, value);
-}
-
-SparseTensor& true_divide_out_sparse_scalar(
-    SparseTensor& result,
-    const SparseTensor& dividend,
-    Scalar divisor) {
-  return true_divide_out_sparse_zerodim(result, dividend, wrapped_scalar_tensor(divisor));
-}
-
-Tensor& true_divide_sparse_(Tensor& self, const Tensor& divisor) {
-  return true_divide_out_sparse_zerodim(self, self, divisor);
-}
-
-// --------------------------------------------------------------------
 // floor_divide(SparseTensor, Scalar)
 // --------------------------------------------------------------------
 
@@ -385,7 +330,7 @@
 
 Tensor mv_sparse(const SparseTensor& self, const Tensor& vec)
 {
-  TORCH_CHECK(self.ndimension() == 2 && 
+  TORCH_CHECK(self.ndimension() == 2 &&
               vec.ndimension() == 1,
               "mv: two tensor dim should be 2 and 1, but got ",
               "SparseTensor Dim: ", self.ndimension(), "Tensor Dim: ", vec.ndimension());
diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py
index c1ea2b0..9f11731 100644
--- a/test/jit/test_save_load.py
+++ b/test/jit/test_save_load.py
@@ -128,8 +128,7 @@
         except Exception as e:
             self.skipTest("Failed to load fixture!")
 
-        self._verify_no("aten::div", v3_module)
-        self._verify_count("aten::true_divide", v3_module, 3)
+        self._verify_count("aten::div", v3_module, 3)  # true_divide aliases to div
         self._verify_count("aten::floor_divide", v3_module, 3)
 
         current_module = self._save_load_module(MyModule)
@@ -172,8 +171,7 @@
         except Exception as e:
             self.skipTest("Failed to load fixture!")
 
-        self._verify_no("aten::div", v3_module)
-        self._verify_count("aten::true_divide", v3_module, 1)
+        self._verify_count("aten::div", v3_module, 1)  # true_divide aliases to div
         self._verify_count("aten::floor_divide", v3_module, 1)
 
         current_module = self._save_load_module(MyModule)
@@ -218,8 +216,7 @@
         except Exception as e:
             self.skipTest("Failed to load fixture!")
 
-        self._verify_no("aten::div", v3_module)
-        self._verify_count("aten::true_divide", v3_module, 1)
+        self._verify_count("aten::div", v3_module, 1)  # true_divide aliases to div
         self._verify_count("aten::floor_divide", v3_module, 1)
 
         current_module = self._save_load_module(MyModule)
@@ -278,8 +275,7 @@
             self.skipTest("Failed to load fixture!")
 
         for m in (v3_module_float, v3_module_int):
-            self._verify_no("aten::div", m)
-            self._verify_count("aten::true_divide", m, 1)
+            self._verify_count("aten::div", m, 1)  # true_divide aliases to div
             self._verify_count("aten::floor_divide", m, 1)
 
         current_module_float = self._save_load_module(MyModuleFloat)
@@ -414,8 +410,7 @@
             self.skipTest("Failed to load fixture!")
 
         for m in (v3_module_float, v3_module_int):
-            self._verify_no("aten::div", m)
-            self._verify_count("aten::true_divide", m, 1)
+            self._verify_count("aten::div", m, 1)  # true_divide aliases to div
             self._verify_count("aten::floor_divide", m, 1)
 
         current_module_float = self._save_load_module(MyModuleFloat)
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 9e96b25..f4aae94 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -870,25 +870,24 @@
         x = torch.randn(2, 3, 4)
         self.run_test(FloordivModule(), (x,))
 
-    def test_true_div(self):
-        class TrueDivModule(torch.nn.Module):
+    def test_div(self):
+        class DivModule(torch.nn.Module):
             def forward(self, x, y):
-                return torch.true_divide(x, y)
+                return x / y
 
         x = torch.randn(2, 3, 4).to(torch.int)
         y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
-        self.run_test(TrueDivModule(), (x, y))
-        self.run_test(TrueDivModule(), (x.float(), y))
-        self.run_test(TrueDivModule(), (x.to(torch.short), y.to(torch.short)))
+        self.run_test(DivModule(), (x, y))
+        self.run_test(DivModule(), (x.float(), y.float()))
 
-    # Note: true_divide cannot (generally) be exported via scripting
+    # Note: div cannot (generally) be exported via scripting
     # since its type promotion logic is dependent on knowing the scalar types
     # of the input tensors. That is, the ONNX graph is dependent on the
     # data type of the inputs. This makes it appropriate for tracing only.
-    def test_true_div_trace(self):
-        class TrueDivModule(torch.nn.Module):
+    def test_div_promotion_trace(self):
+        class DivModule(torch.nn.Module):
             def forward(self, x, y):
-                return torch.true_divide(x, y)
+                return x / y
 
         x = torch.randn(2, 3, 4).to(torch.int)
         y = torch.arange(1, 2 * 3 * 4 + 1).reshape(2, 3, 4).to(torch.int)
@@ -896,10 +895,10 @@
         prev_default = torch.get_default_dtype()
 
         torch.set_default_dtype(torch.float)
-        self.run_test(torch.jit.trace(TrueDivModule(), (x, y)), (x, y))
+        self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
 
         torch.set_default_dtype(torch.double)
-        self.run_test(torch.jit.trace(TrueDivModule(), (x, y)), (x, y))
+        self.run_test(torch.jit.trace(DivModule(), (x, y)), (x, y))
 
         torch.set_default_dtype(prev_default)
 
diff --git a/test/test_foreach.py b/test/test_foreach.py
index be9d5f7..8369ba5 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -91,7 +91,7 @@
         if device == 'cpu':
             if dtype == torch.half:
                 with self.assertRaisesRegex(RuntimeError, r"\"addcmul_cpu_out\" not implemented for \'Half\'"):
-                    self._test_pointwise_op(device, dtype, torch._foreach_addcmul, 
+                    self._test_pointwise_op(device, dtype, torch._foreach_addcmul,
                                             torch._foreach_addcmul_, torch.addcmul)
                 return
 
@@ -100,7 +100,7 @@
     @dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
     def test_addcdiv(self, device, dtype):
         if dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
-            with self.assertRaisesRegex(RuntimeError, 
+            with self.assertRaisesRegex(RuntimeError,
                                         "Integer division with addcdiv is no longer supported, and in a future"):
                 self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)
             return
@@ -108,7 +108,7 @@
         if device == 'cpu':
             if dtype == torch.half:
                 with self.assertRaisesRegex(RuntimeError, r"\"addcdiv_cpu_out\" not implemented for \'Half\'"):
-                    self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, 
+                    self._test_pointwise_op(device, dtype, torch._foreach_addcdiv,
                                             torch._foreach_addcdiv_, torch.addcdiv)
                 return
         self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)
@@ -131,7 +131,7 @@
         self.assertEqual(res, expected)
 
         if dtype in [torch.bool]:
-            with self.assertRaisesRegex(RuntimeError, 
+            with self.assertRaisesRegex(RuntimeError,
                                         "result type Long can't be cast to the desired output type Bool"):
                 torch._foreach_add_(tensors, int_scalar)
         else:
@@ -144,7 +144,7 @@
         float_scalar = 1.
 
         # float scalar + integral tensor will result in float tensor
-        if dtype in [torch.uint8, torch.int8, torch.int16, 
+        if dtype in [torch.uint8, torch.int8, torch.int16,
                      torch.int32, torch.int64, torch.bool]:
             expected = [torch.ones(10, 10, device=device, dtype=torch.float32) for _ in range(10)]
         else:
@@ -153,7 +153,7 @@
         res = torch._foreach_add(tensors, float_scalar)
         self.assertEqual(res, expected)
 
-        if dtype in [torch.uint8, torch.int8, torch.int16, 
+        if dtype in [torch.uint8, torch.int8, torch.int16,
                      torch.int32, torch.int64, torch.bool]:
             self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, float_scalar))
         else:
@@ -169,7 +169,7 @@
         expected = [torch.add(complex_scalar, torch.zeros(10, 10, device=device, dtype=dtype)) for _ in range(10)]
 
         if dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] and device == 'cuda:0':
-            # value cannot be converted to dtype without overflow: 
+            # value cannot be converted to dtype without overflow:
             self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar))
             self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, complex_scalar))
             return
@@ -198,7 +198,7 @@
 
     @dtypes(*torch.testing.get_all_dtypes())
     def test_add_with_different_size_tensors(self, device, dtype):
-        if dtype == torch.bool: 
+        if dtype == torch.bool:
             return
         tensors = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
         expected = [torch.ones(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)]
@@ -222,14 +222,14 @@
         expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)]
 
         # bool tensor + 1 will result in int64 tensor
-        if dtype == torch.bool: 
+        if dtype == torch.bool:
             expected[0] = expected[0].to(torch.int64).add(1)
 
         res = torch._foreach_add(tensors, 1)
         self.assertEqual(res, expected)
 
     def test_bin_op_scalar_with_different_tensor_dtypes(self, device):
-        tensors = [torch.tensor([1.1], dtype=torch.float, device=device), 
+        tensors = [torch.tensor([1.1], dtype=torch.float, device=device),
                    torch.tensor([1], dtype=torch.long, device=device)]
         self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, 1))
 
@@ -279,7 +279,7 @@
             with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device"):
                 torch._foreach_add_([tensor1], [tensor2])
 
-        # Coresponding tensors with different sizes 
+        # Coresponding tensors with different sizes
         tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)]
         tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)]
         with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size"):
@@ -311,8 +311,11 @@
     @dtypes(*torch.testing.get_all_dtypes())
     def test_div_list(self, device, dtype):
         if dtype in torch.testing.integral_types_and(torch.bool):
-            with self.assertRaisesRegex(RuntimeError, "Integer division of tensors using div or / is no longer"):
-                self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
+            if self.device_type == 'cpu':
+                with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to the desired output type"):
+                    self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
+            else:
+                self.skipTest("Skipped! See https://github.com/pytorch/pytorch/issues/44489")
             return
 
         self._test_bin_op_list(device, dtype, torch._foreach_div, torch._foreach_div_, torch.div)
@@ -321,7 +324,7 @@
         tensors1 = []
         tensors2 = []
 
-        for bin_op in self.bin_ops: 
+        for bin_op in self.bin_ops:
             # Empty lists
             with self.assertRaises(RuntimeError):
                 bin_op(tensors1, tensors2)
@@ -364,7 +367,7 @@
         torch._foreach_add_([tensor1], [tensor2])
         self.assertEqual(res, [tensor1])
 
-        # non contiguous 
+        # non contiguous
         tensor1 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
         tensor2 = torch.randn(5, 2, 1, 3, device=device)[:, 0]
         self.assertFalse(tensor1.is_contiguous())
diff --git a/test/test_jit.py b/test/test_jit.py
index fdfd29f..fdb4690 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15355,6 +15355,11 @@
     'test_nn_max_pool1d_with_indices',
 }
 
+EXCLUDE_ALIAS = {
+    # aliases, which may appear in method_tests but are tested elsewhere
+    'true_divide',
+}
+
 def check_alias_annotation(method_name, args, kwargs):
     formals, tensors, actuals = get_script_args(args)
     call = get_call(method_name, 'method', actuals, kwargs)
@@ -15525,6 +15530,10 @@
     if 'complex' in variant_name:
         return
 
+    # Skips aliases, which are tested in test_op_aliases.py
+    if name in EXCLUDE_ALIAS:
+        return
+
     basic_test_name = 'test_' + name
     if variant_name != '':
         basic_test_name += '_' + variant_name
diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py
index c48fed5..738b4d1 100644
--- a/test/test_op_aliases.py
+++ b/test/test_op_aliases.py
@@ -128,6 +128,14 @@
               lambda d: torch.randn(20, device=d),
               get_args=lambda d: (torch.randn(20, device=d),),
               decorators=(onlyCPU,)),
+    # NOTE: only runs on CPU because it leaks CUDA memory
+    #   (see https://github.com/pytorch/pytorch/issues/43119)
+    AliasInfo('true_divide', torch.true_divide, 'div', torch.div,
+              lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
+              decorators=(onlyCPU,)),
+    AliasInfo('true_divide_', torch.Tensor.true_divide_, 'div_', torch.Tensor.div_,
+              lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.rand(20, device=d) + .1,),
+              decorators=(onlyCPU,)),
 )
 
 # Placeholder test class for validating that aliases are correctly
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 338cf4a..1057282 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -1256,7 +1256,7 @@
             ({'dtype': torch.double, 'p': 'fro'},
              ValueError, r'dtype argument is not supported in frobenius norm'),
             ({'dtype': torch.double, 'p': 0},
-             RuntimeError, r"norm_sparse currently does not support 'dtype' argument") 
+             RuntimeError, r"norm_sparse currently does not support 'dtype' argument")
         ]
         x = self._gen_sparse(3, 10, 100)[0]
         for kwargs, err, msg in kwarg_error_pairs:
@@ -1376,14 +1376,6 @@
         self.assertEqual(self.safeToDense(y1), expected)
         self.assertEqual(self.safeToDense(y2), expected)
 
-        y1 = torch.true_divide(x1, 37.5)
-        y2 = x1.clone()
-        if y2.dtype.is_floating_point or y2.dtype.is_complex:
-            y2.true_divide_(37.5)
-        expected = torch.true_divide(self.safeToDense(x1), 37.5)
-        self.assertEqual(self.safeToDense(y1), expected)
-        self.assertEqual(self.safeToDense(y2), expected)
-
         y1 = x1 // 37.5
         y2 = x1.clone()
         y2.floor_divide_(37.5)
@@ -2367,15 +2359,6 @@
                                lambda: torch.tensor(1., device=self.device).to_sparse()
                                / torch.tensor(1., device=self.device).to_sparse())
 
-    def test_true_divide_by_sparse_error(self):
-        def fn():
-            x = torch.tensor(1., device=self.device).to_sparse()
-            y = torch.tensor(1., device=self.device).to_sparse()
-            torch.true_divide(x, y)
-
-        self.assertRaisesRegex(RuntimeError, 'Sparse true division requires',
-                               fn)
-
     def test_floor_divide_by_sparse_error(self):
         self.assertRaisesRegex(RuntimeError, 'Sparse floor division requires',
                                lambda: torch.tensor(1., device=self.device).to_sparse()
diff --git a/test/test_torch.py b/test/test_torch.py
index c70c4ae..860f1b3 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -16795,24 +16795,26 @@
 
         self.compare_with_numpy(torch.reciprocal, np.reciprocal, vals, device, dtype)
 
-    @onlyCPU
     @dtypes(torch.bfloat16, torch.float)
     def test_div(self, device, dtype):
-        m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
-        res1 = m1.clone()
-        res1[:, 3].div_(2)
-        res2 = m1.clone()
-        for i in range(m1.size(0)):
-            res2[i, 3] = res2[i, 3] / 2
-        self.assertEqual(res1, res2)
+        for op, method, inplace in ((torch.div, torch.Tensor.div, torch.Tensor.div_),
+                                    (torch.true_divide, torch.Tensor.true_divide,
+                                     torch.Tensor.true_divide_)):
+            m1 = torch.randn(10, 10, dtype=torch.float, device=device).to(dtype=dtype)
+            res1 = m1.clone()
+            inplace(res1[:, 3], 2)
+            res2 = m1.clone()
+            for i in range(m1.size(0)):
+                res2[i, 3] = res2[i, 3] / 2
+            self.assertEqual(res1, res2)
 
-        if dtype == torch.bfloat16:
-            a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
-            a2 = torch.tensor([2., 2.], dtype=dtype, device=device)
-            self.assertEqual(a1 / a2,
-                             torch.tensor([2.1, 3.1], dtype=dtype, device=device),
-                             atol=0.01, rtol=0)
-            self.assertEqual(a1.div(a2), a1 / a2)
+            if dtype == torch.bfloat16:
+                a1 = torch.tensor([4.2, 6.2], dtype=dtype, device=device)
+                a2 = torch.tensor([2., 2.], dtype=dtype, device=device)
+                self.assertEqual(op(a1, a2),
+                                 torch.tensor([2.1, 3.1], dtype=dtype, device=device),
+                                 atol=0.01, rtol=0)
+                self.assertEqual(method(a1, a2), op(a1, a2))
 
     @onlyCUDA
     @dtypes(torch.half)
@@ -17864,7 +17866,7 @@
 
     @onlyCPU
     @dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
-    def test_div_zero(self, device, dtype):
+    def test_floor_divide_zero(self, device, dtype):
         a = torch.tensor([0, 1], dtype=dtype, device=device)
         b = torch.tensor([0, 1], dtype=dtype, device=device)
         with self.assertRaisesRegex(RuntimeError, 'ZeroDivisionError'):
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index d0f5225..9ee90c7 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -256,7 +256,7 @@
         # supported dtype
         dtypes1 = torch.testing.get_all_math_dtypes('cuda')
         dtypes2 = torch.testing.get_all_math_dtypes(device)
-        ops = [torch.add, torch.sub, torch.mul, torch.true_divide, torch.rsub]
+        ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub]
         for dt1, dt2 in itertools.product(dtypes1, dtypes2):
             for op, non_contiguous in itertools.product(ops, [True, False]):
                 common_dtype = torch.promote_types(dt1, dt2)
@@ -590,58 +590,61 @@
 
     @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
     @float_double_default_dtype
-    def test_true_divide(self, device, dtype):
-        dividend = (torch.randn(5, device=device) * 100).to(dtype)
-        divisor = torch.arange(1, 6, device=device).to(dtype)
+    def test_div_promotion(self, device, dtype):
+        for op in (torch.div, torch.true_divide):
+            dividend = (torch.randn(5, device=device) * 100).to(dtype)
+            divisor = torch.arange(1, 6, device=device).to(dtype)
 
-        # Tests tensor/tensor division
-        casting_result = dividend.to(torch.get_default_dtype()) / divisor.to(torch.get_default_dtype())
-        self.assertEqual(casting_result, torch.true_divide(dividend, divisor))
+            # Tests tensor/tensor division
+            casting_result = dividend.to(torch.get_default_dtype()) / divisor.to(torch.get_default_dtype())
+            self.assertEqual(casting_result, op(dividend, divisor))
 
-        # Tests tensor/scalar division
-        casting_result = dividend.to(torch.get_default_dtype()) / 2
-        self.assertEqual(casting_result, torch.true_divide(dividend, 2.))
+            # Tests tensor/scalar division
+            casting_result = dividend.to(torch.get_default_dtype()) / 2
+            self.assertEqual(casting_result, op(dividend, 2.))
 
     @onlyOnCPUAndCUDA
     @dtypes(torch.float, torch.double,
             torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
-    def test_true_divide_out(self, device, dtype):
-        dividend = (torch.randn(5, device=device) * 100).to(dtype)
-        divisor = torch.arange(1, 6, device=device).to(dtype)
+    def test_div_promotion_out(self, device, dtype):
+        for op in (torch.div, torch.true_divide):
+            dividend = (torch.randn(5, device=device) * 100).to(dtype)
+            divisor = torch.arange(1, 6, device=device).to(dtype)
 
-        # Tests that requests for an integer quotient fail
-        if not dtype.is_floating_point:
-            integral_quotient = torch.empty(5, device=device, dtype=dtype)
-            with self.assertRaises(RuntimeError):
-                torch.true_divide(dividend, divisor, out=integral_quotient)
-            with self.assertRaises(RuntimeError):
-                torch.true_divide(dividend, 2, out=integral_quotient)
-        else:
-            # Tests that requests for a floating quotient succeed
-            floating_quotient = torch.empty(5, device=device, dtype=dtype)
-            div_result = dividend / divisor
-            self.assertEqual(div_result,
-                             torch.true_divide(dividend, divisor, out=floating_quotient))
-            self.assertEqual(dividend / 2,
-                             torch.true_divide(dividend, 2, out=floating_quotient))
+            # Tests that requests for an integer quotient fail
+            if not dtype.is_floating_point:
+                integral_quotient = torch.empty(5, device=device, dtype=dtype)
+                with self.assertRaises(RuntimeError):
+                    op(dividend, divisor, out=integral_quotient)
+                with self.assertRaises(RuntimeError):
+                    op(dividend, 2, out=integral_quotient)
+            else:
+                # Tests that requests for a floating quotient succeed
+                floating_quotient = torch.empty(5, device=device, dtype=dtype)
+                div_result = dividend / divisor
+                self.assertEqual(div_result,
+                                 op(dividend, divisor, out=floating_quotient))
+                self.assertEqual(dividend / 2,
+                                 op(dividend, 2, out=floating_quotient))
 
     @dtypes(torch.float, torch.double,
             torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
-    def test_true_divide_inplace(self, device, dtype):
-        dividend = (torch.randn(5, device=device) * 100).to(dtype)
-        divisor = torch.arange(1, 6, device=device).to(dtype)
+    def test_div_promotion_inplace(self, device, dtype):
+        for op in (torch.Tensor.div_, torch.Tensor.true_divide_):
+            dividend = (torch.randn(5, device=device) * 100).to(dtype)
+            divisor = torch.arange(1, 6, device=device).to(dtype)
 
-        # Tests that requests for an integer quotient fail
-        if not dtype.is_floating_point:
-            with self.assertRaises(RuntimeError):
-                dividend.true_divide_(divisor)
-            with self.assertRaises(RuntimeError):
-                dividend.true_divide_(2)
-        else:
-            # Tests that requests for a floating quotient succeed
-            div_result = dividend.clone().div_(divisor)
-            self.assertEqual(div_result, dividend.clone().true_divide_(divisor))
-            self.assertEqual(dividend.clone().div_(2), dividend.clone().true_divide_(2))
+            # Tests that requests for an integer quotient fail
+            if not dtype.is_floating_point:
+                with self.assertRaises(RuntimeError):
+                    op(dividend, divisor)
+                with self.assertRaises(RuntimeError):
+                    op(dividend, 2)
+            else:
+                # Tests that requests for a floating quotient succeed
+                div_result = dividend.clone().div_(divisor)
+                self.assertEqual(div_result, op(dividend.clone(), divisor))
+                self.assertEqual(dividend.clone().div_(2), op(dividend.clone(), 2))
 
     def _test_sparse_op_input_tensors(self, device, dtype, coalesced, zeros=True):
         t = self._get_test_tensor(device, dtype, not zeros)
@@ -777,29 +780,13 @@
     @onlyOnCPUAndCUDA
     @dtypes(torch.bool, torch.short, torch.uint8, torch.int, torch.long)
     @float_double_default_dtype
-    def test_sparse_true_divide(self, device, dtype):
-        dividend = torch.randn(5, device=device).to(dtype)
-        divisor = 2
-        dividend_sparse = dividend.to_sparse()
-        casting_result = dividend.to(torch.get_default_dtype()) / 2
-        self.assertEqual(casting_result, torch.true_divide(dividend_sparse, 2).to_dense())
-
-    @onlyOnCPUAndCUDA
-    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
-    def test_integer_div_deprecated(self, device, dtype):
-        a = torch.tensor(1, device=device, dtype=dtype)
-        b = torch.tensor(1, device=device, dtype=dtype)
-        o = torch.empty(1, device=device, dtype=dtype)
-
-        # Tests div (including /) deprecation
-        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
-            c = a / b
-        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
-            c = torch.div(a, b)
-        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
-            torch.div(a, b, out=o)
-        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
-            a.div_(b)
+    def test_sparse_div_promotion(self, device, dtype):
+        for op in (torch.div, torch.true_divide):
+            dividend = torch.randn(5, device=device).to(dtype)
+            divisor = 2
+            dividend_sparse = dividend.to_sparse()
+            casting_result = dividend.to(torch.get_default_dtype()) / 2
+            self.assertEqual(casting_result, op(dividend_sparse, 2).to_dense())
 
     @onlyOnCPUAndCUDA
     @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 7f85660..485911e 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -1061,15 +1061,6 @@
 - name: triu(Tensor self, int diagonal=0) -> Tensor
   self: grad.triu(diagonal)
 
-  # Note: true_divide uses the division operator for backward
-  # since grad is always a floating tensor.
-- name: true_divide.Tensor(Tensor self, Tensor other) -> Tensor
-  self: grad / other
-  other: -grad * self / (other * other)
-
-- name: true_divide.Scalar(Tensor self, Scalar other) -> Tensor
-  self: grad / other
-
 - name: trunc(Tensor self) -> Tensor
   self: zeros_like(grad, at::MemoryFormat::Preserve)
 
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index d0a476f..27d13cc 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -311,24 +311,20 @@
             [ -3.8202,   4.3691,   1.0943,  -1.1109,   5.4730]])
 """.format(**common_args))
 
-add_docstr(torch.addcdiv,
-           r"""
+add_docstr(torch.addcdiv, r"""
 addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor
 
 Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`,
 multiply the result by the scalar :attr:`value` and add it to :attr:`input`.
 
 .. warning::
-    Integer division with addcdiv is no longer supported, and in a future release
-    addcdiv will perform a true division of :attr:`tensor1` and :attr:`tensor2`.
-    The historic addcdiv behavior can be implemented using :func:`floor_divide`
-    for integral inputs
-    (:attr:`input` + :attr:`value` * :attr:`tensor1` // :attr:`tensor2`)
-    and :func:`div` for float inputs
-    (:attr:`input` + :attr:`value` * :attr:`tensor1` / :attr:`tensor2`).
-    The future addcdiv behavior can be implemented with :func:`true_divide`
-    (:attr:`input` + :attr:`value` * torch.true_divide(:attr:`tensor1`,
-    :attr:`tensor2`).
+    Integer division with addcdiv is no longer supported, and in a future
+    release addcdiv will perform a true division of tensor1 and tensor2.
+    The historic addcdiv behavior can be implemented as
+    (input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype)
+    for integer inputs and as (input + value * tensor1 / tensor2) for float inputs.
+    The future addcdiv behavior is just the latter implementation:
+    (input + value * tensor1 / tensor2), for all dtypes.
 
 .. math::
     \text{out}_i = \text{input}_i + \text{value} \times \frac{\text{tensor1}_i}{\text{tensor2}_i}
@@ -2371,68 +2367,37 @@
     tensor(2.6537)
 """.format(**common_args))
 
-add_docstr(torch.div,
-           r"""
+add_docstr(torch.div, r"""
 div(input, other, *, out=None) -> Tensor
 
-Divides each element of the input ``input`` with the scalar ``other`` and
-returns a new resulting tensor.
-
-.. warning::
-    Integer division using div is no longer supported, and in a future release
-    div will perform true division as in Python 3. Use :func:`torch.true_divide`
-    or :func:`torch.floor_divide` (// in Python), instead.
+Divides each element of the input ``input`` by the corresponding element of
+:attr:`other`.
 
 .. math::
-    \text{{out}}_i = \frac{{\text{{input}}_i}}{{\text{{other}}}}
+    \text{{out}}_i = \frac{{\text{{input}}_i}}{{\text{{other}}_i}}
 
-If the :class:`torch.dtype` of ``input`` and ``other`` differ, the
-:class:`torch.dtype` of the result tensor is determined following rules
-described in the type promotion :ref:`documentation <type-promotion-doc>`. If
-``out`` is specified, the result must be :ref:`castable <type-promotion-doc>`
-to the :class:`torch.dtype` of the specified output tensor. Integral division
-by zero leads to undefined behavior.
+.. note::
+    Performs a "true" division like Python 3. See :func:`torch.floor_divide`
+    for floor division.
+
+Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
+:ref:`type promotion <type-promotion-doc>`, and integer, float, and complex inputs.
+Always promotes integer types to the default scalar type.
 
 Args:
-    {input}
-    other (Number): the number to be divided to each element of ``input``
+    input (Tensor): the dividend
+    other (Tensor or Number): the divisor
 
 Keyword args:
     {out}
 
-Example::
+Examples::
 
     >>> a = torch.randn(5)
     >>> a
     tensor([ 0.3810,  1.2774, -0.2972, -0.3719,  0.4637])
     >>> torch.div(a, 0.5)
     tensor([ 0.7620,  2.5548, -0.5944, -0.7439,  0.9275])
-
-.. function:: div(input, other, *, out=None) -> Tensor
-
-Each element of the tensor ``input`` is divided by each element of the tensor
-``other``. The resulting tensor is returned.
-
-.. math::
-    \text{{out}}_i = \frac{{\text{{input}}_i}}{{\text{{other}}_i}}
-
-The shapes of ``input`` and ``other`` must be :ref:`broadcastable
-<broadcasting-semantics>`. If the :class:`torch.dtype` of ``input`` and
-``other`` differ, the :class:`torch.dtype` of the result tensor is determined
-following rules described in the type promotion :ref:`documentation
-<type-promotion-doc>`. If ``out`` is specified, the result must be
-:ref:`castable <type-promotion-doc>` to the :class:`torch.dtype` of the
-specified output tensor. Integral division by zero leads to undefined behavior.
-
-Args:
-    input (Tensor): the numerator tensor
-    other (Tensor): the denominator tensor
-
-Keyword args:
-    {out}
-
-Example::
-
     >>> a = torch.randn(4, 4)
     >>> a
     tensor([[-0.3711, -1.9353, -0.4605, -0.2917],
@@ -7862,29 +7827,7 @@
 add_docstr(torch.true_divide, r"""
 true_divide(dividend, divisor, *, out) -> Tensor
 
-Performs "true division" that always computes the division
-in floating point. Analogous to division in Python 3 and equivalent to
-:func:`torch.div` except when both inputs have bool or integer scalar types,
-in which case they are cast to the default (floating) scalar type before the division.
-
-.. math::
-    \text{{out}}_i = \frac{{\text{{dividend}}_i}}{{\text{{divisor}}}}
-
-Args:
-    dividend (Tensor): the dividend
-    divisor (Tensor or Scalar): the divisor
-
-Keyword args:
-    {out}
-
-Example::
-
-    >>> dividend = torch.tensor([5, 3], dtype=torch.int)
-    >>> divisor = torch.tensor([3, 2], dtype=torch.int)
-    >>> torch.true_divide(dividend, divisor)
-    tensor([1.6667, 1.5000])
-    >>> torch.true_divide(dividend, 2)
-    tensor([2.5000, 1.5000])
+Alias for :func:`torch.div`.
 """.format(**common_args))
 
 add_docstr(torch.trunc,
diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp
index c20ed3e..b99e242 100644
--- a/torch/csrc/jit/passes/normalize_ops.cpp
+++ b/torch/csrc/jit/passes/normalize_ops.cpp
@@ -25,6 +25,7 @@
     {aten::less_equal, aten::le},    {aten::less_equal_, aten::le_},
     {aten::less, aten::lt},          {aten::less_, aten::lt_},
     {aten::not_equal, aten::ne},     {aten::not_equal_, aten::ne_},
+    {aten::true_divide, aten::div},  {aten::true_divide_, aten::div_},
 };
 
 void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 3cf03a3..5779b41 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -532,9 +532,18 @@
         ('div', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
         ('div', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs', (True,)),
         ('div', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', (True,)),
+        ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
+        ('true_divide', (S, S, S), (torch.rand(S, S, S) + 0.1,), '', (True,)),
+        ('true_divide', (S, S, S), (torch.rand(S, S) + 0.1,), 'broadcast_rhs', (True,)),
+        ('true_divide', (S, S), (torch.rand(S, S, S) + 0.1,), 'broadcast_lhs', (True,)),
+        ('true_divide', (S, 1, S), (torch.rand(M, S) + 0.1,), 'broadcast_all', (True,)),
+        ('true_divide', (), (uniform_scalar(0.1),), 'scalar', (True,)),
+        ('true_divide', (S, S, S), (uniform_scalar(0.1),), 'scalar_broadcast_rhs', (True,)),
+        ('true_divide', (), (uniform_scalar(0.1),), 'scalar_broadcast_lhs', (True,)),
+        ('true_divide', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant', (True,)),
+        ('true_divide', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
         ('__rdiv__', torch.rand(S, S, S) + 1e-1, (3.14,), 'constant',
             (True, [], ['aten::mul', 'aten::reciprocal'])),
-        ('div', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant', (True,)),
         ('__rdiv__', uniform_scalar(1e-1, requires_grad=True), (3.14,), 'scalar_constant',
             (True, [], ['aten::mul', 'aten::reciprocal'])),
         ('pow', torch.rand(S, S, S) + 1e-3, (torch.rand(S, S, S) + 0.1,), '', (True,)),