Migrate CPU tril, triu, masked_fill to c10::complex (#37897)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37897
Test Plan: Imported from OSS
Differential Revision: D21442181
Pulled By: anjali411
fbshipit-source-id: 609af9086da1b622db51694f65eadfebe3970cfd
diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp
index 1790b4cb..4513782 100644
--- a/aten/src/ATen/native/TriangularOps.cpp
+++ b/aten/src/ATen/native/TriangularOps.cpp
@@ -96,7 +96,7 @@
Tensor self_c;
std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
@@ -112,7 +112,7 @@
}
Tensor self_c;
std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "tril", [&]{
apply_triu_tril<scalar_t, false>(result, self_c, false, k);
});
return result;
@@ -132,7 +132,7 @@
Tensor self_c;
std::tie(inplace, self_c) = checkTrilTriuBatchContiguous(self, true);
Tensor result = inplace ? self : at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
});
if (!inplace) self.copy_(result);
@@ -148,7 +148,7 @@
}
Tensor self_c;
std::tie(std::ignore, self_c) = checkTrilTriuBatchContiguous(self, false);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
+ AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "triu", [&]{
apply_triu_tril<scalar_t, true>(result, self_c, false, k);
});
return result;
diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp
index 919b3ea..91685ce 100644
--- a/aten/src/ATen/native/cpu/IndexKernel.cpp
+++ b/aten/src/ATen/native/cpu/IndexKernel.cpp
@@ -151,7 +151,7 @@
}
void masked_fill_kernel(TensorIterator& iter, Scalar value) {
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
+ AT_DISPATCH_ALL_TYPES_AND_C10_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(), "masked_fill", [&] {
scalar_t scalar_val = value.to<scalar_t>();
auto mask_dtype = iter.input_dtype(0);