torch.tensor can infer complex dtype now (#33361)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/33361
Test Plan: Imported from OSS
Differential Revision: D19943477
Pulled By: anjali411
fbshipit-source-id: ff6d7d2a6fdb6c58390f33bdd8be2f3fa182518b
diff --git a/test/test_torch.py b/test/test_torch.py
index a737997..aa685cd 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -1353,6 +1353,7 @@
self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype)
self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype)
self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype)
+ self.assertIs(torch.complex128, torch.tensor(((5, 3 + 2j), (3, 5 + 4j))).dtype)
if TEST_NUMPY:
self.assertIs(torch.float64, torch.tensor(np.array(())).dtype)
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 8d30e6a..04357ba 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -169,6 +169,9 @@
if (PyBool_Check(obj)) {
return ScalarType::Bool;
}
+ if (PyComplex_Check(obj)) {
+ return ScalarType::ComplexDouble;
+ }
if (THPVariable_Check(obj)) {
auto var = reinterpret_cast<THPVariable*>(obj)->cdata;
return var.scalar_type();