[TensorExpr] Correctly print 'bool' dtype in Cuda printer. (#38077)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38077
Test Plan: Imported from OSS
Differential Revision: D21467298
Pulled By: ZolotukhinM
fbshipit-source-id: 65ac347f097e01aaf1d3ff5d598a402ca619d1f2
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index 44d1593..298511e 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -107,6 +107,8 @@
std::string cudaDtypeCppString(const Dtype& dtype) {
switch (dtype.scalar_type()) {
+ case ScalarType::Bool:
+ return "bool";
case ScalarType::Half:
return "half";
case ScalarType::Char:
@@ -117,9 +119,9 @@
return "short";
case ScalarType::Long:
return "long";
- default:; /* nothing */
+ default:
+ return dtype.ToCppString();
}
- return dtype.ToCppString();
}
static void print_flat_alloc(std::ostream& os, const Allocate* alloc) {