[11/N] Fix clang-tidy warnings in jit (#132131)
Follows #132122
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132131
Approved by: https://github.com/Skylion007
diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h
index 510cad4..68b96fe 100644
--- a/test/cpp/tensorexpr/test_base.h
+++ b/test/cpp/tensorexpr/test_base.h
@@ -40,6 +40,8 @@
}
#endif // defined(USE_GTEST)
+#include <string>
+#include <vector>
namespace torch {
namespace jit {
diff --git a/test/cpp/tensorexpr/test_loopnest.cpp b/test/cpp/tensorexpr/test_loopnest.cpp
index b04fe80..ab38745 100644
--- a/test/cpp/tensorexpr/test_loopnest.cpp
+++ b/test/cpp/tensorexpr/test_loopnest.cpp
@@ -2292,7 +2292,7 @@
return ordering.str();
}
- void visit(ForPtr v) final {
+ void visit(const ForPtr& v) final {
ordering << v->var()->name_hint() << ",";
IRVisitor::visit(v);
}
diff --git a/torch/csrc/jit/tensorexpr/analysis.h b/torch/csrc/jit/tensorexpr/analysis.h
index 2c6e25a..cabce61 100644
--- a/torch/csrc/jit/tensorexpr/analysis.h
+++ b/torch/csrc/jit/tensorexpr/analysis.h
@@ -7,9 +7,7 @@
#include <utility>
-namespace torch {
-namespace jit {
-namespace tensorexpr {
+namespace torch::jit::tensorexpr {
class HasRand : public IRVisitor {
public:
HasRand(StmtPtr stmt) : stmt_(std::move(stmt)) {
@@ -21,11 +19,11 @@
}
private:
- void visit(IntrinsicsPtr v) override {
+ void visit(const IntrinsicsPtr& v) override {
if (v->op_type() == IntrinsicsOp::kRand) {
has_rand_ = true;
} else {
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
}
StmtPtr stmt_;
@@ -33,10 +31,9 @@
};
template <typename Op>
-// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class NodeFinder : public IRVisitor {
public:
- void visit(NodePtr<Op> v) override {
+ void visit(const NodePtr<Op>& v) override {
nodes.push_back((NodePtr<Op>)v);
IRVisitor::visit(v);
}
@@ -59,9 +56,9 @@
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class VarFinder : public IRVisitor {
public:
- void visit(VarPtr v) override {
+ void visit(const VarPtr& v) override {
vars_.insert(v);
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
static std::unordered_set<VarPtr> find(const StmtPtr& s) {
@@ -86,9 +83,9 @@
class BufFinder : public IRVisitor {
public:
- void visit(BufPtr v) override {
+ void visit(const BufPtr& v) override {
bufs_.insert(v);
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
static std::unordered_set<BufPtr> find(const StmtPtr& s) {
@@ -128,13 +125,13 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
}
- void visit(AtomicAddPtr v) override {
+ void visit(const AtomicAddPtr& v) override {
if (v->buf() == target_) {
writes_.push_back(v);
}
@@ -170,25 +167,25 @@
return false;
}
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(LetPtr v) override {
+ void visit(const LetPtr& v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(CondPtr v) override {
+ void visit(const CondPtr& v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
}
- void visit(AtomicAddPtr v) override {
+ void visit(const AtomicAddPtr& v) override {
if (readsBuffer(v)) {
reads_.push_back(v);
}
@@ -200,10 +197,10 @@
class ExternalAllocBufFinder : public IRVisitor {
public:
- void visit(ExternalCallWithAllocPtr v) override {
+ void visit(const ExternalCallWithAllocPtr& v) override {
const auto& bufs_out = v->buf_out_args();
bufs_.insert(bufs_out.begin(), bufs_out.end());
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
static std::unordered_set<BufPtr> find(const StmtPtr& s) {
@@ -242,36 +239,36 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
}
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
- void visit(AtomicAddPtr v) override {
+ void visit(const AtomicAddPtr& v) override {
if (v->buf()->base_handle() == var_) {
found_ = true;
return;
}
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
- void visit(LetPtr v) override {
+ void visit(const LetPtr& v) override {
if (v->var() == var_) {
found_ = true;
return;
}
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
- void visit(ForPtr v) override {
+ void visit(const ForPtr& v) override {
if (v->var() == var_) {
found_ = true;
return;
}
- IRVisitor::visit(std::move(v));
+ IRVisitor::visit(v);
}
VarPtr var_;
@@ -362,7 +359,7 @@
}
}
- void visit(BlockPtr v) override {
+ void visit(const BlockPtr& v) override {
for (const StmtPtr& s : *v) {
curr_index_ += 1;
findAccAndUpdateLiveRange(s);
@@ -384,7 +381,7 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
auto load_node = to<Load>(v->value());
if (load_node) {
auto t_buf = load_node->buf();
@@ -401,6 +398,4 @@
std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
};
-} // namespace tensorexpr
-} // namespace jit
-} // namespace torch
+} // namespace torch::jit::tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp
index 0d465d5..ef28ec3 100644
--- a/torch/csrc/jit/tensorexpr/block_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp
@@ -61,16 +61,16 @@
}
}
-void BlockAnalysis::visit(StorePtr v) {
+void BlockAnalysis::visit(const StorePtr& v) {
store_targets_.insert(v->buf());
v->value()->accept(this);
}
-void BlockAnalysis::visit(LoadPtr v) {
+void BlockAnalysis::visit(const LoadPtr& v) {
loads_.insert(v->buf());
}
-void BlockAnalysis::visit(ForPtr v) {
+void BlockAnalysis::visit(const ForPtr& v) {
const LoopOptions& loop_options = v->loop_options();
if (loop_options.is_gpu_block_index()) {
map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping();
@@ -91,21 +91,21 @@
// TODO: When handling fused ops d = a + b + c, the correct
// way would be to mutate the expression to Block version and print.
-void BlockPrinter::visit(AddPtr v) {
+void BlockPrinter::visit(const AddPtr& v) {
emitIndent();
os() << "add(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void BlockPrinter::visit(MulPtr v) {
+void BlockPrinter::visit(const MulPtr& v) {
emitIndent();
os() << "mul(";
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void BlockPrinter::visit(ForPtr v) {
+void BlockPrinter::visit(const ForPtr& v) {
const LoopOptions& loop_options = v->loop_options();
auto buf_reads = block_analysis_->loads();
@@ -296,16 +296,16 @@
}
}
-void BlockPrinter::visit(LoadPtr v) {
+void BlockPrinter::visit(const LoadPtr& v) {
os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, ";
}
-void BlockPrinter::visit(StorePtr v) {
+void BlockPrinter::visit(const StorePtr& v) {
emitIndent();
os() << *v->value() << block_analysis_->getFlatInputName(v->buf())
<< ".tensor)" << '\n';
}
-void BlockPrinter::visit(BlockPtr v) {
+void BlockPrinter::visit(const BlockPtr& v) {
os() << "{" << '\n';
indent_++;
for (const StmtPtr& s : v->stmts()) {
diff --git a/torch/csrc/jit/tensorexpr/block_codegen.h b/torch/csrc/jit/tensorexpr/block_codegen.h
index 93728a8..d08c7ee 100644
--- a/torch/csrc/jit/tensorexpr/block_codegen.h
+++ b/torch/csrc/jit/tensorexpr/block_codegen.h
@@ -52,9 +52,9 @@
}
private:
- void visit(StorePtr v) override;
- void visit(LoadPtr v) override;
- void visit(ForPtr v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const ForPtr& v) override;
std::unordered_map<std::string, BufPtr> map_input_to_tensor_bufs_;
std::unordered_set<BufPtr> store_targets_;
@@ -87,12 +87,12 @@
void PrintDMAs(const std::unordered_set<BufPtr>& bufs);
void PrintAdjustBuffers(const std::unordered_set<BufPtr>& bufs);
- void visit(ForPtr v) override;
- void visit(LoadPtr v) override;
- void visit(StorePtr v) override;
- void visit(BlockPtr v) override;
- void visit(AddPtr v) override;
- void visit(MulPtr v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const AddPtr& v) override;
+ void visit(const MulPtr& v) override;
};
class TORCH_API BlockCodeGen : public CodeGen {
diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp
index 2056fa0..8e05d8e 100644
--- a/torch/csrc/jit/tensorexpr/cpp_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cpp_codegen.cpp
@@ -1,5 +1,6 @@
#include <algorithm>
#include <type_traits>
+#include <utility>
#include <vector>
#include <torch/csrc/jit/tensorexpr/cpp_codegen.h>
@@ -14,7 +15,7 @@
// with '_'.
class CppVarNameRewriter : public IRVisitor {
public:
- void visit(VarPtr v) override {
+ void visit(const VarPtr& v) override {
constexpr char kDot = '.';
constexpr char kUnderscore = '_';
if (v->name_hint().find(kDot) == std::string::npos) {
@@ -25,7 +26,7 @@
v->set_name_hint(std::move(name));
}
- void visit(BufPtr v) override {
+ void visit(const BufPtr& v) override {
v->base_handle()->accept(this);
}
};
@@ -47,75 +48,75 @@
CppPrinter::~CppPrinter() = default;
void CppPrinter::printPrologue() {
- os() << "#include <cassert>" << std::endl;
- os() << "#include <cmath>" << std::endl;
- os() << "#include <algorithm>" << std::endl;
- os() << "#include <type_traits>" << std::endl;
- os() << std::endl;
+ os() << "#include <cassert>" << '\n';
+ os() << "#include <cmath>" << '\n';
+ os() << "#include <algorithm>" << '\n';
+ os() << "#include <type_traits>" << '\n';
+ os() << '\n';
- os() << "#define POS_INFINITY INFINITY" << std::endl;
- os() << "#define NEG_INFINITY -INFINITY" << std::endl;
- os() << std::endl;
+ os() << "#define POS_INFINITY INFINITY" << '\n';
+ os() << "#define NEG_INFINITY -INFINITY" << '\n';
+ os() << '\n';
- os() << cpp_intrinsics_definition << std::endl;
- os() << std::endl;
+ os() << cpp_intrinsics_definition << '\n';
+ os() << '\n';
- os() << "namespace torch {" << std::endl;
- os() << "namespace jit {" << std::endl;
- os() << "namespace tensorexpr {" << std::endl;
+ os() << "namespace torch {" << '\n';
+ os() << "namespace jit {" << '\n';
+ os() << "namespace tensorexpr {" << '\n';
for (auto const& it : getNNCFunctionRegistry()) {
- os() << declareExternalFunction(it.first) << std::endl;
+ os() << declareExternalFunction(it.first) << '\n';
}
- os() << "} // namespace tensorexpr" << std::endl;
- os() << "} // namespace jit" << std::endl;
- os() << "} // namespace torch" << std::endl;
- os() << std::endl;
+ os() << "} // namespace tensorexpr" << '\n';
+ os() << "} // namespace jit" << '\n';
+ os() << "} // namespace torch" << '\n';
+ os() << '\n';
- os() << "using namespace torch::jit::tensorexpr;" << std::endl;
- os() << std::endl;
+ os() << "using namespace torch::jit::tensorexpr;" << '\n';
+ os() << '\n';
}
template <typename T>
-inline typename std::enable_if<!std::is_floating_point<T>::value, void>::type
-visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::enable_if_t<!std::is_floating_point_v<T>, void> visit_mod(
+ std::ostream& os,
+ const ExprPtr lhs,
+ const ExprPtr rhs) {
os << *lhs << " % " << *rhs;
}
template <typename T>
-inline typename std::enable_if<std::is_floating_point<T>::value, void>::type
-visit_mod(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::enable_if_t<std::is_floating_point_v<T>, void> visit_mod(
+ std::ostream& os,
+ const ExprPtr lhs,
+ const ExprPtr rhs) {
os << "std::fmod(" << *lhs << ", " << *rhs << ")";
}
template <typename T>
-inline typename std::enable_if<
- std::is_floating_point<T>::value || std::is_integral<T>::value,
- void>::type
-visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::
+ enable_if_t<std::is_floating_point_v<T> || std::is_integral_v<T>, void>
+ visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
os << "std::max(" << *lhs << ", " << *rhs << ")";
}
template <typename T>
-inline typename std::enable_if<
- !std::is_floating_point<T>::value && !std::is_integral<T>::value,
- void>::type
-visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::
+ enable_if_t<!std::is_floating_point_v<T> && !std::is_integral_v<T>, void>
+ visit_max(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
os << "(" << *lhs << " < " << *rhs << ") ? " << *rhs << " : " << *lhs;
}
template <typename T>
-inline typename std::enable_if<
- std::is_floating_point<T>::value || std::is_integral<T>::value,
- void>::type
-visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::
+ enable_if_t<std::is_floating_point_v<T> || std::is_integral_v<T>, void>
+ visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
os << "std::min(" << *lhs << ", " << *rhs << ")";
}
template <typename T>
-inline typename std::enable_if<
- !std::is_floating_point<T>::value && !std::is_integral<T>::value,
- void>::type
-visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
+inline std::
+ enable_if_t<!std::is_floating_point_v<T> && !std::is_integral_v<T>, void>
+ visit_min(std::ostream& os, const ExprPtr lhs, const ExprPtr rhs) {
os << *lhs << " < " << *rhs << " ? " << *lhs << " : " << *rhs;
}
@@ -154,38 +155,38 @@
}
}
-void CppPrinter::visit(RampPtr v) {
+void CppPrinter::visit(const RampPtr& v) {
visit(alloc<Add>(v->base(), alloc<Mul>(alloc<IntImm>(lane_), v->stride())));
}
-void CppPrinter::visit(BroadcastPtr v) {
+void CppPrinter::visit(const BroadcastPtr& v) {
v->value()->accept(this);
}
-void CppPrinter::visit(ModPtr v) {
+void CppPrinter::visit(const ModPtr& v) {
dispatch_binary_op(os(), v.get());
}
-void CppPrinter::visit(MaxPtr v) {
+void CppPrinter::visit(const MaxPtr& v) {
dispatch_binary_op(os(), v.get());
}
-void CppPrinter::visit(MinPtr v) {
+void CppPrinter::visit(const MinPtr& v) {
dispatch_binary_op(os(), v.get());
}
-void CppPrinter::visit(CompareSelectPtr v) {
+void CppPrinter::visit(const CompareSelectPtr& v) {
os() << "((" << *v->lhs() << " "
<< IRPrinter::to_string(v->compare_select_op()) << " " << *v->rhs()
<< ") ? " << *v->ret_val1() << " : " << *v->ret_val2() << ")";
}
-void CppPrinter::visit(IfThenElsePtr v) {
+void CppPrinter::visit(const IfThenElsePtr& v) {
os() << "((" << *v->condition() << ") ? " << *v->true_value() << " : "
<< *v->false_value() << ")";
}
-void CppPrinter::visit(AllocatePtr v) {
+void CppPrinter::visit(const AllocatePtr& v) {
size_t size = v->dtype().byte_size();
for (const auto& dim : v->dims()) {
IntImmPtr d = to<IntImm>(dim);
@@ -199,21 +200,21 @@
emitIndent();
os() << v->dtype().ToCppString() << "* " << (*v->buffer_var())
<< " = static_cast<" << v->dtype().ToCppString() << "*>(malloc(" << size
- << "));" << std::endl;
+ << "));" << '\n';
}
-void CppPrinter::visit(FreePtr v) {
+void CppPrinter::visit(const FreePtr& v) {
emitIndent();
- os() << "free(" << *v->buffer_var() << ");" << std::endl;
+ os() << "free(" << *v->buffer_var() << ");" << '\n';
}
-void CppPrinter::visit(LoadPtr v) {
+void CppPrinter::visit(const LoadPtr& v) {
auto flat_idx =
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
os() << *v->base_handle() << "[" << *flat_idx << "]";
}
-void CppPrinter::visit(StorePtr v) {
+void CppPrinter::visit(const StorePtr& v) {
auto flat_idx =
flatten_index(v->buf()->dims(), v->indices(), v->buf()->strides());
const int lanes = v->value()->dtype().lanes();
@@ -221,21 +222,21 @@
lane_ = lane;
emitIndent();
os() << *v->base_handle() << "[" << *flat_idx << "] = " << *v->value()
- << ";" << std::endl;
+ << ";" << '\n';
}
}
-void CppPrinter::visit(CastPtr v) {
+void CppPrinter::visit(const CastPtr& v) {
os() << "static_cast<" << v->dtype().ToCppString() << ">(" << *v->src_value()
<< ")";
}
-void CppPrinter::visit(BitCastPtr v) {
+void CppPrinter::visit(const BitCastPtr& v) {
os() << "std::bitcast<" << v->src_value()->dtype().ToCppString() << ", "
<< v->dtype().ToCppString() << ">(" << *v->src_value() << ")";
}
-void CppPrinter::visit(IntrinsicsPtr v) {
+void CppPrinter::visit(const IntrinsicsPtr& v) {
if (v->op_type() == kRand || v->op_type() == kSigmoid) {
throw std::runtime_error("kRand and kSigmoid are not supported");
}
@@ -250,7 +251,7 @@
os() << ")";
}
-void CppPrinter::visit(ExternalCallPtr v) {
+void CppPrinter::visit(const ExternalCallPtr& v) {
// The generated code needs to link against functions defined
// in external_functions.cpp.
@@ -271,22 +272,22 @@
};
emitIndent();
- os() << "{" << std::endl;
+ os() << "{" << '\n';
indent_++;
emitIndent();
os() << "void* buf_ptrs[]{";
- for_buf([&](const BufPtr b) { os() << *b->base_handle(); });
- os() << "};" << std::endl;
+ for_buf([&](const BufPtr& b) { os() << *b->base_handle(); });
+ os() << "};" << '\n';
emitIndent();
os() << "int64_t buf_ranks[]{";
- for_buf([&](const BufPtr b) { os() << b->ndim(); });
- os() << "};" << std::endl;
+ for_buf([&](const BufPtr& b) { os() << b->ndim(); });
+ os() << "};" << '\n';
emitIndent();
os() << "int64_t buf_dims[]{";
- for_buf([&](const BufPtr buf) {
+ for_buf([&](const BufPtr& buf) {
for (size_t i = 0; i < buf->ndim(); i++) {
if (i > 0) {
os() << ", ";
@@ -294,14 +295,14 @@
os() << *buf->dim(i);
}
});
- os() << "};" << std::endl;
+ os() << "};" << '\n';
emitIndent();
os() << "int8_t buf_dtypes[]{";
- for_buf([&](const BufPtr buf) {
+ for_buf([&](const BufPtr& buf) {
os() << static_cast<int>(buf->dtype().scalar_type());
});
- os() << "};" << std::endl;
+ os() << "};" << '\n';
emitIndent();
os() << "int64_t extra_args[]{";
@@ -311,41 +312,41 @@
}
os() << *v->args()[i];
}
- os() << "};" << std::endl;
+ os() << "};" << '\n';
emitIndent();
- os() << v->func_name() << "(" << std::endl;
+ os() << v->func_name() << "(" << '\n';
emitIndent();
- os() << " " << bufs.size() << "," << std::endl;
+ os() << " " << bufs.size() << "," << '\n';
emitIndent();
- os() << " buf_ptrs," << std::endl;
+ os() << " buf_ptrs," << '\n';
emitIndent();
- os() << " buf_ranks," << std::endl;
+ os() << " buf_ranks," << '\n';
emitIndent();
- os() << " buf_dims," << std::endl;
+ os() << " buf_dims," << '\n';
emitIndent();
- os() << " buf_dtypes," << std::endl;
+ os() << " buf_dtypes," << '\n';
emitIndent();
- os() << " " << v->args().size() << "," << std::endl;
+ os() << " " << v->args().size() << "," << '\n';
emitIndent();
- os() << " extra_args);" << std::endl;
+ os() << " extra_args);" << '\n';
indent_--;
emitIndent();
- os() << "}" << std::endl;
+ os() << "}" << '\n';
}
-void CppPrinter::visit(LetPtr v) {
+void CppPrinter::visit(const LetPtr& v) {
if (v->var()->dtype().lanes() == 1) {
emitIndent();
os() << v->var()->dtype().ToCppString() << " " << *v->var() << " = "
- << *v->value() << ";" << std::endl;
+ << *v->value() << ";" << '\n';
} else {
vector_vars_[v->var()] = v->value();
}
}
-void CppPrinter::visit(VarPtr v) {
+void CppPrinter::visit(const VarPtr& v) {
if (v->dtype().lanes() == 1) {
os() << name_manager()->get_unique_name(v);
} else {
@@ -358,7 +359,7 @@
const std::vector<BufferArg>& buffer_args,
at::Device device,
const std::string& kernel_func_name)
- : CodeGen(stmt, buffer_args, device, kernel_func_name) {
+ : CodeGen(std::move(stmt), buffer_args, device, kernel_func_name) {
init();
}
@@ -382,7 +383,7 @@
}
os() << ")";
stmt()->accept(printer_.get());
- os() << std::endl;
+ os() << '\n';
}
CppCodeGen::~CppCodeGen() = default;
@@ -390,13 +391,13 @@
void CppCodeGen::call(const std::vector<CallArg>& args) {
// TODO: compile the generated C++ kernel into a library,
// and call the library here.
- os() << "int main() {}" << std::endl;
+ os() << "int main() {}" << '\n';
}
void CppCodeGen::call_raw(const std::vector<void*>& args) {
// TODO: compile the generated C++ kernel into a library,
// and call the library here.
- os() << "int main() {}" << std::endl;
+ os() << "int main() {}" << '\n';
}
RegisterCodeGen<CppCodeGen> cpp_codegen_reg("cpp_codegen");
diff --git a/torch/csrc/jit/tensorexpr/cpp_codegen.h b/torch/csrc/jit/tensorexpr/cpp_codegen.h
index a6d583e..6ae1c9c 100644
--- a/torch/csrc/jit/tensorexpr/cpp_codegen.h
+++ b/torch/csrc/jit/tensorexpr/cpp_codegen.h
@@ -28,35 +28,35 @@
using IRPrinter::visit;
// Binary expressions.
- void visit(ModPtr) override;
- void visit(MaxPtr) override;
- void visit(MinPtr) override;
+ void visit(const ModPtr&) override;
+ void visit(const MaxPtr&) override;
+ void visit(const MinPtr&) override;
// Conditional expressions.
- void visit(CompareSelectPtr) override;
- void visit(IfThenElsePtr) override;
+ void visit(const CompareSelectPtr&) override;
+ void visit(const IfThenElsePtr&) override;
// Tensor operations.
- void visit(AllocatePtr) override;
- void visit(FreePtr) override;
- void visit(LoadPtr) override;
- void visit(StorePtr) override;
+ void visit(const AllocatePtr&) override;
+ void visit(const FreePtr&) override;
+ void visit(const LoadPtr&) override;
+ void visit(const StorePtr&) override;
// Casts.
- void visit(CastPtr) override;
- void visit(BitCastPtr) override;
+ void visit(const CastPtr&) override;
+ void visit(const BitCastPtr&) override;
// Calls.
- void visit(IntrinsicsPtr) override;
- void visit(ExternalCallPtr) override;
+ void visit(const IntrinsicsPtr&) override;
+ void visit(const ExternalCallPtr&) override;
// Vars.
- void visit(LetPtr) override;
- void visit(VarPtr) override;
+ void visit(const LetPtr&) override;
+ void visit(const VarPtr&) override;
// Vector data types.
- void visit(RampPtr) override;
- void visit(BroadcastPtr) override;
+ void visit(const RampPtr&) override;
+ void visit(const BroadcastPtr&) override;
private:
int lane_;
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index d8f8f1e..7ee0c22 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -23,6 +23,7 @@
#endif
#include <unordered_map>
+#include <utility>
namespace torch::jit::tensorexpr {
@@ -30,7 +31,7 @@
// TODO: move this to a more shared place.
class ScopedVarName {
public:
- ScopedVarName(VarNameMap* mapping, VarPtr var, const std::string& name)
+ ScopedVarName(VarNameMap* mapping, const VarPtr& var, const std::string& name)
: mapping_(mapping), var_(var) {
auto iter = mapping->find(var);
if (iter != mapping->end()) {
@@ -39,7 +40,10 @@
mapping->insert(std::make_pair(var, name));
}
- ScopedVarName(UniqueNameManager* manager, VarPtr var, const std::string& name)
+ ScopedVarName(
+ UniqueNameManager* manager,
+ const VarPtr& var,
+ const std::string& name)
: ScopedVarName(&manager->unique_name_mapping_, var, name) {}
ScopedVarName(const ScopedVarName&) = delete;
@@ -54,7 +58,7 @@
VarPtr var_ = nullptr;
};
-static bool is_zero(ExprPtr expr) {
+static bool is_zero(const ExprPtr& expr) {
auto v = intValue(expr);
return v && *v == 0;
}
@@ -84,19 +88,18 @@
}
}
-void CudaAnalysis::visit(FreePtr v) {
+void CudaAnalysis::visit(const FreePtr& v) {
if (thread_local_bufs_.count(v->buffer_var()) == 0 &&
cross_block_bufs_.count(v->buffer_var()) == 0) {
throw std::runtime_error("Global free not supported yet");
}
}
-void CudaAnalysis::visit(AllocatePtr v) {
+void CudaAnalysis::visit(const AllocatePtr& v) {
StmtPtr p = v->get_parent();
while (p) {
ForPtr for_v = to<For>(p);
if (for_v) {
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (for_v->loop_options().is_gpu_block_index()) {
// TODO: This isn't right if there's a thread index at a higher level
// than this.
@@ -112,11 +115,11 @@
throw std::runtime_error("Global alloc not supported yet");
}
-void CudaAnalysis::visit(PlacementAllocatePtr v) {
+void CudaAnalysis::visit(const PlacementAllocatePtr& v) {
throw std::runtime_error("Memory reuse not supported yet");
}
-void CudaAnalysis::visit(ForPtr v) {
+void CudaAnalysis::visit(const ForPtr& v) {
// Recurse first.
v->body()->accept(this);
@@ -127,7 +130,6 @@
throw std::runtime_error("support only 3D gpu_block_index");
}
ExprPtr prev = nullptr;
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_block_extents_.size() <= static_cast<size_t>(gpu_block_index)) {
gpu_block_extents_.resize(gpu_block_index + 1);
} else {
@@ -156,7 +158,6 @@
throw std::runtime_error("support only 3D gpu_thread_index");
}
ExprPtr prev = nullptr;
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (gpu_thread_extents_.size() <= static_cast<size_t>(gpu_thread_index)) {
gpu_thread_extents_.resize(gpu_thread_index + 1);
} else {
@@ -182,12 +183,11 @@
}
}
-void CudaPrinter::print_flat_alloc(AllocatePtr alloc) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+void CudaPrinter::print_flat_alloc(const AllocatePtr& alloc) {
std::vector<ExprPtr> dims = alloc->dims();
// TODO: this should be merged with the storage flattener.
int64_t flat_size = 1;
- for (auto dim : dims) {
+ for (const auto& dim : dims) {
auto dim_i = intValue(dim);
if (dim_i) {
flat_size *= *dim_i;
@@ -196,10 +196,10 @@
}
}
os() << dtypeToCppString(alloc->dtype()) << " " << (*alloc->buffer_var())
- << "[" << flat_size << "];" << std::endl;
+ << "[" << flat_size << "];" << '\n';
}
-void CudaPrinter::visit(AllocatePtr v) {
+void CudaPrinter::visit(const AllocatePtr& v) {
// TODO: handle dynamic shapes here.
if (cuda_analysis_->cross_block_bufs().count(v->buffer_var()) != 0) {
emitIndent();
@@ -217,15 +217,15 @@
throw std::runtime_error("Encountered Alloc not local to block or thread");
}
-void CudaPrinter::visit(FreePtr v) {
+void CudaPrinter::visit(const FreePtr& v) {
// do nothing
}
-void CudaPrinter::visit(ForPtr v) {
+void CudaPrinter::visit(const ForPtr& v) {
IRPrinter::visit(v);
}
-void CudaPrinter::visit(CastPtr v) {
+void CudaPrinter::visit(const CastPtr& v) {
std::string castFn = v->dtype().scalar_type() == ScalarType::Half
? "__float2half"
: v->dtype().scalar_type() == ScalarType::BFloat16 ? "__float2bfloat16"
@@ -239,7 +239,7 @@
os() << ")";
}
-void CudaPrinter::visit(IntrinsicsPtr v) {
+void CudaPrinter::visit(const IntrinsicsPtr& v) {
if (v->op_type() == IntrinsicsOp::kRand) {
os() << "Uint32ToFloat(" << *rand_func_ << "())";
return;
@@ -275,11 +275,11 @@
os() << ")";
}
-void CudaPrinter::visit(ExternalCallPtr v) {
+void CudaPrinter::visit(const ExternalCallPtr& v) {
throw unimplemented_lowering(v);
}
-void CudaPrinter::visit(LoadPtr v) {
+void CudaPrinter::visit(const LoadPtr& v) {
// TODO: find a better metric in using ldg or not. Support different dtypes.
// Detects whether the load target is also a store target.
// TODO: this is currently too wide. It detects whether a store-target
@@ -307,7 +307,7 @@
// TODO: maybe this should be a more shared location?
// TODO: investigate how "ExprPtr" can be implicitly converted to "ExprHandle"
// as a bool.
-static bool CheckEqual(ExprPtr lhs, ExprPtr rhs) {
+static bool CheckEqual(const ExprPtr& lhs, const ExprPtr& rhs) {
// The fast path. Checks if the pointers are the same.
if (lhs == rhs) {
return true;
@@ -319,7 +319,6 @@
class AtomicAddFuser : public IRMutator {
public:
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
AtomicAddFuser(
const std::unordered_set<VarPtr>& thread_local_bufs,
const GPUMetaVarRewriter& metavars)
@@ -382,9 +381,8 @@
// TODO: this checks that the metavars occur directly as an index, but this
// is pessimistic, blockIdx.x + 1 is fine too if there is no overlapping.
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unordered_set<VarPtr> vars_to_find = nontrivial_metavars_;
- for (ExprPtr e : v->indices()) {
+ for (const ExprPtr& e : v->indices()) {
if (VarPtr v = to<Var>(e)) {
vars_to_find.erase(v);
}
@@ -399,6 +397,7 @@
}
private:
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::unordered_set<VarPtr>& thread_local_bufs_;
struct MetaVarExtent {
ExprPtr expr{nullptr};
@@ -408,7 +407,7 @@
std::unordered_set<VarPtr> nontrivial_metavars_;
};
-void CudaPrinter::visit(StorePtr v) {
+void CudaPrinter::visit(const StorePtr& v) {
emitIndent();
if (v->indices().empty()) {
os() << *v->base_handle() << " = ";
@@ -416,10 +415,10 @@
os() << *v->base_handle() << "[" << *v->flat_index() << "] = ";
}
os() << *v->value() << ";";
- os() << std::endl;
+ os() << '\n';
}
-void CudaPrinter::visit(AtomicAddPtr v) {
+void CudaPrinter::visit(const AtomicAddPtr& v) {
emitIndent();
if (cuda_analysis_->thread_local_bufs().count(v->base_handle()) > 0) {
// atomicAdd only works on global and shared memory
@@ -429,10 +428,10 @@
os() << "atomicAdd(&" << *v->base_handle() << "[" << *v->flat_index() << "]"
<< ", " << *v->value() << ");";
}
- os() << std::endl;
+ os() << '\n';
}
-void CudaPrinter::visit(MaxPtr v) {
+void CudaPrinter::visit(const MaxPtr& v) {
if (v->dtype().is_integral()) {
os() << "max(";
} else {
@@ -444,7 +443,7 @@
os() << ")";
}
-void CudaPrinter::visit(MinPtr v) {
+void CudaPrinter::visit(const MinPtr& v) {
if (v->dtype().is_integral()) {
os() << "min(";
} else {
@@ -456,7 +455,7 @@
os() << ")";
}
-void CudaPrinter::visit(IfThenElsePtr v) {
+void CudaPrinter::visit(const IfThenElsePtr& v) {
os() << "((";
v->condition()->accept(this);
os() << ") ? ";
@@ -466,11 +465,11 @@
os() << ")";
}
-void CudaPrinter::visit(BlockPtr v) {
- os() << "{" << std::endl;
+void CudaPrinter::visit(const BlockPtr& v) {
+ os() << "{" << '\n';
indent_++;
- for (StmtPtr s : v->stmts()) {
+ for (const StmtPtr& s : v->stmts()) {
s->accept(this);
}
@@ -479,15 +478,14 @@
os() << "}";
}
-void CudaPrinter::visit(LetPtr v) {
+void CudaPrinter::visit(const LetPtr& v) {
emitIndent();
os() << dtypeToCppString(v->var()->dtype());
os() << " " << *v->var() << " = ";
v->value()->accept(this);
- os() << ";" << std::endl;
+ os() << ";" << '\n';
}
-// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
class PrioritizeLoad : public IRMutator {
public:
ExprPtr mutate(LoadPtr v) override {
@@ -501,7 +499,7 @@
if (thread_local_bufs_.count(v->base_handle()) > 0) {
return IRMutator::mutate(v);
}
- if (v->indices().size() == 0) {
+ if (v->indices().empty()) {
return IRMutator::mutate(v);
}
if (nested_store_) {
@@ -526,7 +524,7 @@
MemLoadList& load_list = load_stack_.back();
VarPtr load_new_var = alloc<Var>("v", v->dtype());
ExprPtr new_value = IRMutator::mutate(v);
- load_list.push_back(std::make_pair(load_new_var, new_value));
+ load_list.emplace_back(load_new_var, new_value);
return load_new_var;
}
@@ -547,9 +545,8 @@
load_list.pop_back();
new_var = alloc<Var>("v", v->dtype());
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
ExprPtr new_value = alloc<Cast>(v->dtype(), pair.second);
- load_list.push_back(std::make_pair(new_var, new_value));
+ load_list.emplace_back(new_var, new_value);
return new_var;
}
@@ -569,9 +566,8 @@
}
StmtPtr mutate(BlockPtr v) override {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::list<StmtPtr> stmts = v->stmts();
- for (StmtPtr stmt : stmts) {
+ for (const StmtPtr& stmt : stmts) {
PushList();
StmtPtr stmt_new = stmt->accept_mutator(this);
@@ -599,14 +595,14 @@
using MemoryLoadStack = std::vector<MemLoadList>;
void PushList() {
- load_stack_.push_back(MemLoadList());
+ load_stack_.emplace_back();
}
void PopList() {
load_stack_.pop_back();
}
- void AddMemLoadsFromList(BlockPtr block, StmtPtr last) {
+ void AddMemLoadsFromList(const BlockPtr& block, const StmtPtr& last) {
MemLoadList& load_list = load_stack_.back();
if (load_list.empty()) {
return;
@@ -681,7 +677,6 @@
old_reach = current_block_reach_[gpu_block_index];
// Extents must be positive, assume >= 1.
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
current_block_reach_[gpu_block_index] = v->stop();
} else {
@@ -689,7 +684,6 @@
IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
VarPtr metaVar = gpu_block_vars_[gpu_block_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
} else if (loop_options.is_gpu_thread_index()) {
@@ -700,7 +694,6 @@
old_reach = current_thread_reach_[gpu_thread_index];
// Extents must be positive, assume >= 1.
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (old_reach->isConstant() && immediateEquals(old_reach, 1)) {
current_thread_reach_[gpu_thread_index] = v->stop();
} else {
@@ -708,7 +701,6 @@
IRSimplifier::simplify(alloc<Max>(old_reach, v->stop(), true));
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
VarPtr metaVar = gpu_thread_vars_[gpu_thread_index];
body = Substitute(Stmt::clone(body), {{v->var(), metaVar}});
}
@@ -717,7 +709,6 @@
body = Stmt::clone(body->accept_mutator(this));
// pop the internal reach off the stack.
- // NOLINTNEXTLINE(bugprone-branch-clone)
if (loop_options.is_gpu_block_index()) {
current_block_reach_[loop_options.gpu_block_index()] = old_reach;
return body;
@@ -730,7 +721,6 @@
}
StmtPtr GPUMetaVarRewriter::mutate(BlockPtr v) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<Segment> innerSegments;
Segment current;
@@ -745,7 +735,7 @@
// the same launch reach. Segments are comprised of all statements that aren't
// loops - which are their own segments. Some operations, such as threading
// and memory ops should never be masked and so also get their own segment.
- for (StmtPtr stmt : *v) {
+ for (const StmtPtr& stmt : *v) {
StmtPtr stmt_new = stmt->accept_mutator(this);
if (stmt == stmt_new) {
stmt_new = Stmt::clone(stmt_new);
@@ -776,10 +766,9 @@
// We are max extent in all dimensions, so need no masks at this level.
if (isFullExtent()) {
// flatten inner segments.
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<StmtPtr> stmts;
for (auto& v : innerSegments) {
- for (auto s : v.stmts()) {
+ for (const auto& s : v.stmts()) {
stmts.push_back(s);
}
}
@@ -787,7 +776,6 @@
return alloc<Block>(stmts);
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<StmtPtr> stmts;
for (auto& segment : innerSegments) {
bool need_sync = false;
@@ -846,7 +834,7 @@
std::ostream& out,
const std::vector<ExprPtr>& exprs) {
size_t i = 0;
- for (auto expr : exprs) {
+ for (const auto& expr : exprs) {
if (i++ > 0) {
out << ", ";
}
@@ -900,18 +888,18 @@
os() << device_resource_string << shared_resource_string;
if (has_random_) {
- os() << philox_random_string << std::endl;
+ os() << philox_random_string << '\n';
}
if (halfChecker.hasHalf()) {
- os() << fuser::cuda::half_support_literal << std::endl;
+ os() << fuser::cuda::half_support_literal << '\n';
}
if (halfChecker.hasBFloat16()) {
- os() << fuser::cuda::bfloat16_support_literal << std::endl;
+ os() << fuser::cuda::bfloat16_support_literal << '\n';
}
std::string func_name = GetUniqueFuncName(kernel_func_name());
- os() << "extern \"C\" __global__" << std::endl;
+ os() << "extern \"C\" __global__" << '\n';
#if defined(USE_ROCM)
// CUDA has a default limit of threads per block (=flat work group size)
// of 1024, but ROCm uses 256 by default. At the time of writing
@@ -924,7 +912,6 @@
os() << "__attribute__((amdgpu_flat_work_group_size(1, 1024)))" << std::endl;
#endif
os() << "void " << func_name << "(";
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const std::vector<BufferArg> buffer_args = this->buffer_args();
for (size_t i = 0; i < buffer_args.size(); i++) {
if (i > 0) {
@@ -938,9 +925,7 @@
<< (buffer_arg.isVar() ? " " : "* ")
<< name_manager()->get_unique_name(var);
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
VarPtr rand_seed;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
VarPtr rand_offset;
if (has_random_) {
// TODO: switch to kUint64 when it is available.
@@ -951,17 +936,15 @@
<< *rand_offset;
}
os() << ") {";
- os() << std::endl;
+ os() << '\n';
if (has_random_) {
VarPtr idx = alloc<Var>("idx", kInt);
- os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;"
- << std::endl;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" << '\n';
VarPtr rand_func = printer_->rand_func();
os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", "
- << *rand_offset << ");" << std::endl;
- os() << std::endl;
+ << *rand_offset << ");" << '\n';
+ os() << '\n';
}
stmt_v->accept(cuda_analysis_.get());
@@ -969,7 +952,7 @@
stmt_v = stmt_v->accept_mutator(metavar_rewriter_.get());
AtomicAddFuser atomic_add_fuser(
- cuda_analysis_->thread_local_bufs(), *metavar_rewriter_.get());
+ cuda_analysis_->thread_local_bufs(), *metavar_rewriter_);
stmt_v = stmt_v->accept_mutator(&atomic_add_fuser);
stmt_v = registerize(stmt_v);
@@ -985,7 +968,7 @@
set_stmt(stmt_v);
stmt_v->accept(printer_.get());
- os() << std::endl;
+ os() << '\n';
os() << "}";
// Check that all block extents had been set.
@@ -1003,7 +986,7 @@
auto block_extents = metavar_rewriter_->gpu_block_extents();
auto thread_extents = metavar_rewriter_->gpu_thread_extents();
bool canCallWithNumel =
- !has_random_ && block_extents.size() > 0 && thread_extents.size() > 0;
+ !has_random_ && !block_extents.empty() && !thread_extents.empty();
for (size_t i = 1; i < block_extents.size() && canCallWithNumel; i++) {
canCallWithNumel = canCallWithNumel && block_extents[i]->isConstant() &&
immediateAs<int>(block_extents[i]) == 1;
@@ -1050,8 +1033,7 @@
block_extents_eval_.emplace_back(
ExprEval<LLVMCodeGen>(ExprHandle(be), extents_buffer_args));
#else
- block_extents_eval_.emplace_back(
- ExprEval<SimpleIREvaluator>(ExprHandle(be), extents_buffer_args));
+ block_extents_eval_.emplace_back(ExprHandle(be), extents_buffer_args);
#endif
}
thread_extents_eval_.reserve(thread_extents.size());
@@ -1060,8 +1042,7 @@
thread_extents_eval_.emplace_back(
ExprEval<LLVMCodeGen>(ExprHandle(te), extents_buffer_args));
#else
- thread_extents_eval_.emplace_back(
- ExprEval<SimpleIREvaluator>(ExprHandle(te), extents_buffer_args));
+ thread_extents_eval_.emplace_back(ExprHandle(te), extents_buffer_args);
#endif
}
@@ -1102,7 +1083,6 @@
std::vector<void*> ptr_to_args(buffer_args.size());
for (size_t i = 0; i < buffer_args.size(); i++) {
ptr_to_args[i] =
- // NOLINTNEXTLINE: const_cast
buffer_args[i].isVar() ? args[i] : const_cast<void**>(&args[i]);
}
@@ -1145,10 +1125,8 @@
"cuda_codegen: block or thread extent greater than 3D");
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<int> gpu_block_extents_v(3, 1);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- std::vector<int> gpu_thread_extents_v(3, 1);
+ std::vector<int64_t> gpu_block_extents_v(3, 1);
+ std::vector<int64_t> gpu_thread_extents_v(3, 1);
// evaluate all the block/thread extents into values
// TODO: eventually, codegen these calculations and make them part of the
@@ -1187,20 +1165,18 @@
}
// Skip launching the kernel if there are no elements to process.
- for (int extent : gpu_block_extents_v) {
+ for (auto extent : gpu_block_extents_v) {
if (extent == 0) {
return;
}
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int ptr_count = buffer_args.size();
+ auto ptr_count = buffer_args.size();
// If the kernel has a rand call in it, add two extra arguments for random
// seed and offset.
if (has_random_) {
ptr_count += 2;
}
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<void*> ptr_to_args(ptr_count);
// In CUDA we need to pass pointers to pointers for buffers, thus we need to
@@ -1262,7 +1238,6 @@
}
auto const& buffer_args = this->buffer_args();
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<void*> raw_args(buffer_args.size());
for (size_t i = 0; i < buffer_args.size(); i++) {
auto const& bufferArg = buffer_args[i];
@@ -1296,16 +1271,13 @@
}
// Acquires device and NVRTC properties (for compile arch and occupancy
// calculations)
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- int major, minor;
+ int major = 0, minor = 0;
bool compile_to_sass = false;
fuser::cuda::codegenOutputQuery(prop, major, minor, compile_to_sass);
// Creates the NVRTC program
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- nvrtcProgram program;
+ nvrtcProgram program{nullptr};
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram(
&program, code.c_str(), nullptr, 0, nullptr, nullptr));
@@ -1327,31 +1299,26 @@
"compute_" +
#endif
std::to_string(major) + std::to_string(minor);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
const std::vector<const char*> args = {
"--std=c++17", compute.c_str(), "-default-device"};
#endif
auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data());
if (result != NVRTC_SUCCESS) {
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- size_t logsize;
+ size_t logsize = 0;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize));
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::vector<char> log(logsize);
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data()));
std::stringstream cu;
- cu << log.data() << std::endl;
- cu << "nvrtc compilation failed: " << std::endl;
- cu << code << std::endl;
+ cu << log.data() << '\n';
+ cu << "nvrtc compilation failed: " << '\n';
+ cu << code << '\n';
throw std::runtime_error(cu.str());
}
ResourceGuard holdProgram(
[&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); });
AT_CUDA_NVRTC_CHECK(result);
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- size_t ptx_size;
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+ size_t ptx_size = 0;
std::vector<char> ptx;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010
// compile_to_sass determines whether we are generating SASS or PTX, hence
@@ -1369,8 +1336,7 @@
ptx.resize(ptx_size);
AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
- // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
- CUmodule module;
+ CUmodule module{nullptr};
AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data()));
AT_CUDA_DRIVER_CHECK(
nvrtc().cuModuleGetFunction(&function_, module, func_name.c_str()));
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h
index 6bf82ff..b1daaa4 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.h
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h
@@ -49,14 +49,14 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
store_targets_.insert(v->buf());
}
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
- void visit(PlacementAllocatePtr v) override;
- void visit(ForPtr v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
+ void visit(const PlacementAllocatePtr& v) override;
+ void visit(const ForPtr& v) override;
std::unordered_set<BufPtr> store_targets_;
std::unordered_set<VarPtr> thread_local_bufs_;
@@ -162,22 +162,22 @@
}
}
- void visit(CastPtr v) override;
- void visit(IntrinsicsPtr v) override;
- void visit(ForPtr v) override;
+ void visit(const CastPtr& v) override;
+ void visit(const IntrinsicsPtr& v) override;
+ void visit(const ForPtr& v) override;
- void visit(LoadPtr v) override;
- void visit(StorePtr v) override;
- void visit(AtomicAddPtr v) override;
- void visit(MaxPtr v) override;
- void visit(MinPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(BlockPtr v) override;
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
- void visit(LetPtr v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const AtomicAddPtr& v) override;
+ void visit(const MaxPtr& v) override;
+ void visit(const MinPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
+ void visit(const LetPtr& v) override;
- void visit(ExternalCallPtr v) override;
+ void visit(const ExternalCallPtr& v) override;
VarPtr rand_func() const {
return rand_func_;
@@ -192,15 +192,14 @@
VarPtr rand_func_;
const CudaAnalysis* cuda_analysis_;
- void print_flat_alloc(AllocatePtr alloc);
+ void print_flat_alloc(const AllocatePtr& alloc);
};
-// Construct Cuda C from the buffer and tensor input, and invoke the kernel
-// when real arguments are provided.
+// Construct Cuda C from the buffer and tensor input, and invoke the
+// kernel when real arguments are provided.
class TORCH_CUDA_CU_API CudaCodeGen : public CodeGen {
public:
template <typename... Ts>
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CudaCodeGen(StmtPtr stmt, Ts... ts)
: CodeGen(
stmt,
@@ -272,7 +271,7 @@
std::unique_ptr<GPUMetaVarRewriter> metavar_rewriter_;
std::unordered_set<std::string> taken_func_names;
std::mutex eval_lock_;
- CUfunction function_;
+ CUfunction function_{nullptr};
bool has_random_ = false;
int thread_block_size_ = -1;
diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp
index 0d65165..e95b534 100644
--- a/torch/csrc/jit/tensorexpr/eval.cpp
+++ b/torch/csrc/jit/tensorexpr/eval.cpp
@@ -97,45 +97,45 @@
internal_buffers_.clear();
}
- TORCH_API void visit(AddPtr v) override {
+ TORCH_API void visit(const AddPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(SubPtr v) override {
+ TORCH_API void visit(const SubPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(MulPtr v) override {
+ TORCH_API void visit(const MulPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(DivPtr v) override {
+ TORCH_API void visit(const DivPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(ModPtr v) override {
+ TORCH_API void visit(const ModPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(MaxPtr v) override {
+ TORCH_API void visit(const MaxPtr& v) override {
visit_binary_op(v, v->propagate_nans());
}
- TORCH_API void visit(MinPtr v) override {
+ TORCH_API void visit(const MinPtr& v) override {
visit_binary_op(v, v->propagate_nans());
}
- TORCH_API void visit(AndPtr v) override {
+ TORCH_API void visit(const AndPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(OrPtr v) override {
+ TORCH_API void visit(const OrPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(XorPtr v) override {
+ TORCH_API void visit(const XorPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(LshiftPtr v) override {
+ TORCH_API void visit(const LshiftPtr& v) override {
visit_binary_op(v);
}
- TORCH_API void visit(RshiftPtr v) override {
+ TORCH_API void visit(const RshiftPtr& v) override {
visit_binary_op(v);
}
- void visit(CompareSelectPtr v) override {
+ void visit(const CompareSelectPtr& v) override {
visit_compare_select_op(v, v->compare_select_op());
}
@@ -416,14 +416,14 @@
}
}
-#define IMM_VISIT(Type, Name) \
- TORCH_API void visit(Name##ImmPtr v) override { \
- value_ = InterpValue(v->value()); \
+#define IMM_VISIT(Type, Name) \
+ TORCH_API void visit(const Name##ImmPtr& v) override { \
+ value_ = InterpValue(v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
- TORCH_API void visit(BlockPtr v) override {
+ TORCH_API void visit(const BlockPtr& v) override {
BlockPtr last = scope_;
scope_ = v;
for (const StmtPtr& s : v->stmts()) {
@@ -441,7 +441,7 @@
scope_ = last;
}
- TORCH_API void visit(VarPtr v) override {
+ TORCH_API void visit(const VarPtr& v) override {
auto iter = eval_context_.find(v);
if (iter == eval_context_.end()) {
throw malformed_input("could not find Var in context", v);
@@ -494,7 +494,7 @@
}
}
- TORCH_API void visit(CastPtr v) override {
+ TORCH_API void visit(const CastPtr& v) override {
ExprPtr src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
@@ -549,7 +549,7 @@
}
}
- TORCH_API void visit(BitCastPtr v) override {
+ TORCH_API void visit(const BitCastPtr& v) override {
ExprPtr src_value = v->src_value();
src_value->accept(this);
Dtype dst_dtype = v->dtype();
@@ -572,7 +572,7 @@
}
}
- TORCH_API void visit(ForPtr v) override {
+ TORCH_API void visit(const ForPtr& v) override {
ExprPtr var_node = v->var();
v->start()->accept(this);
auto dtype = value_.dtype();
@@ -592,14 +592,14 @@
eval_context_.erase(var_node);
}
- TORCH_API void visit(RampPtr v) override {
+ TORCH_API void visit(const RampPtr& v) override {
v->base()->accept(this);
auto base = value().intValue();
v->stride()->accept(this);
auto stride = value().intValue();
int lanes = v->lanes();
- std::vector<int> values(lanes);
+ std::vector<int64_t> values(lanes);
for (const auto i : c10::irange(lanes)) {
values[i] = base + i * stride;
}
@@ -607,7 +607,7 @@
value_ = InterpValue(values);
}
- TORCH_API void visit(BroadcastPtr v) override {
+ TORCH_API void visit(const BroadcastPtr& v) override {
v->value()->accept(this);
InterpValue value = this->value();
int lanes = v->lanes();
@@ -624,7 +624,7 @@
}
}
- TORCH_API void visit(IfThenElsePtr v) override {
+ TORCH_API void visit(const IfThenElsePtr& v) override {
v->condition()->accept(this);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool cond_v;
@@ -728,7 +728,7 @@
}
}
- TORCH_API void visit(LoadPtr v) override {
+ TORCH_API void visit(const LoadPtr& v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Load", v);
@@ -772,7 +772,7 @@
}
}
- TORCH_API void visit(StorePtr v) override {
+ TORCH_API void visit(const StorePtr& v) override {
auto iter = buffer_mapping_.find(v->buf());
if (iter == buffer_mapping_.end()) {
throw malformed_input("could not find base node in Store", v);
@@ -821,7 +821,7 @@
}
}
- void visit(ExternalCallPtr v) override {
+ void visit(const ExternalCallPtr& v) override {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
@@ -893,7 +893,7 @@
extra_args.data());
}
- void visit(ExternalCallWithAllocPtr v) override {
+ void visit(const ExternalCallWithAllocPtr& v) override {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
@@ -955,14 +955,12 @@
auto fn_ptr = func_registry.at(v->func_name());
(*fn_ptr)(
- // @lint-ignore CLANGTIDY
bufs_in_size,
buf_ptrs.data(),
buf_ranks.data(),
buf_dims.data(),
buf_strides.data(),
buf_dtypes.data(),
- // @lint-ignore CLANGTIDY
extra_args.size(),
extra_args.data());
@@ -974,7 +972,7 @@
}
template <typename TReturn, typename TInput>
- void visit_intrinsics_helper(IntrinsicsPtr v) {
+ void visit_intrinsics_helper(const IntrinsicsPtr& v) {
std::vector<InterpValue> values(v->nparams());
for (const auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
@@ -1009,7 +1007,7 @@
value_ = InterpValue(result);
}
- TORCH_API void visit(IntrinsicsPtr v) override {
+ TORCH_API void visit(const IntrinsicsPtr& v) override {
auto ty = v->dtype().scalar_type();
if (v->op_type() == kIsNan) {
auto inp_dtype = v->params().at(0)->dtype().scalar_type();
@@ -1036,7 +1034,7 @@
}
}
- void visit(AllocatePtr v) override {
+ void visit(const AllocatePtr& v) override {
BufPtr b = v->buf();
std::vector<ExprPtr> dims = b->dims();
int64_t total_byte_size = b->dtype().byte_size();
@@ -1058,14 +1056,14 @@
internal_buffers_.insert(std::make_pair(b, std::move(buffer)));
}
- void visit(PlacementAllocatePtr v) override {
+ void visit(const PlacementAllocatePtr& v) override {
buffer_mapping_[v->buf()] = buffer_mapping_.at(v->buf_to_reuse());
}
- void visit(FreePtr v) override {
+ void visit(const FreePtr& v) override {
BufPtr b = v->buf();
GRAPH_DEBUG("FREE: buf=", v->buf()->name_hint());
- int count = internal_buffers_.erase(b);
+ auto count = internal_buffers_.erase(b);
if (count == 0) {
throw std::runtime_error(
"Free a buffer that is not currently bound: " +
@@ -1074,7 +1072,7 @@
buffer_mapping_.erase(b);
}
- void visit(FreeExtPtr v) override {
+ void visit(const FreeExtPtr& v) override {
const auto& bufs = v->bufs();
const auto bufs_num = bufs.size();
std::vector<void*> buf_ptrs;
@@ -1089,12 +1087,12 @@
nnc_aten_free(bufs_num, buf_ptrs.data());
}
- void visit(LetPtr v) override {
+ void visit(const LetPtr& v) override {
var_by_scope_[scope_].push_back(v->var());
bindVar(v->var(), evaluateExpr(v->value()));
}
- void visit(CondPtr v) override {
+ void visit(const CondPtr& v) override {
v->condition()->accept(this);
if (value().intValue()) {
if (v->true_stmt()) {
diff --git a/torch/csrc/jit/tensorexpr/exceptions.h b/torch/csrc/jit/tensorexpr/exceptions.h
index b5e656f..d696877 100644
--- a/torch/csrc/jit/tensorexpr/exceptions.h
+++ b/torch/csrc/jit/tensorexpr/exceptions.h
@@ -3,7 +3,6 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
-#include <sstream>
#include <stdexcept>
// Forward declarations of types
@@ -18,8 +17,8 @@
// Forward declarations of functions
namespace std {
-TORCH_API std::string to_string(const torch::jit::tensorexpr::ExprPtr);
-TORCH_API std::string to_string(const torch::jit::tensorexpr::StmtPtr);
+TORCH_API std::string to_string(const torch::jit::tensorexpr::ExprPtr&);
+TORCH_API std::string to_string(const torch::jit::tensorexpr::StmtPtr&);
} // namespace std
namespace torch {
diff --git a/torch/csrc/jit/tensorexpr/half_support.h b/torch/csrc/jit/tensorexpr/half_support.h
index 8ec41fe..dc46544 100644
--- a/torch/csrc/jit/tensorexpr/half_support.h
+++ b/torch/csrc/jit/tensorexpr/half_support.h
@@ -26,27 +26,27 @@
return hasBFloat16_;
}
- void visit(LoadPtr v) override {
+ void visit(const LoadPtr& v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
}
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
hasHalf_ |= v->buf()->dtype().scalar_type() == ScalarType::Half;
hasBFloat16_ |= v->buf()->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
}
- void visit(HalfImmPtr v) override {
+ void visit(const HalfImmPtr& v) override {
hasHalf_ = true;
}
- void visit(BFloat16ImmPtr v) override {
+ void visit(const BFloat16ImmPtr& v) override {
hasBFloat16_ = true;
}
- void visit(CastPtr v) override {
+ void visit(const CastPtr& v) override {
hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half;
hasBFloat16_ |= v->dtype().scalar_type() == ScalarType::BFloat16;
IRVisitor::visit(v);
diff --git a/torch/csrc/jit/tensorexpr/hash_provider.cpp b/torch/csrc/jit/tensorexpr/hash_provider.cpp
index 9b8513b..e687d78 100644
--- a/torch/csrc/jit/tensorexpr/hash_provider.cpp
+++ b/torch/csrc/jit/tensorexpr/hash_provider.cpp
@@ -26,98 +26,98 @@
return _h != other;
}
-void HashProvider::visit(AddPtr v) {
+void HashProvider::visit(const AddPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "+", hashOf(v->rhs())));
}
-void HashProvider::visit(SubPtr v) {
+void HashProvider::visit(const SubPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "-", hashOf(v->rhs())));
}
-void HashProvider::visit(MulPtr v) {
+void HashProvider::visit(const MulPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "*", hashOf(v->rhs())));
}
-void HashProvider::visit(DivPtr v) {
+void HashProvider::visit(const DivPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "/", hashOf(v->rhs())));
}
-void HashProvider::visit(ModPtr v) {
+void HashProvider::visit(const ModPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "%", hashOf(v->rhs())));
}
-void HashProvider::visit(RoundOffPtr v) {
+void HashProvider::visit(const RoundOffPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "rof", hashOf(v->rhs())));
}
-void HashProvider::visit(MaxPtr v) {
+void HashProvider::visit(const MaxPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mx", hashOf(v->rhs())));
}
-void HashProvider::visit(MinPtr v) {
+void HashProvider::visit(const MinPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "Mn", hashOf(v->rhs())));
}
-void HashProvider::visit(AndPtr v) {
+void HashProvider::visit(const AndPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "&", hashOf(v->rhs())));
}
-void HashProvider::visit(OrPtr v) {
+void HashProvider::visit(const OrPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "|", hashOf(v->rhs())));
}
-void HashProvider::visit(XorPtr v) {
+void HashProvider::visit(const XorPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "^", hashOf(v->rhs())));
}
-void HashProvider::visit(LshiftPtr v) {
+void HashProvider::visit(const LshiftPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), "<<", hashOf(v->rhs())));
}
-void HashProvider::visit(RshiftPtr v) {
+void HashProvider::visit(const RshiftPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
putHash(v, hash_combine(hashOf(v->lhs()), ">>", hashOf(v->rhs())));
}
-void HashProvider::visit(CompareSelectPtr v) {
+void HashProvider::visit(const CompareSelectPtr& v) {
CACHE_GUARD();
v->lhs()->accept(this);
v->rhs()->accept(this);
@@ -133,18 +133,18 @@
hashOf(v->ret_val2())));
}
-void HashProvider::visit(CastPtr v) {
+void HashProvider::visit(const CastPtr& v) {
CACHE_GUARD();
v->src_value()->accept(this);
putHash(v, hash_combine("cast", v->dtype(), hashOf(v->src_value())));
}
-void HashProvider::visit(VarPtr v) {
+void HashProvider::visit(const VarPtr& v) {
CACHE_GUARD();
putHash(v, hash_combine("var", name_manager_.get_unique_name(v)));
}
-void HashProvider::visit(RampPtr v) {
+void HashProvider::visit(const RampPtr& v) {
CACHE_GUARD();
v->base()->accept(this);
v->stride()->accept(this);
@@ -153,7 +153,7 @@
hash_combine("ramp", hashOf(v->base()), hashOf(v->stride()), v->lanes()));
}
-void HashProvider::visit(LoadPtr v) {
+void HashProvider::visit(const LoadPtr& v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
@@ -164,7 +164,7 @@
putHash(v, hash_combine("load", hashOf(v->base_handle()), indices_hash));
}
-void HashProvider::visit(StorePtr v) {
+void HashProvider::visit(const StorePtr& v) {
CACHE_GUARD();
v->base_handle()->accept(this);
SimplifierHashType indices_hash;
@@ -179,7 +179,7 @@
"store", hashOf(v->base_handle()), indices_hash, hashOf(v->value())));
}
-void HashProvider::visit(BlockPtr v) {
+void HashProvider::visit(const BlockPtr& v) {
CACHE_GUARD();
SimplifierHashType hash;
@@ -190,7 +190,7 @@
putHash(v, hash);
}
-void HashProvider::visit(ForPtr v) {
+void HashProvider::visit(const ForPtr& v) {
CACHE_GUARD();
v->var()->accept(this);
v->start()->accept(this);
@@ -207,13 +207,13 @@
putHash(v, hash);
}
-void HashProvider::visit(BroadcastPtr v) {
+void HashProvider::visit(const BroadcastPtr& v) {
CACHE_GUARD();
v->value()->accept(this);
putHash(v, hash_combine("broadcast", hashOf(v->value()), v->lanes()));
}
-void HashProvider::visit(IfThenElsePtr v) {
+void HashProvider::visit(const IfThenElsePtr& v) {
CACHE_GUARD();
v->condition()->accept(this);
v->true_value()->accept(this);
@@ -228,7 +228,7 @@
hashOf(v->false_value())));
}
-void HashProvider::visit(IntrinsicsPtr v) {
+void HashProvider::visit(const IntrinsicsPtr& v) {
CACHE_GUARD();
// calls to rand are not symbolic and have a different value each time, they
// should not hash to anything and this is the best we can do.
@@ -247,7 +247,7 @@
putHash(v, hash);
}
-void HashProvider::visit(AllocatePtr v) {
+void HashProvider::visit(const AllocatePtr& v) {
CACHE_GUARD();
VarPtr buffer_var = v->buffer_var();
buffer_var->accept(this);
@@ -263,7 +263,7 @@
putHash(v, hash);
}
-void HashProvider::visit(FreePtr v) {
+void HashProvider::visit(const FreePtr& v) {
CACHE_GUARD();
VarPtr buffer_var = v->buffer_var();
buffer_var->accept(this);
@@ -271,7 +271,7 @@
putHash(v, hash_combine("free", hashOf(buffer_var)));
}
-void HashProvider::visit(CondPtr v) {
+void HashProvider::visit(const CondPtr& v) {
CACHE_GUARD();
ExprPtr condition = v->condition();
StmtPtr true_stmt = v->true_stmt();
@@ -291,7 +291,7 @@
putHash(v, hash);
}
-void HashProvider::visit(TermPtr v) {
+void HashProvider::visit(const TermPtr& v) {
CACHE_GUARD();
v->scalar()->accept(this);
@@ -304,7 +304,7 @@
putHash(v, hash);
}
-void HashProvider::visit(PolynomialPtr v) {
+void HashProvider::visit(const PolynomialPtr& v) {
CACHE_GUARD();
v->scalar()->accept(this);
@@ -317,7 +317,7 @@
putHash(v, hash);
}
-void HashProvider::visit(MaxTermPtr v) {
+void HashProvider::visit(const MaxTermPtr& v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("maxterm");
if (v->scalar()) {
@@ -333,7 +333,7 @@
putHash(v, hash);
}
-void HashProvider::visit(MinTermPtr v) {
+void HashProvider::visit(const MinTermPtr& v) {
CACHE_GUARD();
SimplifierHashType hash = hash_combine("minterm");
if (v->scalar()) {
diff --git a/torch/csrc/jit/tensorexpr/hash_provider.h b/torch/csrc/jit/tensorexpr/hash_provider.h
index c160661..69b360c 100644
--- a/torch/csrc/jit/tensorexpr/hash_provider.h
+++ b/torch/csrc/jit/tensorexpr/hash_provider.h
@@ -7,9 +7,7 @@
#include <utility>
-namespace torch {
-namespace jit {
-namespace tensorexpr {
+namespace torch::jit::tensorexpr {
struct TORCH_API SimplifierHashType {
SimplifierHashType() = default;
@@ -24,9 +22,7 @@
size_t _h{0};
};
-} // namespace tensorexpr
-} // namespace jit
-} // namespace torch
+} // namespace torch::jit::tensorexpr
namespace std {
template <>
@@ -38,9 +34,7 @@
} // namespace std
-namespace torch {
-namespace jit {
-namespace tensorexpr {
+namespace torch::jit::tensorexpr {
#define CACHE_GUARD() \
if (cachedHash(v)) { \
@@ -61,10 +55,10 @@
return hashOf(e);
}
- bool cachedHash(ExprPtr e) {
+ bool cachedHash(const ExprPtr& e) {
return exprToHash_.find(e) != exprToHash_.end();
}
- bool cachedHash(StmtPtr s) {
+ bool cachedHash(const StmtPtr& s) {
return stmtToHash_.find(s) != stmtToHash_.end();
}
@@ -73,47 +67,46 @@
stmtToHash_.clear();
}
- void visit(AddPtr v) override;
- void visit(SubPtr v) override;
- void visit(MulPtr v) override;
- void visit(DivPtr v) override;
- void visit(ModPtr v) override;
- void visit(RoundOffPtr v) override;
- void visit(MaxPtr v) override;
- void visit(MinPtr v) override;
- void visit(AndPtr v) override;
- void visit(OrPtr v) override;
- void visit(XorPtr v) override;
- void visit(LshiftPtr v) override;
- void visit(RshiftPtr v) override;
- void visit(CompareSelectPtr v) override;
+ void visit(const AddPtr& v) override;
+ void visit(const SubPtr& v) override;
+ void visit(const MulPtr& v) override;
+ void visit(const DivPtr& v) override;
+ void visit(const ModPtr& v) override;
+ void visit(const RoundOffPtr& v) override;
+ void visit(const MaxPtr& v) override;
+ void visit(const MinPtr& v) override;
+ void visit(const AndPtr& v) override;
+ void visit(const OrPtr& v) override;
+ void visit(const XorPtr& v) override;
+ void visit(const LshiftPtr& v) override;
+ void visit(const RshiftPtr& v) override;
+ void visit(const CompareSelectPtr& v) override;
-// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
- void visit(Name##ImmPtr v) override { \
+ void visit(const Name##ImmPtr& v) override { \
CACHE_GUARD(); \
putHash(v, hash_combine(#Name, v->value())); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
- void visit(CastPtr v) override;
- void visit(VarPtr v) override;
- void visit(RampPtr v) override;
- void visit(LoadPtr v) override;
- void visit(StorePtr v) override;
- void visit(BlockPtr v) override;
- void visit(ForPtr v) override;
- void visit(BroadcastPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(IntrinsicsPtr v) override;
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
- void visit(CondPtr v) override;
- void visit(TermPtr v) override;
- void visit(PolynomialPtr v) override;
- void visit(MaxTermPtr v) override;
- void visit(MinTermPtr v) override;
+ void visit(const CastPtr& v) override;
+ void visit(const VarPtr& v) override;
+ void visit(const RampPtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const BroadcastPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const IntrinsicsPtr& v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
+ void visit(const CondPtr& v) override;
+ void visit(const TermPtr& v) override;
+ void visit(const PolynomialPtr& v) override;
+ void visit(const MaxTermPtr& v) override;
+ void visit(const MinTermPtr& v) override;
template <typename... Types>
SimplifierHashType hash_combine(const Types&... args) {
@@ -189,14 +182,14 @@
_hash_combine(seed, args...);
}
- void putHash(ExprPtr e, SimplifierHashType h) {
+ void putHash(const ExprPtr& e, SimplifierHashType h) {
auto res = exprToHash_.emplace(e, h);
if (res.second == false) {
// This is always a logic bug since we should check the cache first.
throw std::runtime_error("hash collision");
}
}
- void putHash(StmtPtr s, SimplifierHashType h) {
+ void putHash(const StmtPtr& s, SimplifierHashType h) {
auto res = stmtToHash_.emplace(s, h);
if (res.second == false) {
// This is always a logic bug since we should check the cache first.
@@ -254,7 +247,7 @@
if (s < 0)
break;
// NOLINTNEXTLINE(bugprone-signed-char-misuse)
- int64_t c = val.data()[s];
+ int64_t c = val[s];
intval |= (c << (i * 8));
s--;
@@ -299,6 +292,4 @@
}
};
-} // namespace tensorexpr
-} // namespace jit
-} // namespace torch
+} // namespace torch::jit::tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/ir_cloner.cpp b/torch/csrc/jit/tensorexpr/ir_cloner.cpp
index 1b2bac1..f45abf0 100644
--- a/torch/csrc/jit/tensorexpr/ir_cloner.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_cloner.cpp
@@ -10,9 +10,9 @@
template <
typename Op,
- typename std::enable_if<std::is_same<
+ std::enable_if_t<std::is_same_v<
decltype(detail::bin_op_deducer(std::declval<Op>())),
- void>::value>::type* = nullptr>
+ void>>* = nullptr>
static ExprPtr mutate_binary_op(
NodePtr<Op> v,
IRCloner* cloner,
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp
index fc9e811..2300102 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp
@@ -80,43 +80,43 @@
}
}
-void IRPrinter::visit(AddPtr v) {
+void IRPrinter::visit(const AddPtr& v) {
visitBinaryOp(v, "+", this);
}
-void IRPrinter::visit(SubPtr v) {
+void IRPrinter::visit(const SubPtr& v) {
visitBinaryOp(v, "-", this);
}
-void IRPrinter::visit(MulPtr v) {
+void IRPrinter::visit(const MulPtr& v) {
visitBinaryOp(v, "*", this);
}
-void IRPrinter::visit(DivPtr v) {
+void IRPrinter::visit(const DivPtr& v) {
visitBinaryOp(v, "/", this);
}
-void IRPrinter::visit(AndPtr v) {
+void IRPrinter::visit(const AndPtr& v) {
visitBinaryOp(v, "&", this);
}
-void IRPrinter::visit(OrPtr v) {
+void IRPrinter::visit(const OrPtr& v) {
visitBinaryOp(v, "|", this);
}
-void IRPrinter::visit(XorPtr v) {
+void IRPrinter::visit(const XorPtr& v) {
visitBinaryOp(v, "^", this);
}
-void IRPrinter::visit(LshiftPtr v) {
+void IRPrinter::visit(const LshiftPtr& v) {
visitBinaryOp(v, "<<", this);
}
-void IRPrinter::visit(RshiftPtr v) {
+void IRPrinter::visit(const RshiftPtr& v) {
visitBinaryOp(v, ">>", this);
}
-void IRPrinter::visit(ModPtr v) {
+void IRPrinter::visit(const ModPtr& v) {
if (v->dtype().is_integral()) {
visitBinaryOp(v, "%", this);
} else if (v->dtype().is_floating_point()) {
@@ -126,7 +126,7 @@
}
}
-void IRPrinter::visit(MaxPtr v) {
+void IRPrinter::visit(const MaxPtr& v) {
os() << "Max(";
v->lhs()->accept(this);
os() << ", ";
@@ -134,7 +134,7 @@
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
-void IRPrinter::visit(MinPtr v) {
+void IRPrinter::visit(const MinPtr& v) {
os() << "Min(";
v->lhs()->accept(this);
os() << ", ";
@@ -142,7 +142,7 @@
os() << ", " << (unsigned int)v->propagate_nans() << ")";
}
-void IRPrinter::visit(CompareSelectPtr v) {
+void IRPrinter::visit(const CompareSelectPtr& v) {
CompareSelectOperation cmp_op = v->compare_select_op();
int self_prec = getPrecedence(v->expr_type());
int lhs_prec = getPrecedence(v->lhs()->expr_type());
@@ -222,32 +222,32 @@
}
// NOLINTNEXTLINE
-#define IMM_PRINT_VISIT(Type, Name) \
- void IRPrinter::visit(Name##ImmPtr v) { \
- formatImm(os(), v->value()); \
+#define IMM_PRINT_VISIT(Type, Name) \
+ void IRPrinter::visit(const Name##ImmPtr& v) { \
+ formatImm(os(), v->value()); \
}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
-void IRPrinter::visit(CastPtr v) {
+void IRPrinter::visit(const CastPtr& v) {
auto dtype = v->dtype();
os() << dtypeToCppString(dtype) << "(";
v->src_value()->accept(this);
os() << ")";
}
-void IRPrinter::visit(BitCastPtr v) {
+void IRPrinter::visit(const BitCastPtr& v) {
auto dtype = v->dtype();
os() << "BitCast<" << dtype.ToCppString() << ">(";
v->src_value()->accept(this);
os() << ")";
}
-void IRPrinter::visit(VarPtr v) {
+void IRPrinter::visit(const VarPtr& v) {
os() << name_manager_.get_unique_name(v);
}
-void IRPrinter::visit(BufPtr v) {
+void IRPrinter::visit(const BufPtr& v) {
auto dtype = v->dtype();
os() << *v->base_handle();
os() << "(dtype=" << dtypeToCppString(dtype);
@@ -281,12 +281,12 @@
os() << ")";
}
-void IRPrinter::visit(RampPtr v) {
+void IRPrinter::visit(const RampPtr& v) {
os() << "Ramp(" << *v->base() << ", " << *v->stride() << ", " << v->lanes()
<< ")";
}
-void IRPrinter::visit(LoadPtr v) {
+void IRPrinter::visit(const LoadPtr& v) {
// TODO: support the mask case
if (v->indices().empty()) {
os() << *v->base_handle();
@@ -306,16 +306,16 @@
}
}
-void IRPrinter::visit(BroadcastPtr v) {
+void IRPrinter::visit(const BroadcastPtr& v) {
os() << "Broadcast(" << *v->value() << ", " << v->lanes() << ")";
}
-void IRPrinter::visit(IfThenElsePtr v) {
+void IRPrinter::visit(const IfThenElsePtr& v) {
os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", "
<< *v->false_value() << ")";
}
-void IRPrinter::visit(IntrinsicsPtr v) {
+void IRPrinter::visit(const IntrinsicsPtr& v) {
os() << v->func_name() << "(";
for (const auto i : c10::irange(v->nparams())) {
if (i > 0) {
@@ -326,7 +326,7 @@
os() << ")";
}
-void IRPrinter::visit(TermPtr v) {
+void IRPrinter::visit(const TermPtr& v) {
os() << "Term(";
v->scalar()->accept(this);
for (const auto& t : v->variables()) {
@@ -336,7 +336,7 @@
os() << ")";
}
-void IRPrinter::visit(PolynomialPtr v) {
+void IRPrinter::visit(const PolynomialPtr& v) {
bool first = true;
os() << "Polynomial(";
for (const auto& t : v->variables()) {
@@ -354,7 +354,7 @@
os() << ")";
}
-void IRPrinter::visit(RoundOffPtr v) {
+void IRPrinter::visit(const RoundOffPtr& v) {
os() << "RoundOff(";
v->lhs()->accept(this);
os() << ", ";
@@ -362,7 +362,7 @@
os() << ")";
}
-void IRPrinter::visit(MaxTermPtr v) {
+void IRPrinter::visit(const MaxTermPtr& v) {
os() << "MaxTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
@@ -377,7 +377,7 @@
os() << ")";
}
-void IRPrinter::visit(MinTermPtr v) {
+void IRPrinter::visit(const MinTermPtr& v) {
os() << "MinTerm(";
if (v->scalar()) {
v->scalar()->accept(this);
@@ -392,7 +392,7 @@
os() << ")";
}
-void IRPrinter::visit(ReduceOpPtr v) {
+void IRPrinter::visit(const ReduceOpPtr& v) {
os() << "ReduceOp(";
os() << *v->body() << ", ";
@@ -414,7 +414,7 @@
// each statement in a `Block` the printer will insert indentation before
// the statement and a newline after the statement.
-void IRPrinter::visit(StorePtr v) {
+void IRPrinter::visit(const StorePtr& v) {
// TODO: handle the mask
if (v->indices().empty()) {
os() << *v->base_handle() << " = " << *v->value() << ";";
@@ -435,7 +435,7 @@
os() << "] = " << *v->value() << ";";
}
-void IRPrinter::visit(ForPtr v) {
+void IRPrinter::visit(const ForPtr& v) {
VarPtr var = v->var();
VarHandle vv(var);
os() << "for (" << dtypeToCppString(var->dtype()) << " " << vv << " = "
@@ -452,7 +452,7 @@
}
}
-void IRPrinter::visit(BlockPtr v) {
+void IRPrinter::visit(const BlockPtr& v) {
os() << "{\n";
indent_++;
@@ -465,7 +465,7 @@
os() << "}";
}
-void IRPrinter::visit(AllocatePtr v) {
+void IRPrinter::visit(const AllocatePtr& v) {
os() << "Allocate(" << *v->buffer_var()
<< "); // dtype=" << dtypeToCppString(v->dtype());
os() << ", dims=[";
@@ -479,11 +479,11 @@
os() << "]";
}
-void IRPrinter::visit(FreePtr v) {
+void IRPrinter::visit(const FreePtr& v) {
os() << "Free(" << *v->buffer_var() << ");";
}
-void IRPrinter::visit(FreeExtPtr v) {
+void IRPrinter::visit(const FreeExtPtr& v) {
os() << "FreeExt(bufs={";
int i = 0;
for (const auto& buf : v->bufs()) {
@@ -496,17 +496,17 @@
os() << "});";
}
-void IRPrinter::visit(PlacementAllocatePtr v) {
+void IRPrinter::visit(const PlacementAllocatePtr& v) {
os() << "Alias(" << *v->buf()->base_handle() << ","
<< *v->buf_to_reuse()->base_handle() << ");";
}
-void IRPrinter::visit(LetPtr v) {
+void IRPrinter::visit(const LetPtr& v) {
os() << dtypeToCppString(v->var()->dtype()) << " " << *v->var();
os() << " = " << *v->value() << ";";
}
-void IRPrinter::visit(CondPtr v) {
+void IRPrinter::visit(const CondPtr& v) {
ExprPtr cond = v->condition();
StmtPtr true_stmt = v->true_stmt();
StmtPtr false_stmt = v->false_stmt();
@@ -523,7 +523,7 @@
}
}
-void IRPrinter::visit(AtomicAddPtr v) {
+void IRPrinter::visit(const AtomicAddPtr& v) {
os() << "atomicAdd(&" << *v->base_handle() << "[";
size_t i = 0;
for (const ExprPtr& ind : v->indices()) {
@@ -538,11 +538,11 @@
os() << "], " << *v->value() << ");";
}
-void IRPrinter::visit(SyncThreadsPtr v) {
+void IRPrinter::visit(const SyncThreadsPtr& v) {
os() << "__syncthreads();";
}
-void IRPrinter::visit(ExternalCallPtr v) {
+void IRPrinter::visit(const ExternalCallPtr& v) {
os() << *v->buf() << " = " << v->func_name() << "(";
os() << "buf_args={";
@@ -565,7 +565,7 @@
os() << "})";
}
-void IRPrinter::visit(ExternalCallWithAllocPtr v) {
+void IRPrinter::visit(const ExternalCallWithAllocPtr& v) {
int i = 0;
for (const auto& buf_out_arg : v->buf_out_args()) {
if (i++ > 0) {
@@ -670,13 +670,13 @@
} // namespace torch::jit::tensorexpr
namespace std {
-std::string to_string(ExprPtr expr) {
+std::string to_string(const ExprPtr& expr) {
std::ostringstream oss;
oss << *expr;
return oss.str();
}
-std::string to_string(StmtPtr stmt) {
+std::string to_string(const StmtPtr& stmt) {
std::ostringstream oss;
oss << *stmt;
return oss.str();
diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h
index 8eff1ab..ad0345e 100644
--- a/torch/csrc/jit/tensorexpr/ir_printer.h
+++ b/torch/csrc/jit/tensorexpr/ir_printer.h
@@ -20,51 +20,51 @@
void print(ExprHandle);
void print(Expr&);
void print(Stmt&);
- void visit(AddPtr v) override;
- void visit(SubPtr v) override;
- void visit(MulPtr v) override;
- void visit(DivPtr v) override;
- void visit(ModPtr v) override;
- void visit(MaxPtr v) override;
- void visit(MinPtr v) override;
- void visit(AndPtr v) override;
- void visit(OrPtr v) override;
- void visit(XorPtr v) override;
- void visit(LshiftPtr v) override;
- void visit(RshiftPtr v) override;
- void visit(CompareSelectPtr v) override;
-#define IMM_PRINT_VISIT(Type, Name) void visit(Name##ImmPtr v) override;
+ void visit(const AddPtr& v) override;
+ void visit(const SubPtr& v) override;
+ void visit(const MulPtr& v) override;
+ void visit(const DivPtr& v) override;
+ void visit(const ModPtr& v) override;
+ void visit(const MaxPtr& v) override;
+ void visit(const MinPtr& v) override;
+ void visit(const AndPtr& v) override;
+ void visit(const OrPtr& v) override;
+ void visit(const XorPtr& v) override;
+ void visit(const LshiftPtr& v) override;
+ void visit(const RshiftPtr& v) override;
+ void visit(const CompareSelectPtr& v) override;
+#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##ImmPtr& v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT);
#undef IMM_PRINT_VISIT
- void visit(CastPtr v) override;
- void visit(BitCastPtr v) override;
- void visit(VarPtr v) override;
- void visit(BufPtr v) override;
- void visit(RampPtr v) override;
- void visit(LoadPtr v) override;
- void visit(BroadcastPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(IntrinsicsPtr v) override;
- void visit(TermPtr v) override;
- void visit(PolynomialPtr v) override;
- void visit(RoundOffPtr v) override;
- void visit(MaxTermPtr v) override;
- void visit(MinTermPtr v) override;
- void visit(ReduceOpPtr v) override;
+ void visit(const CastPtr& v) override;
+ void visit(const BitCastPtr& v) override;
+ void visit(const VarPtr& v) override;
+ void visit(const BufPtr& v) override;
+ void visit(const RampPtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const BroadcastPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const IntrinsicsPtr& v) override;
+ void visit(const TermPtr& v) override;
+ void visit(const PolynomialPtr& v) override;
+ void visit(const RoundOffPtr& v) override;
+ void visit(const MaxTermPtr& v) override;
+ void visit(const MinTermPtr& v) override;
+ void visit(const ReduceOpPtr& v) override;
- void visit(AtomicAddPtr v) override;
- void visit(SyncThreadsPtr v) override;
- void visit(ExternalCallPtr v) override;
- void visit(ExternalCallWithAllocPtr v) override;
- void visit(StorePtr v) override;
- void visit(ForPtr v) override;
- void visit(CondPtr v) override;
- void visit(BlockPtr v) override;
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
- void visit(FreeExtPtr v) override;
- void visit(PlacementAllocatePtr v) override;
- void visit(LetPtr v) override;
+ void visit(const AtomicAddPtr& v) override;
+ void visit(const SyncThreadsPtr& v) override;
+ void visit(const ExternalCallPtr& v) override;
+ void visit(const ExternalCallWithAllocPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const CondPtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
+ void visit(const FreeExtPtr& v) override;
+ void visit(const PlacementAllocatePtr& v) override;
+ void visit(const LetPtr& v) override;
// A child class may have a difference rule for generating dtype
// string, e.g. CUDA needs int64_t to be generated as long long.
@@ -124,7 +124,7 @@
using torch::jit::tensorexpr::StmtPtr;
using torch::jit::tensorexpr::Tensor;
-TORCH_API std::string to_string(ExprPtr expr);
-TORCH_API std::string to_string(StmtPtr stmt);
+TORCH_API std::string to_string(const ExprPtr& expr);
+TORCH_API std::string to_string(const StmtPtr& stmt);
TORCH_API std::string to_string(const Tensor& t);
} // namespace std
diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.cpp b/torch/csrc/jit/tensorexpr/ir_verifier.cpp
index cc75694..109e3a8 100644
--- a/torch/csrc/jit/tensorexpr/ir_verifier.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_verifier.cpp
@@ -16,9 +16,9 @@
template <
typename D,
- typename std::enable_if<std::is_same<
- decltype(detail::deducer(std::declval<D>())),
- void>::value>::type* = nullptr>
+ std::enable_if_t<
+ std::is_same_v<decltype(detail::deducer(std::declval<D>())), void>>* =
+ nullptr>
void verifyBitwiseOp(NodePtr<D> v, IRVerifier* verifier) {
if (!v->lhs()->dtype().is_integral()) {
throw unsupported_dtype();
@@ -28,39 +28,39 @@
}
}
-void IRVerifier::visit(AndPtr v) {
+void IRVerifier::visit(const AndPtr& v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(OrPtr v) {
+void IRVerifier::visit(const OrPtr& v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(XorPtr v) {
+void IRVerifier::visit(const XorPtr& v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(LshiftPtr v) {
+void IRVerifier::visit(const LshiftPtr& v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(RshiftPtr v) {
+void IRVerifier::visit(const RshiftPtr& v) {
verifyBitwiseOp(v, this);
IRVisitor::visit(v);
}
-void IRVerifier::visit(ModPtr v) {
+void IRVerifier::visit(const ModPtr& v) {
if (!v->dtype().is_integral() && !v->dtype().is_floating_point()) {
throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype()));
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(CompareSelectPtr v) {
+void IRVerifier::visit(const CompareSelectPtr& v) {
if (v->ret_val1()->dtype() != v->ret_val2()->dtype()) {
throw malformed_ir("bad dtype in CompareSelect");
}
@@ -70,14 +70,14 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(RampPtr v) {
+void IRVerifier::visit(const RampPtr& v) {
if (v->stride()->dtype() != v->base()->dtype()) {
throw malformed_ir("Bad stride in Ramp");
}
IRVisitor::visit(v);
}
-void IRVerifier::visit(LoadPtr v) {
+void IRVerifier::visit(const LoadPtr& v) {
auto indices = v->indices();
if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
@@ -103,7 +103,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(IfThenElsePtr v) {
+void IRVerifier::visit(const IfThenElsePtr& v) {
if (!v->condition()->dtype().is_integral()) {
throw unsupported_dtype();
}
@@ -116,7 +116,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(IntrinsicsPtr v) {
+void IRVerifier::visit(const IntrinsicsPtr& v) {
if (v->op_type() == kIsNan) {
if (v->dtype().scalar_type() != c10::kInt) {
throw malformed_ir("bad dtype in intrinsic arg");
@@ -133,7 +133,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(StorePtr v) {
+void IRVerifier::visit(const StorePtr& v) {
auto indices = v->indices();
if (!indices.empty() && v->buf()->base_handle()->dtype() != kHandle) {
throw malformed_ir(
@@ -162,7 +162,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(ForPtr v) {
+void IRVerifier::visit(const ForPtr& v) {
if (!v->var()) {
throw malformed_ir("nullptr Var in For loop");
} else if (!v->start()) {
@@ -175,7 +175,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(BlockPtr v) {
+void IRVerifier::visit(const BlockPtr& v) {
for (const StmtPtr& s : v->stmts()) {
if (s->get_parent() != v) {
throw malformed_ir("Broken child-parent link inside a Block");
@@ -184,7 +184,7 @@
IRVisitor::visit(v);
}
-void IRVerifier::visit(ExternalCallPtr v) {
+void IRVerifier::visit(const ExternalCallPtr& v) {
IRVisitor::visit(v);
}
diff --git a/torch/csrc/jit/tensorexpr/ir_verifier.h b/torch/csrc/jit/tensorexpr/ir_verifier.h
index 03b6d9a..e1fb4c0 100644
--- a/torch/csrc/jit/tensorexpr/ir_verifier.h
+++ b/torch/csrc/jit/tensorexpr/ir_verifier.h
@@ -31,22 +31,22 @@
public:
IRVerifier() = default;
- void visit(ModPtr v) override;
- void visit(AndPtr v) override;
- void visit(OrPtr v) override;
- void visit(XorPtr v) override;
- void visit(LshiftPtr v) override;
- void visit(RshiftPtr v) override;
- void visit(CompareSelectPtr v) override;
- void visit(RampPtr v) override;
- void visit(LoadPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(IntrinsicsPtr v) override;
+ void visit(const ModPtr& v) override;
+ void visit(const AndPtr& v) override;
+ void visit(const OrPtr& v) override;
+ void visit(const XorPtr& v) override;
+ void visit(const LshiftPtr& v) override;
+ void visit(const RshiftPtr& v) override;
+ void visit(const CompareSelectPtr& v) override;
+ void visit(const RampPtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const IntrinsicsPtr& v) override;
- void visit(ExternalCallPtr v) override;
- void visit(StorePtr v) override;
- void visit(ForPtr v) override;
- void visit(BlockPtr v) override;
+ void visit(const ExternalCallPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const BlockPtr& v) override;
};
TORCH_API void verify(StmtPtr);
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
index 8ba8064..00232fe 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.cpp
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp
@@ -11,96 +11,95 @@
template <
typename Op,
- typename std::enable_if<std::is_same<
+ std::enable_if_t<std::is_same_v<
decltype(detail::bin_op_deducer(std::declval<Op>())),
- void>::value>::type* = nullptr>
-static void visit_binary_op(NodePtr<Op> v, IRVisitor* visitor) {
+ void>>* = nullptr>
+static void visit_binary_op(const NodePtr<Op>& v, IRVisitor* visitor) {
v->lhs()->accept(visitor);
v->rhs()->accept(visitor);
}
-void IRVisitor::visit(AddPtr v) {
+void IRVisitor::visit(const AddPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(SubPtr v) {
+void IRVisitor::visit(const SubPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(MulPtr v) {
+void IRVisitor::visit(const MulPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(DivPtr v) {
+void IRVisitor::visit(const DivPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(ModPtr v) {
+void IRVisitor::visit(const ModPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(MaxPtr v) {
+void IRVisitor::visit(const MaxPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(MinPtr v) {
+void IRVisitor::visit(const MinPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(AndPtr v) {
+void IRVisitor::visit(const AndPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(OrPtr v) {
+void IRVisitor::visit(const OrPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(XorPtr v) {
+void IRVisitor::visit(const XorPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(LshiftPtr v) {
+void IRVisitor::visit(const LshiftPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(RshiftPtr v) {
+void IRVisitor::visit(const RshiftPtr& v) {
visit_binary_op(v, this);
}
-void IRVisitor::visit(CompareSelectPtr v) {
+void IRVisitor::visit(const CompareSelectPtr& v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
v->ret_val1()->accept(this);
v->ret_val2()->accept(this);
}
-// NOLINTNEXTLINE
#define IMM_VISIT(Type, Name) \
- void IRVisitor::visit(Name##ImmPtr v) {}
+ void IRVisitor::visit(const Name##ImmPtr& v) {}
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT);
#undef IMM_VISIT
-void IRVisitor::visit(CastPtr v) {
+void IRVisitor::visit(const CastPtr& v) {
v->src_value()->accept(this);
}
-void IRVisitor::visit(BitCastPtr v) {
+void IRVisitor::visit(const BitCastPtr& v) {
v->src_value()->accept(this);
}
-void IRVisitor::visit(VarPtr v) {}
+void IRVisitor::visit(const VarPtr& v) {}
-void IRVisitor::visit(RampPtr v) {
+void IRVisitor::visit(const RampPtr& v) {
v->base()->accept(this);
v->stride()->accept(this);
}
-void IRVisitor::visit(LoadPtr v) {
+void IRVisitor::visit(const LoadPtr& v) {
v->buf()->accept(this);
for (const ExprPtr& ind : v->indices()) {
ind->accept(this);
}
}
-void IRVisitor::visit(BufPtr v) {
+void IRVisitor::visit(const BufPtr& v) {
v->base_handle()->accept(this);
if (v->qscale()) {
v->qscale()->accept(this);
@@ -110,7 +109,7 @@
}
}
-void IRVisitor::visit(StorePtr v) {
+void IRVisitor::visit(const StorePtr& v) {
v->buf()->accept(this);
for (const ExprPtr& ind : v->indices()) {
ind->accept(this);
@@ -118,7 +117,7 @@
v->value()->accept(this);
}
-void IRVisitor::visit(AtomicAddPtr v) {
+void IRVisitor::visit(const AtomicAddPtr& v) {
v->buf()->accept(this);
for (const ExprPtr& ind : v->indices()) {
ind->accept(this);
@@ -126,9 +125,9 @@
v->value()->accept(this);
}
-void IRVisitor::visit(SyncThreadsPtr v) {}
+void IRVisitor::visit(const SyncThreadsPtr& v) {}
-void IRVisitor::visit(ExternalCallPtr v) {
+void IRVisitor::visit(const ExternalCallPtr& v) {
v->buf()->accept(this);
for (const BufPtr& buf_arg : v->buf_args()) {
buf_arg->accept(this);
@@ -138,7 +137,7 @@
}
}
-void IRVisitor::visit(ExternalCallWithAllocPtr v) {
+void IRVisitor::visit(const ExternalCallWithAllocPtr& v) {
for (const auto& buf_out_arg : v->buf_out_args()) {
buf_out_arg->accept(this);
}
@@ -150,19 +149,19 @@
}
}
-void IRVisitor::visit(FreeExtPtr v) {
+void IRVisitor::visit(const FreeExtPtr& v) {
for (const auto& buf : v->bufs()) {
buf->accept(this);
}
}
-void IRVisitor::visit(BlockPtr v) {
+void IRVisitor::visit(const BlockPtr& v) {
for (const StmtPtr& s : *v) {
s->accept(this);
}
}
-void IRVisitor::visit(ForPtr v) {
+void IRVisitor::visit(const ForPtr& v) {
v->var()->accept(this);
v->start()->accept(this);
v->stop()->accept(this);
@@ -171,23 +170,23 @@
}
}
-void IRVisitor::visit(BroadcastPtr v) {
+void IRVisitor::visit(const BroadcastPtr& v) {
v->value()->accept(this);
}
-void IRVisitor::visit(IfThenElsePtr v) {
+void IRVisitor::visit(const IfThenElsePtr& v) {
v->condition()->accept(this);
v->true_value()->accept(this);
v->false_value()->accept(this);
}
-void IRVisitor::visit(IntrinsicsPtr v) {
+void IRVisitor::visit(const IntrinsicsPtr& v) {
for (const auto i : c10::irange(v->nparams())) {
v->param(i)->accept(this);
}
}
-void IRVisitor::visit(AllocatePtr v) {
+void IRVisitor::visit(const AllocatePtr& v) {
v->buffer_var()->accept(this);
std::vector<ExprPtr> dims = v->dims();
for (const ExprPtr& dim : dims) {
@@ -195,21 +194,21 @@
}
}
-void IRVisitor::visit(FreePtr v) {
+void IRVisitor::visit(const FreePtr& v) {
v->buffer_var()->accept(this);
}
-void IRVisitor::visit(PlacementAllocatePtr v) {
+void IRVisitor::visit(const PlacementAllocatePtr& v) {
v->buf()->accept(this);
v->buf_to_reuse()->accept(this);
}
-void IRVisitor::visit(LetPtr v) {
+void IRVisitor::visit(const LetPtr& v) {
v->var()->accept(this);
v->value()->accept(this);
}
-void IRVisitor::visit(CondPtr v) {
+void IRVisitor::visit(const CondPtr& v) {
ExprPtr condition = v->condition();
StmtPtr true_stmt = v->true_stmt();
StmtPtr false_stmt = v->false_stmt();
@@ -222,26 +221,26 @@
}
}
-void IRVisitor::visit(TermPtr v) {
+void IRVisitor::visit(const TermPtr& v) {
v->scalar()->accept(this);
for (const auto& t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(PolynomialPtr v) {
+void IRVisitor::visit(const PolynomialPtr& v) {
v->scalar()->accept(this);
for (const auto& t : v->variables()) {
t->accept(this);
}
}
-void IRVisitor::visit(RoundOffPtr v) {
+void IRVisitor::visit(const RoundOffPtr& v) {
v->lhs()->accept(this);
v->rhs()->accept(this);
}
-void IRVisitor::visit(MaxTermPtr v) {
+void IRVisitor::visit(const MaxTermPtr& v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
@@ -250,7 +249,7 @@
}
}
-void IRVisitor::visit(MinTermPtr v) {
+void IRVisitor::visit(const MinTermPtr& v) {
if (v->scalar()) {
v->scalar()->accept(this);
}
@@ -259,7 +258,7 @@
}
}
-void IRVisitor::visit(ReduceOpPtr v) {
+void IRVisitor::visit(const ReduceOpPtr& v) {
v->body()->accept(this);
for (const auto& r : v->reduce_args()) {
diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h
index 09e6069..d6e87b6 100644
--- a/torch/csrc/jit/tensorexpr/ir_visitor.h
+++ b/torch/csrc/jit/tensorexpr/ir_visitor.h
@@ -3,62 +3,58 @@
#include <torch/csrc/Export.h>
#include <torch/csrc/jit/tensorexpr/fwd_decls.h>
-namespace torch {
-namespace jit {
-namespace tensorexpr {
+namespace torch::jit::tensorexpr {
class TORCH_API IRVisitor {
public:
virtual ~IRVisitor() = default;
- virtual void visit(AddPtr v);
- virtual void visit(SubPtr v);
- virtual void visit(MulPtr v);
- virtual void visit(DivPtr v);
- virtual void visit(ModPtr v);
- virtual void visit(MaxPtr v);
- virtual void visit(MinPtr v);
- virtual void visit(AndPtr v);
- virtual void visit(OrPtr v);
- virtual void visit(XorPtr v);
- virtual void visit(LshiftPtr v);
- virtual void visit(RshiftPtr v);
- virtual void visit(CompareSelectPtr v);
+ virtual void visit(const AddPtr& v);
+ virtual void visit(const SubPtr& v);
+ virtual void visit(const MulPtr& v);
+ virtual void visit(const DivPtr& v);
+ virtual void visit(const ModPtr& v);
+ virtual void visit(const MaxPtr& v);
+ virtual void visit(const MinPtr& v);
+ virtual void visit(const AndPtr& v);
+ virtual void visit(const OrPtr& v);
+ virtual void visit(const XorPtr& v);
+ virtual void visit(const LshiftPtr& v);
+ virtual void visit(const RshiftPtr& v);
+ virtual void visit(const CompareSelectPtr& v);
-#define IMM_PRINT_VISIT(Type, Name) virtual void visit(Name##ImmPtr v);
+#define IMM_PRINT_VISIT(Type, Name) virtual void visit(const Name##ImmPtr& v);
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_PRINT_VISIT)
#undef IMM_PRINT_VISIT
- virtual void visit(CastPtr v);
- virtual void visit(BitCastPtr v);
- virtual void visit(VarPtr v);
- virtual void visit(BufPtr v);
- virtual void visit(RampPtr v);
- virtual void visit(LoadPtr v);
- virtual void visit(ForPtr v);
- virtual void visit(BlockPtr v);
- virtual void visit(StorePtr v);
- virtual void visit(BroadcastPtr v);
- virtual void visit(IfThenElsePtr v);
- virtual void visit(IntrinsicsPtr v);
- virtual void visit(AllocatePtr v);
- virtual void visit(FreePtr v);
- virtual void visit(FreeExtPtr v);
- virtual void visit(PlacementAllocatePtr v);
- virtual void visit(LetPtr v);
- virtual void visit(CondPtr v);
- virtual void visit(TermPtr v);
- virtual void visit(PolynomialPtr v);
- virtual void visit(RoundOffPtr v);
- virtual void visit(MaxTermPtr v);
- virtual void visit(MinTermPtr v);
- virtual void visit(ReduceOpPtr v);
- virtual void visit(AtomicAddPtr v);
- virtual void visit(SyncThreadsPtr v);
- virtual void visit(ExternalCallPtr v);
- virtual void visit(ExternalCallWithAllocPtr v);
+ virtual void visit(const CastPtr& v);
+ virtual void visit(const BitCastPtr& v);
+ virtual void visit(const VarPtr& v);
+ virtual void visit(const BufPtr& v);
+ virtual void visit(const RampPtr& v);
+ virtual void visit(const LoadPtr& v);
+ virtual void visit(const ForPtr& v);
+ virtual void visit(const BlockPtr& v);
+ virtual void visit(const StorePtr& v);
+ virtual void visit(const BroadcastPtr& v);
+ virtual void visit(const IfThenElsePtr& v);
+ virtual void visit(const IntrinsicsPtr& v);
+ virtual void visit(const AllocatePtr& v);
+ virtual void visit(const FreePtr& v);
+ virtual void visit(const FreeExtPtr& v);
+ virtual void visit(const PlacementAllocatePtr& v);
+ virtual void visit(const LetPtr& v);
+ virtual void visit(const CondPtr& v);
+ virtual void visit(const TermPtr& v);
+ virtual void visit(const PolynomialPtr& v);
+ virtual void visit(const RoundOffPtr& v);
+ virtual void visit(const MaxTermPtr& v);
+ virtual void visit(const MinTermPtr& v);
+ virtual void visit(const ReduceOpPtr& v);
+ virtual void visit(const AtomicAddPtr& v);
+ virtual void visit(const SyncThreadsPtr& v);
+ virtual void visit(const ExternalCallPtr& v);
+ virtual void visit(const ExternalCallWithAllocPtr& v);
};
-} // namespace tensorexpr
-} // namespace jit
-} // namespace torch
+} // namespace torch::jit::tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index 3deb03b..af9622c 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -314,43 +314,43 @@
llvm::JITTargetAddress getKernelAddress() const;
std::unique_ptr<llvm::orc::PytorchLLVMJIT> releaseJIT();
- void visit(AddPtr v) override;
- void visit(SubPtr v) override;
- void visit(MulPtr v) override;
- void visit(DivPtr v) override;
- void visit(ModPtr v) override;
- void visit(MaxPtr v) override;
- void visit(MinPtr v) override;
- void visit(AndPtr v) override;
- void visit(OrPtr v) override;
- void visit(XorPtr v) override;
- void visit(LshiftPtr v) override;
- void visit(RshiftPtr v) override;
- void visit(CompareSelectPtr v) override;
+ void visit(const AddPtr& v) override;
+ void visit(const SubPtr& v) override;
+ void visit(const MulPtr& v) override;
+ void visit(const DivPtr& v) override;
+ void visit(const ModPtr& v) override;
+ void visit(const MaxPtr& v) override;
+ void visit(const MinPtr& v) override;
+ void visit(const AndPtr& v) override;
+ void visit(const OrPtr& v) override;
+ void visit(const XorPtr& v) override;
+ void visit(const LshiftPtr& v) override;
+ void visit(const RshiftPtr& v) override;
+ void visit(const CompareSelectPtr& v) override;
-#define IMM_VISIT_DECLARE(_1, Name) void visit(Name##ImmPtr v) override;
+#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##ImmPtr& v) override;
AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
- void visit(CastPtr v) override;
- void visit(BitCastPtr v) override;
- void visit(VarPtr v) override;
- void visit(RampPtr v) override;
- void visit(LoadPtr v) override;
- void visit(ForPtr v) override;
- void visit(BlockPtr v) override;
- void visit(StorePtr v) override;
- void visit(BroadcastPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(IntrinsicsPtr v) override;
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
- void visit(FreeExtPtr v) override;
- void visit(PlacementAllocatePtr v) override;
- void visit(LetPtr v) override;
- void visit(CondPtr v) override;
- void visit(ExternalCallPtr v) override;
- void visit(ExternalCallWithAllocPtr v) override;
+ void visit(const CastPtr& v) override;
+ void visit(const BitCastPtr& v) override;
+ void visit(const VarPtr& v) override;
+ void visit(const RampPtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const BroadcastPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const IntrinsicsPtr& v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
+ void visit(const FreeExtPtr& v) override;
+ void visit(const PlacementAllocatePtr& v) override;
+ void visit(const LetPtr& v) override;
+ void visit(const CondPtr& v) override;
+ void visit(const ExternalCallPtr& v) override;
+ void visit(const ExternalCallWithAllocPtr& v) override;
void emitIsNan(IntrinsicsPtr v);
@@ -759,7 +759,7 @@
// TODO: The binary ops are copypasta.
-void LLVMCodeGenImpl::visit(AddPtr v) {
+void LLVMCodeGenImpl::visit(const AddPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -777,7 +777,7 @@
}
}
-void LLVMCodeGenImpl::visit(SubPtr v) {
+void LLVMCodeGenImpl::visit(const SubPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -795,7 +795,7 @@
}
}
-void LLVMCodeGenImpl::visit(MulPtr v) {
+void LLVMCodeGenImpl::visit(const MulPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -813,7 +813,7 @@
}
}
-void LLVMCodeGenImpl::visit(DivPtr v) {
+void LLVMCodeGenImpl::visit(const DivPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -831,7 +831,7 @@
}
}
-void LLVMCodeGenImpl::visit(AndPtr v) {
+void LLVMCodeGenImpl::visit(const AndPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -846,7 +846,7 @@
}
}
-void LLVMCodeGenImpl::visit(OrPtr v) {
+void LLVMCodeGenImpl::visit(const OrPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -861,7 +861,7 @@
}
}
-void LLVMCodeGenImpl::visit(XorPtr v) {
+void LLVMCodeGenImpl::visit(const XorPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -876,7 +876,7 @@
}
}
-void LLVMCodeGenImpl::visit(LshiftPtr v) {
+void LLVMCodeGenImpl::visit(const LshiftPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -891,7 +891,7 @@
}
}
-void LLVMCodeGenImpl::visit(RshiftPtr v) {
+void LLVMCodeGenImpl::visit(const RshiftPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -910,7 +910,7 @@
}
}
-void LLVMCodeGenImpl::visit(ModPtr v) {
+void LLVMCodeGenImpl::visit(const ModPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
bool lfp = lhs->getType()->isFPOrFPVectorTy();
@@ -925,7 +925,7 @@
}
}
-void LLVMCodeGenImpl::visit(MaxPtr v) {
+void LLVMCodeGenImpl::visit(const MaxPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
@@ -948,7 +948,7 @@
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs));
}
-void LLVMCodeGenImpl::visit(MinPtr v) {
+void LLVMCodeGenImpl::visit(const MinPtr& v) {
v->lhs()->accept(this);
auto lhs = this->value_;
v->rhs()->accept(this);
@@ -970,7 +970,7 @@
irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs));
}
-void LLVMCodeGenImpl::visit(CompareSelectPtr v) {
+void LLVMCodeGenImpl::visit(const CompareSelectPtr& v) {
auto genUnbiased = [this, v]() -> llvm::Value* {
v->lhs()->accept(this);
auto lhs = this->value_;
@@ -1073,21 +1073,21 @@
}
#define IMM_VISIT_DECLARE(Type, Name) \
- void LLVMCodeGenImpl::visit(Name##ImmPtr v) { \
+ void LLVMCodeGenImpl::visit(const Name##ImmPtr& v) { \
value_ = getFromType<Type>(Name##Ty_, v->value()); \
}
AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE);
#undef IMM_VISIT_DECLARE
-void LLVMCodeGenImpl::visit(HalfImmPtr v) {
+void LLVMCodeGenImpl::visit(const HalfImmPtr& v) {
value_ = llvm::ConstantFP::get(HalfTy_, v->value());
}
-void LLVMCodeGenImpl::visit(BFloat16ImmPtr v) {
+void LLVMCodeGenImpl::visit(const BFloat16ImmPtr& v) {
value_ = llvm::ConstantInt::get(ShortTy_, v->value().x);
}
-void LLVMCodeGenImpl::visit(BoolImmPtr v) {
+void LLVMCodeGenImpl::visit(const BoolImmPtr& v) {
value_ = llvm::ConstantInt::get(BoolTy_, v->value());
}
@@ -1099,7 +1099,7 @@
}
}
-void LLVMCodeGenImpl::visit(CastPtr v) {
+void LLVMCodeGenImpl::visit(const CastPtr& v) {
v->src_value()->accept(this);
auto dst_type = v->dtype().scalar_type();
@@ -1246,7 +1246,7 @@
}
}
-void LLVMCodeGenImpl::visit(BitCastPtr v) {
+void LLVMCodeGenImpl::visit(const BitCastPtr& v) {
v->src_value()->accept(this);
llvm::Type* dstType = dtypeToLLVM(v->dtype());
@@ -1265,7 +1265,7 @@
value_ = irb_.CreateBitOrPointerCast(value_, dstType);
}
-void LLVMCodeGenImpl::visit(VarPtr v) {
+void LLVMCodeGenImpl::visit(const VarPtr& v) {
value_ = varToValue(v);
}
@@ -1297,7 +1297,7 @@
}
}
-void LLVMCodeGenImpl::visit(RampPtr v) {
+void LLVMCodeGenImpl::visit(const RampPtr& v) {
v->base()->accept(this);
auto base = this->value_;
v->stride()->accept(this);
@@ -1397,7 +1397,7 @@
return phi;
}
-void LLVMCodeGenImpl::visit(LoadPtr v) {
+void LLVMCodeGenImpl::visit(const LoadPtr& v) {
if (v->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
@@ -1673,7 +1673,7 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(ForPtr v) {
+void LLVMCodeGenImpl::visit(const ForPtr& v) {
if (v->is_parallel()) {
processParallelFor(v);
return;
@@ -1729,7 +1729,7 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(BlockPtr v) {
+void LLVMCodeGenImpl::visit(const BlockPtr& v) {
BlockPtr last = scope_;
scope_ = v;
@@ -1795,7 +1795,7 @@
irb_.SetInsertPoint(tailblock);
}
-void LLVMCodeGenImpl::visit(StorePtr v) {
+void LLVMCodeGenImpl::visit(const StorePtr& v) {
if (v->value()->dtype().lanes() == 1) {
v->base_handle()->accept(this);
auto base = this->value_;
@@ -1858,13 +1858,13 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(BroadcastPtr v) {
+void LLVMCodeGenImpl::visit(const BroadcastPtr& v) {
v->value()->accept(this);
int lanes = v->lanes();
value_ = irb_.CreateVectorSplat(lanes, value_);
}
-void LLVMCodeGenImpl::visit(IfThenElsePtr v) {
+void LLVMCodeGenImpl::visit(const IfThenElsePtr& v) {
v->condition()->accept(this);
llvm::Value* condition = value_;
llvm::Value* c = irb_.CreateICmpNE(
@@ -1991,7 +1991,7 @@
return SimdCallee{callee.getFunctionType(), callee.getCallee(), useSimd};
}
-void LLVMCodeGenImpl::visit(IntrinsicsPtr v) {
+void LLVMCodeGenImpl::visit(const IntrinsicsPtr& v) {
llvm::FunctionType* call_ty = nullptr;
llvm::Value* call_fn = nullptr;
bool call_simd_sleef = false;
@@ -2188,7 +2188,7 @@
varToVal_[buf->base_handle()] = ptr;
}
-void LLVMCodeGenImpl::visit(ExternalCallPtr v) {
+void LLVMCodeGenImpl::visit(const ExternalCallPtr& v) {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
@@ -2342,7 +2342,7 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(ExternalCallWithAllocPtr v) {
+void LLVMCodeGenImpl::visit(const ExternalCallWithAllocPtr& v) {
auto& func_registry = getNNCFunctionRegistry();
if (!func_registry.count(v->func_name())) {
throw unimplemented_lowering(v);
@@ -2559,7 +2559,7 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(AllocatePtr v) {
+void LLVMCodeGenImpl::visit(const AllocatePtr& v) {
llvm::Value* size =
llvm::ConstantInt::getSigned(LongTy_, v->dtype().byte_size());
for (ExprPtr e : v->dims()) {
@@ -2595,7 +2595,7 @@
varToVal_[v->buffer_var()] = malloc;
}
-void LLVMCodeGenImpl::visit(PlacementAllocatePtr v) {
+void LLVMCodeGenImpl::visit(const PlacementAllocatePtr& v) {
auto buf_to_reuse = v->buf_to_reuse();
auto buf = v->buf();
@@ -2607,7 +2607,7 @@
handleBufReuse(buf, buf_to_reuse);
}
-void LLVMCodeGenImpl::visit(FreePtr v) {
+void LLVMCodeGenImpl::visit(const FreePtr& v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
llvm::Value* ptr = bufsExtToFreeVal_.count(v->buffer_var())
@@ -2623,7 +2623,7 @@
}
}
-void LLVMCodeGenImpl::visit(FreeExtPtr v) {
+void LLVMCodeGenImpl::visit(const FreeExtPtr& v) {
value_ = llvm::ConstantInt::get(IntTy_, 0);
const auto& bufs = v->bufs();
const auto bufs_num = bufs.size();
@@ -2684,7 +2684,7 @@
value_ = llvm::ConstantInt::get(IntTy_, 0);
}
-void LLVMCodeGenImpl::visit(LetPtr v) {
+void LLVMCodeGenImpl::visit(const LetPtr& v) {
v->value()->accept(this);
if (!varToVal_.count(v->var())) {
varToVal_.emplace(v->var(), value_);
@@ -2694,7 +2694,7 @@
}
}
-void LLVMCodeGenImpl::visit(CondPtr v) {
+void LLVMCodeGenImpl::visit(const CondPtr& v) {
// Even if true_stmt and false_stmt are nullptr,
// in case condition is a function call with side effect,
// we still evaluate it.
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 62a67af..fb9b760 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -1026,7 +1026,7 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
if (stores_[v->buf()].insert(last_stmt_).second) {
uses_[v->buf()].push_back({(StmtPtr)v, true});
}
@@ -1034,7 +1034,7 @@
IRVisitor::visit(v);
}
- void visit(ExternalCallPtr v) override {
+ void visit(const ExternalCallPtr& v) override {
if (stores_[v->buf()].insert(last_stmt_).second) {
uses_[v->buf()].push_back({(StmtPtr)v, true});
}
@@ -1049,7 +1049,7 @@
IRVisitor::visit(v);
}
- void visit(ExternalCallWithAllocPtr v) override {
+ void visit(const ExternalCallWithAllocPtr& v) override {
for (const auto& out_buf : v->buf_out_args()) {
if (stores_[out_buf].insert(last_stmt_).second) {
uses_[out_buf].push_back({(StmtPtr)v, true});
@@ -1066,7 +1066,7 @@
IRVisitor::visit(v);
}
- void visit(LoadPtr v) override {
+ void visit(const LoadPtr& v) override {
if (loads_[v->buf()].insert(last_stmt_).second) {
uses_[v->buf()].push_back({last_stmt_, false});
}
@@ -1097,19 +1097,19 @@
}
private:
- void visit(StorePtr v) override {
+ void visit(const StorePtr& v) override {
contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- void visit(ExternalCallPtr v) override {
+ void visit(const ExternalCallPtr& v) override {
contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- void visit(ExternalCallWithAllocPtr v) override {
+ void visit(const ExternalCallWithAllocPtr& v) override {
contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
- void visit(BlockPtr v) override {
+ void visit(const BlockPtr& v) override {
contained_.insert((StmtPtr)v);
IRVisitor::visit(v);
}
diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp
index 1d687a6..12689f7 100644
--- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp
+++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.cpp
@@ -502,7 +502,7 @@
// Node visitors:
-void MemDependencyChecker::visit(StorePtr v) {
+void MemDependencyChecker::visit(const StorePtr& v) {
StmtPtr last = lastStmt_;
lastStmt_ = v;
v->value()->accept(this);
@@ -534,7 +534,7 @@
currentScope_->accesses_.push_back(info);
}
-void MemDependencyChecker::visit(LoadPtr v) {
+void MemDependencyChecker::visit(const LoadPtr& v) {
// Create a temporary scope to hold any loads that occur within the indices of
// this load.
auto indicesScope =
@@ -675,7 +675,7 @@
return false;
}
-void MemDependencyChecker::visit(ForPtr v) {
+void MemDependencyChecker::visit(const ForPtr& v) {
VarPtr var = v->var();
StmtPtr last = lastStmt_;
@@ -910,7 +910,7 @@
currentScope_ = currentScope_->parent;
}
-void MemDependencyChecker::visit(CondPtr v) {
+void MemDependencyChecker::visit(const CondPtr& v) {
StmtPtr last = lastStmt_;
lastStmt_ = v;
@@ -959,7 +959,7 @@
lastStmt_ = last;
}
-void MemDependencyChecker::visit(IfThenElsePtr v) {
+void MemDependencyChecker::visit(const IfThenElsePtr& v) {
// condition is in enclosing scope.
v->condition()->accept(this);
@@ -995,7 +995,7 @@
currentScope_ = enclosingScope;
}
-void MemDependencyChecker::visit(CompareSelectPtr v) {
+void MemDependencyChecker::visit(const CompareSelectPtr& v) {
// condition is in enclosing scope.
v->lhs()->accept(this);
v->rhs()->accept(this);
@@ -1055,7 +1055,7 @@
}
}
-void MemDependencyChecker::visit(BlockPtr v) {
+void MemDependencyChecker::visit(const BlockPtr& v) {
auto prev_scope = currentScope_;
// handle kernel inputs.
@@ -1091,7 +1091,7 @@
}
}
-void MemDependencyChecker::visit(LetPtr v) {
+void MemDependencyChecker::visit(const LetPtr& v) {
StmtPtr last = lastStmt_;
lastStmt_ = v;
@@ -1110,11 +1110,11 @@
// Don't support AtomicAdd yet, it's a bit more complex since it's both a read
// and a write. It's only inserted during Cuda codegen so this should be okay.
-void MemDependencyChecker::visit(AtomicAddPtr v) {
+void MemDependencyChecker::visit(const AtomicAddPtr& v) {
throw std::runtime_error("MemDependencyChecker AtomicAdd unimplemented");
}
-void MemDependencyChecker::visit(AllocatePtr v) {
+void MemDependencyChecker::visit(const AllocatePtr& v) {
StmtPtr last = lastStmt_;
lastStmt_ = v;
@@ -1146,7 +1146,7 @@
lastStmt_ = last;
}
-void MemDependencyChecker::visit(FreePtr v) {
+void MemDependencyChecker::visit(const FreePtr& v) {
StmtPtr last = lastStmt_;
lastStmt_ = v;
@@ -1290,7 +1290,7 @@
}
private:
- void visit(VarPtr v) override {
+ void visit(const VarPtr& v) override {
auto it = vars_.find(v);
if (it == vars_.end()) {
return;
diff --git a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h
index 3b5bb53..c374739 100644
--- a/torch/csrc/jit/tensorexpr/mem_dependency_checker.h
+++ b/torch/csrc/jit/tensorexpr/mem_dependency_checker.h
@@ -260,17 +260,17 @@
private:
// Node visitors.
- void visit(StorePtr v) override;
- void visit(LoadPtr v) override;
- void visit(ForPtr v) override;
- void visit(CondPtr v) override;
- void visit(IfThenElsePtr v) override;
- void visit(CompareSelectPtr v) override;
- void visit(BlockPtr v) override;
- void visit(LetPtr v) override;
- void visit(AtomicAddPtr v) override;
- void visit(AllocatePtr v) override;
- void visit(FreePtr v) override;
+ void visit(const StorePtr& v) override;
+ void visit(const LoadPtr& v) override;
+ void visit(const ForPtr& v) override;
+ void visit(const CondPtr& v) override;
+ void visit(const IfThenElsePtr& v) override;
+ void visit(const CompareSelectPtr& v) override;
+ void visit(const BlockPtr& v) override;
+ void visit(const LetPtr& v) override;
+ void visit(const AtomicAddPtr& v) override;
+ void visit(const AllocatePtr& v) override;
+ void visit(const FreePtr& v) override;
using BoundRelationship = std::pair<IndexBounds, std::shared_ptr<AccessInfo>>;
diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp
index 5e57209..25a43dc 100644
--- a/torch/csrc/jit/tensorexpr/registerizer.cpp
+++ b/torch/csrc/jit/tensorexpr/registerizer.cpp
@@ -188,7 +188,7 @@
scope->closeAccess(info);
}
-void RegisterizerAnalysis::visit(ForPtr v) {
+void RegisterizerAnalysis::visit(const ForPtr& v) {
if (v->loop_options().is_gpu_block_index() ||
v->loop_options().is_gpu_thread_index()) {
throw malformed_input(
@@ -272,7 +272,7 @@
mergeCurrentScopeIntoParent();
};
-void RegisterizerAnalysis::visit(CondPtr v) {
+void RegisterizerAnalysis::visit(const CondPtr& v) {
ExprPtr condition = v->condition();
BlockPtr true_stmt = v->true_stmt();
BlockPtr false_stmt = v->false_stmt();
@@ -312,7 +312,7 @@
// IfThenElses are just like Conds except they are not Stmts, which means no
// registerization can occur internally. However, the first reference to an
// access can occur within one if its visible outside the condition.
-void RegisterizerAnalysis::visit(IfThenElsePtr v) {
+void RegisterizerAnalysis::visit(const IfThenElsePtr& v) {
ExprPtr condition = v->condition();
ExprPtr true_value = v->true_value();
ExprPtr false_value = v->false_value();
@@ -347,7 +347,7 @@
}
}
-void RegisterizerAnalysis::visit(LetPtr v) {
+void RegisterizerAnalysis::visit(const LetPtr& v) {
currentScope_->addLocalVar(v->var());
stmtStack_.push_front(v);
@@ -355,7 +355,7 @@
stmtStack_.pop_front();
}
-void RegisterizerAnalysis::visit(BlockPtr v) {
+void RegisterizerAnalysis::visit(const BlockPtr& v) {
auto prev_scope = currentScope_;
if (currentScope_->block() != v) {
currentScope_ = std::make_shared<Scope>(v, prev_scope);
@@ -383,7 +383,7 @@
}
}
-void RegisterizerAnalysis::visit(StorePtr v) {
+void RegisterizerAnalysis::visit(const StorePtr& v) {
stmtStack_.push_front(v);
v->value()->accept(this);
stmtStack_.pop_front();
@@ -437,7 +437,7 @@
}
}
-void RegisterizerAnalysis::visit(LoadPtr v) {
+void RegisterizerAnalysis::visit(const LoadPtr& v) {
if (v->indices().empty()) {
// already a scalar.
return;
diff --git a/torch/csrc/jit/tensorexpr/registerizer.h b/torch/csrc/jit/tensorexpr/registerizer.h
index f73551b..62efb4d 100644
--- a/torch/csrc/jit/tensorexpr/registerizer.h
+++ b/torch/csrc/jit/tensorexpr/registerizer.h
@@ -326,25 +326,25 @@
: currentScope_(std::make_shared<Scope>(nullptr, nullptr, 0)) {}
~RegisterizerAnalysis() override = default;
- void visit(ForPtr v) override;
+ void visit(const ForPtr& v) override;
- void visit(CondPtr v) override;
+ void visit(const CondPtr& v) override;
- void visit(BlockPtr v) override;
+ void visit(const BlockPtr& v) override;
- void visit(StorePtr v) override;
+ void visit(const StorePtr& v) override;
- void visit(LoadPtr v) override;
+ void visit(const LoadPtr& v) override;
- void visit(IfThenElsePtr v) override;
+ void visit(const IfThenElsePtr& v) override;
- void visit(LetPtr v) override;
+ void visit(const LetPtr& v) override;
-#define STMT_ON_STACK(Op) \
- void visit(Op##Ptr v) override { \
- stmtStack_.push_front(v); \
- IRVisitor::visit(v); \
- stmtStack_.pop_front(); \
+#define STMT_ON_STACK(Op) \
+ void visit(const Op##Ptr& v) override { \
+ stmtStack_.push_front(v); \
+ IRVisitor::visit(v); \
+ stmtStack_.pop_front(); \
}
STMT_ON_STACK(AtomicAdd);