[TensorExpr] add more detail to malformed_input exceptions (#35891)
Summary:
Add an explanation string to malformed_input exceptions thrown inside jit/tensorexpr to aid in debugging issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35891
Differential Revision: D20822306
Pulled By: nickgg
fbshipit-source-id: ce153a05218f2a4da5ecf5f1a5dc439070c96e55
diff --git a/torch/csrc/jit/tensorexpr/buffer.h b/torch/csrc/jit/tensorexpr/buffer.h
index 026c4c2..2eb1356 100644
--- a/torch/csrc/jit/tensorexpr/buffer.h
+++ b/torch/csrc/jit/tensorexpr/buffer.h
@@ -11,7 +11,7 @@
Buffer(const BufHandle& data, const Dtype& dtype)
: data_(data.node()), dtype_(dtype) {
if (data.dtype() != kHandle) {
- throw malformed_input();
+ throw malformed_input("Buffer dtype must be Handle");
}
std::vector<ExprHandle> stride_handles(ndim());
@@ -51,6 +51,7 @@
ExprHandle operator()(Args... args) const {
return LoadValue(std::forward<Args>(args)...);
}
+
ExprHandle LoadValue(
const ExprHandle& x,
const ExprHandle& y,
diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
index 2119b1e..13fbdf1 100644
--- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp
@@ -52,7 +52,8 @@
static int as_int(const Expr* expr) {
auto v = dynamic_cast<const IntImm*>(expr);
if (!v) {
- throw malformed_input(expr);
+ throw malformed_input(
+ "cuda_codegen: non Int expr interpreted as int", expr);
}
return v->value();
@@ -508,7 +509,7 @@
void CudaCodeGen::call(const std::vector<CallArg>& args) {
if (args.size() != buffer_args().size()) {
- throw malformed_input();
+ throw malformed_input("cuda_codegen: wrong number of args in call");
}
// TODO: move as much of this into the constructors.
@@ -517,7 +518,8 @@
const std::vector<const Expr*>& gpu_thread_extents =
printer_->gpu_thread_extents();
if (gpu_block_extents.size() > 3 || gpu_thread_extents.size() > 3) {
- throw malformed_input();
+ throw malformed_input(
+ "cuda_codegen: block or thread extent greater than 3D");
}
std::vector<int> gpu_block_extents_v(3, 1);
diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h
index 7bb5813..493b845 100644
--- a/torch/csrc/jit/tensorexpr/eval.h
+++ b/torch/csrc/jit/tensorexpr/eval.h
@@ -129,7 +129,7 @@
TORCH_API void call(const std::vector<CallArg>& args) override {
if (args.size() != buffer_args().size()) {
- throw malformed_input();
+ throw malformed_input("bad args in IREvaluator call");
}
for (size_t i = 0; i < args.size(); i++) {
bind(buffer_args()[i], args[i]);
@@ -346,7 +346,7 @@
v->rhs()->accept(this);
Value rhs_v = value_;
if (lhs_v.dtype() != rhs_v.dtype()) {
- throw malformed_input(v);
+ throw malformed_input("bad dtype in binary op", v);
}
IRNodeType expr_type = v->expr_type();
if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kOr ||
@@ -385,7 +385,7 @@
if (lhs_v.dtype() != rhs_v.dtype() ||
ret_val1_v.dtype() != ret_val2_v.dtype()) {
- throw malformed_input(v);
+ throw malformed_input("bad dtype in CompareSelect", v);
}
switch (lhs_v.dtype().scalar_type()) {
@@ -411,7 +411,7 @@
TORCH_API void visit(const Let* v) override {
const Var* var = dynamic_cast<const Var*>(v->var());
if (!var) {
- throw malformed_input(v);
+ throw malformed_input("bad Var in Let", v);
}
v->value()->accept(this);
Value value = value_;
@@ -435,7 +435,7 @@
TORCH_API void visit(const LetStmt* v) override {
const Var* var = v->var();
if (!var) {
- throw malformed_input(v);
+ throw malformed_input("bad Var in LetStmt", v);
}
v->value()->accept(this);
@@ -460,7 +460,7 @@
TORCH_API void visit(const Var* v) override {
auto iter = eval_context_.find(v);
if (iter == eval_context_.end()) {
- throw malformed_input(v);
+ throw malformed_input("could not find Var in context", v);
}
value_ = iter->second;
@@ -499,7 +499,7 @@
Dtype dst_dtype = v->dtype();
Dtype src_dtype = src_value->dtype();
if (src_dtype.lanes() != dst_dtype.lanes()) {
- throw malformed_input(v);
+ throw malformed_input("lane mismatch in Cast", v);
}
if (src_dtype != dst_dtype) {
@@ -523,7 +523,7 @@
v->stop()->accept(this);
int stop = value_.as<int>();
if (eval_context_.count(var_node)) {
- throw malformed_input(v);
+ throw malformed_input("could not find var_node in For context", v);
}
for (int i = start; i < stop; i++) {
@@ -580,7 +580,7 @@
const Var* base_node = v->base_handle();
auto iter = buffer_mapping_.find(base_node);
if (iter == buffer_mapping_.end()) {
- throw malformed_input(v);
+ throw malformed_input("could not find base node in Load", v);
}
void* ptr = iter->second;
@@ -613,7 +613,7 @@
const Var* base_node = v->base_handle();
auto iter = buffer_mapping_.find(base_node);
if (iter == buffer_mapping_.end()) {
- throw malformed_input(v);
+ throw malformed_input("could not find base node in Store", v);
}
void* ptr = iter->second;
@@ -624,25 +624,25 @@
v->mask()->accept(this);
std::vector<int> mask = value().as_vec<int>();
if (index.size() != mask.size()) {
- throw malformed_input(v);
+ throw malformed_input("mask size mismatch in Store", v);
}
ScalarType v_sdtype = v->value()->dtype().scalar_type();
switch (v_sdtype) {
-#define TYPE_CASE(Type, Name) \
- case ScalarType::Name: { \
- v->value()->accept(this); \
- std::vector<Type> value = this->value().as_vec<Type>(); \
- if (index.size() != value.size()) { \
- throw malformed_input(v); \
- } \
- Type* ptr##Name = static_cast<Type*>(ptr); \
- for (size_t i = 0; i < index.size(); i++) { \
- if (mask[i]) { \
- ptr##Name[index[i]] = value[i]; \
- } \
- } \
+#define TYPE_CASE(Type, Name) \
+ case ScalarType::Name: { \
+ v->value()->accept(this); \
+ std::vector<Type> value = this->value().as_vec<Type>(); \
+ if (index.size() != value.size()) { \
+ throw malformed_input("value size mismatch in Store", v); \
+ } \
+ Type* ptr##Name = static_cast<Type*>(ptr); \
+ for (size_t i = 0; i < index.size(); i++) { \
+ if (mask[i]) { \
+ ptr##Name[index[i]] = value[i]; \
+ } \
+ } \
} break;
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE);
#undef TYPE_CASE
@@ -669,7 +669,7 @@
if (values.size() >= 2ULL) {
v2 = values[1].as_vec<float>();
if (v1.size() != v2.size()) {
- throw malformed_input(v);
+ throw malformed_input("value size mismatch in Intrinsics", v);
}
}
@@ -829,7 +829,7 @@
const Var* key_var = entry.first;
const Expr* value = entry.second;
if (!key_var) {
- throw malformed_input();
+ throw malformed_input("missing key in VarSubMutator");
}
var_mapping_[key_var] = value;
}
diff --git a/torch/csrc/jit/tensorexpr/exceptions.h b/torch/csrc/jit/tensorexpr/exceptions.h
index 4734157..3ec6ebd 100644
--- a/torch/csrc/jit/tensorexpr/exceptions.h
+++ b/torch/csrc/jit/tensorexpr/exceptions.h
@@ -56,8 +56,14 @@
: std::runtime_error("MALFORMED INPUT: " + err) {}
explicit malformed_input(const Expr* expr)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(expr)) {}
+ explicit malformed_input(const std::string& err, const Expr* expr)
+ : std::runtime_error(
+ "MALFORMED INPUT: " + err + " - " + std::to_string(expr)) {}
explicit malformed_input(const Stmt* stmt)
: std::runtime_error("MALFORMED INPUT: " + std::to_string(stmt)) {}
+ explicit malformed_input(const std::string& err, const Stmt* stmt)
+ : std::runtime_error(
+ "MALFORMED INPUT: " + err + " - " + std::to_string(stmt)) {}
};
} // namespace tensorexpr
diff --git a/torch/csrc/jit/tensorexpr/function.cpp b/torch/csrc/jit/tensorexpr/function.cpp
index 4a9a9a6..3888e1d 100644
--- a/torch/csrc/jit/tensorexpr/function.cpp
+++ b/torch/csrc/jit/tensorexpr/function.cpp
@@ -41,7 +41,7 @@
const std::vector<DimArg>& dim_args,
const std::function<ExprHandle(const VarHandle&)>& body_func) {
if (dim_args.size() != 1) {
- throw malformed_input();
+ throw malformed_input("mismatch between body and arg size (1)");
}
std::vector<const Expr*> dims;
@@ -59,7 +59,7 @@
const std::function<ExprHandle(const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 2) {
- throw malformed_input();
+ throw malformed_input("mismatch between body and arg size (2)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
@@ -77,7 +77,7 @@
ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)>&
body_func) {
if (dim_args.size() != 3) {
- throw malformed_input();
+ throw malformed_input("mismatch between body and arg size (3)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args;
@@ -99,7 +99,7 @@
const VarHandle&,
const VarHandle&)>& body_func) {
if (dim_args.size() != 4) {
- throw malformed_input();
+ throw malformed_input("mismatch between body and arg size (4)");
}
std::vector<const Expr*> dims;
std::vector<const Var*> args_nodes;
diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp
index 00726c4..f07859a 100644
--- a/torch/csrc/jit/tensorexpr/ir.cpp
+++ b/torch/csrc/jit/tensorexpr/ir.cpp
@@ -12,12 +12,12 @@
static Dtype dtypeOfIndices(const std::vector<const Expr*>& indices) {
if (!indices.size()) {
- throw malformed_input();
+ throw malformed_input("cant get dtype of empty indices");
}
Dtype dt = indices.at(0)->dtype();
for (size_t i = 1; i < indices.size(); ++i) {
if (indices.at(i)->dtype() != dt) {
- throw malformed_input();
+ throw malformed_input("dtype mismatch in dtypeOfIndices");
}
}
return dt;
@@ -55,14 +55,15 @@
const Expr* mask)
: ExprNodeBase(dtype), buf_(buf), indices_(indices), mask_(mask) {
if (buf->base_handle()->dtype() != kHandle) {
- throw malformed_input();
+ throw malformed_input(
+ "Load base handle dtype must be Handle", buf->base_handle());
}
if (!indicesValid(indices)) {
- throw malformed_input();
+ throw malformed_input("invalid indices in Load");
}
Dtype index_dtype = dtypeOfIndices(indices);
if (index_dtype.lanes() != mask->dtype().lanes()) {
- throw malformed_input();
+ throw malformed_input("lane mismatch in Load mask");
}
}
@@ -89,7 +90,7 @@
const Expr* mask)
: Store(buffer.data(), indices, value, mask) {
if (buffer.dtype().scalar_type() != value->dtype().scalar_type()) {
- throw malformed_input();
+ throw malformed_input("invalid dtype in Store");
}
}
@@ -100,7 +101,7 @@
const Expr* mask)
: buf_(buf), indices_(std::move(indices)), value_(value), mask_(mask) {
if (buf->dtype() != kHandle) {
- throw malformed_input();
+ throw malformed_input("Store base handle must be Handle");
}
/*
TODO: Reenable the checks.
@@ -167,7 +168,7 @@
size_t ndim = dims.size();
if (ndim != indices.size()) {
- throw malformed_input();
+ throw malformed_input("dimensions mismatch in flatten_index");
}
if (ndim == 0) {
return new IntImm(0);
@@ -202,7 +203,7 @@
const std::vector<const Expr*>& params) {
// TODO: check the op_type an dmake a real decision
if (params.size() == 0) {
- throw malformed_input();
+ throw malformed_input("invalid params in Intrinsics");
}
return params[0]->dtype();
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 8a0fc9c..102357f 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -159,7 +159,7 @@
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in And");
}
}
};
@@ -172,7 +172,7 @@
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in Or");
}
}
};
@@ -185,7 +185,7 @@
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in Xor");
}
}
};
@@ -198,7 +198,7 @@
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in Lshift");
}
}
};
@@ -211,7 +211,7 @@
throw unsupported_dtype();
}
if (lhs->dtype() != rhs->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in Rshift");
}
}
};
@@ -430,7 +430,7 @@
stride_(stride),
lanes_(lanes) {
if (stride->dtype() != base->dtype()) {
- throw malformed_input();
+ throw malformed_input("Bad stride in Ramp");
}
}
@@ -537,7 +537,7 @@
throw unsupported_dtype();
}
if (t->dtype() != f->dtype()) {
- throw malformed_input();
+ throw malformed_input("Bad dtype in IfThenElse");
}
}
@@ -622,7 +622,7 @@
const ExprHandle& rhs,
CompareSelectOperation cmp_op) {
if (lhs.dtype() != rhs.dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in CompareSelect");
}
return ExprHandle(new CompareSelect(
lhs.node(),
@@ -639,7 +639,7 @@
const ExprHandle& ret_val2,
CompareSelectOperation cmp_op) {
if (lhs.dtype() != rhs.dtype() || ret_val1.dtype() != ret_val2.dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in CompareSelect");
}
return ExprHandle(new CompareSelect(
lhs.node(), rhs.node(), ret_val1.node(), ret_val2.node(), cmp_op));
@@ -665,7 +665,7 @@
ret_val2_(ret_val2),
compare_op_(cmp_op) {
if (ret_val1->dtype() != ret_val2->dtype()) {
- throw malformed_input();
+ throw malformed_input("bad dtype in CompareSelect");
}
}
};
@@ -810,7 +810,7 @@
: BaseClass(IntrinsicsDtype(op_type, dtype), kIntrinsics, {}),
op_type_(op_type) {
if (OpArgCount(op_type) != 0) {
- throw malformed_input();
+ throw malformed_input("bad arg count in Intrinsics");
}
}
@@ -818,7 +818,7 @@
: BaseClass(IntrinsicsDtype(op_type, v1->dtype()), kIntrinsics, {v1}),
op_type_(op_type) {
if (OpArgCount(op_type) != 1) {
- throw malformed_input();
+ throw malformed_input("bad arg count in Intrinsics");
}
}
@@ -829,7 +829,7 @@
{v1, v2}),
op_type_(op_type) {
if (OpArgCount(op_type) != 2) {
- throw malformed_input();
+ throw malformed_input("bad arg count in Intrinsics");
}
}
@@ -837,7 +837,7 @@
: BaseClass(IntrinsicsDtype(op_type, params), kIntrinsics, params),
op_type_(op_type) {
if (OpArgCount(op_type) != nparams()) {
- throw malformed_input();
+ throw malformed_input("bad arg count in Intrinsics");
}
}
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 70a256b..0e253f8 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -48,7 +48,7 @@
static std::vector<DimArg> texprDims(const torch::jit::Value* v) {
if (v->type()->kind() != TypeKind::TensorType) {
- throw malformed_input();
+ throw malformed_input("type is not Tensor");
}
auto tt = v->type()->cast<TensorType>();
@@ -87,7 +87,7 @@
}
if (!scalars_.count(v->unique())) {
- throw malformed_input();
+ throw malformed_input("no scalar in Constant");
}
return scalars_.at(v->unique());
@@ -135,7 +135,7 @@
const ExprHandle& e,
const torch::jit::Value* v) {
if (v->type()->kind() != TypeKind::TensorType) {
- throw malformed_input();
+ throw malformed_input("type is not tensor in demoteOutput");
}
auto tt = *v->type()->cast<TensorType>()->scalarType();
@@ -913,7 +913,7 @@
int64_t dim = constant(n->inputs()[1]).AsNode<IntImm>()->value();
if (dim < 0) {
if (axes.size() == 0) {
- throw malformed_input();
+ throw malformed_input("axes are zero handling unsqueeze");
}
dim += axes.size() - 1;
@@ -1303,7 +1303,7 @@
const c10::VaryingStrides& contiguity,
const std::unordered_map<int64_t, VarHandle>& sizeVars) {
if (axes.size() != strides.size()) {
- throw malformed_input();
+ throw malformed_input("axes and strides size mismatch");
}
std::vector<ShapeArg> strideArgs;
@@ -1312,7 +1312,7 @@
ExprHandle index = 0;
if (axes.size() == 0) {
- throw malformed_input();
+ throw malformed_input("axes are zero creating input index");
}
size_t n = axes.size() - 1;
@@ -1332,7 +1332,7 @@
if (sizeVal < 0) {
auto it = sizeVars.find(sizeVal);
if (it == sizeVars.end()) {
- throw malformed_input();
+ throw malformed_input("cannot dind size when creating input index");
}
auto const& v = it->second;
@@ -1458,7 +1458,7 @@
// Move output operands from `tensors_` to `tensorOutputs_`
for (const auto& output : graph_->outputs()) {
if (!tensors_.count(output->unique())) {
- throw malformed_input();
+ throw malformed_input("cannot find output Tensor");
}
tensorOutputs_.emplace_back(tensors_.at(output->unique()));
tensors_.erase(output->unique());
@@ -1527,7 +1527,7 @@
} else {
const IntImm* s = dynamic_cast<const IntImm*>(dim);
if (!s) {
- throw malformed_input(dim);
+ throw malformed_input("output expected Int", dim);
}
tensorSize.push_back(s->value());
}
diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
index 4e16359..d181e3e 100644
--- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
+++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp
@@ -185,7 +185,7 @@
void LLVMCodeGen::call(const std::vector<CallArg>& args) {
if (args.size() != buffer_args().size()) {
- throw malformed_input();
+ throw malformed_input("wrong number of args in call");
}
std::vector<void*> argv;
@@ -365,7 +365,7 @@
} else if (!lfp && !rfp) {
value_ = irb_.CreateAdd(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Add", v);
}
}
@@ -383,7 +383,7 @@
} else if (!lfp && !rfp) {
value_ = irb_.CreateSub(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Sub", v);
}
}
@@ -401,7 +401,7 @@
} else if (!lfp && !rfp) {
value_ = irb_.CreateMul(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Mul", v);
}
}
@@ -419,7 +419,7 @@
} else if (!lfp && !rfp) {
value_ = irb_.CreateSDiv(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Div", v);
}
}
@@ -434,7 +434,7 @@
if (!lfp && !rfp) {
value_ = irb_.CreateAnd(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in And", v);
}
}
@@ -449,7 +449,7 @@
if (!lfp && !rfp) {
value_ = irb_.CreateOr(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Or", v);
}
}
@@ -464,7 +464,7 @@
if (!lfp && !rfp) {
value_ = irb_.CreateXor(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Xor", v);
}
}
@@ -479,7 +479,7 @@
if (!lfp && !rfp) {
value_ = irb_.CreateShl(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Lshift", v);
}
}
@@ -494,7 +494,7 @@
if (!lfp && !rfp) {
value_ = irb_.CreateLShr(lhs, rhs);
} else {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad type in Rshift", v);
}
}
@@ -699,7 +699,7 @@
void LLVMCodeGenImpl::visit(const Let* v) {
const Var* var = dynamic_cast<const Var*>(v->var());
if (!var) {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad Var in Let", v);
}
v->value()->accept(this);
@@ -721,7 +721,7 @@
void LLVMCodeGenImpl::visit(const LetStmt* v) {
const Var* var = v->var();
if (!var) {
- throw malformed_input(v);
+ throw malformed_input("llvm_codgen: bad Var in LetStmt", v);
}
v->value()->accept(this);
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 2ce5ebf..dcce3de 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -631,13 +631,13 @@
if (all_processed) {
result.push_back(t);
if (processed.count(t)) {
- throw malformed_input();
+ throw malformed_input("failure to find all processed Tensors");
}
processed.insert(t);
} else {
if (queued.count(t)) {
- throw malformed_input();
+ throw malformed_input("failure to find all queued Tensors");
}
q.push(t);
@@ -686,7 +686,7 @@
}
if (t->buf()->ndim() == 0) {
- throw malformed_input();
+ throw malformed_input("Tensor lowered to zero dimensions");
}
for (size_t i = 0; i < t->buf()->ndim(); i++) {
@@ -765,9 +765,9 @@
For** tail) {
Block* p = dynamic_cast<Block*>(f->get_parent());
if (!f) {
- throw malformed_input(f);
+ throw malformed_input("splitWithTail attempted on null loop", f);
} else if (!p) {
- throw malformed_input(p);
+ throw malformed_input("splitWithTail attempted on loop with no parent", p);
}
bool tail_is_needed = true;
diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h
index aa3e3c9..de832bb 100644
--- a/torch/csrc/jit/tensorexpr/stmt.h
+++ b/torch/csrc/jit/tensorexpr/stmt.h
@@ -77,7 +77,7 @@
const ExprHandle& value,
Stmt* body) {
if (body->get_parent()) {
- throw malformed_input(body);
+ throw malformed_input("LetStmt body has existing parent", body);
}
return new LetStmt(var.node(), value.node(), body);
@@ -114,7 +114,7 @@
void prepend_stmt(Stmt* s) {
if (s->get_parent()) {
- throw malformed_input(s);
+ throw malformed_input("Block prepend Stmt with existing parent", s);
}
stmts_.push_front(s);
@@ -122,7 +122,7 @@
}
void append_stmt(Stmt* s) {
if (s->get_parent()) {
- throw malformed_input(s);
+ throw malformed_input("Block append Stmt with existing parent", s);
}
stmts_.push_back(s);
@@ -130,7 +130,8 @@
}
bool replace_stmt(Stmt* old_stmt, Stmt* new_stmt) {
if (new_stmt->get_parent()) {
- throw malformed_input(new_stmt);
+ throw malformed_input(
+ "Block replace Stmt wiith existing parent", new_stmt);
}
auto pos = std::find(stmts_.begin(), stmts_.end(), old_stmt);
@@ -150,7 +151,8 @@
explicit Block(const std::vector<Stmt*>& stmts) {
for (Stmt* s : stmts) {
if (s->get_parent()) {
- throw malformed_input(s);
+ throw malformed_input(
+ "Block creation has Stmt with existing parent", s);
}
stmts_.push_back(s);
@@ -339,7 +341,7 @@
std::string gpu_block_index_str() const {
if (!is_gpu_block_index()) {
- throw malformed_input();
+ throw malformed_input("Has no GPU block index");
}
static const char* kBlockIndexNames[] = {
@@ -350,7 +352,7 @@
};
if (gpu_block_index_ < 0 || gpu_block_index_ >= 4) {
- throw malformed_input();
+ throw malformed_input("invalid GPU block index");
}
return kBlockIndexNames[gpu_block_index_];
@@ -377,14 +379,14 @@
std::string gpu_thread_index_str() const {
if (!is_gpu_thread_index()) {
- throw malformed_input();
+ throw malformed_input("has no GPU thread index");
}
static const char* kThreadIndexNames[] = {
"threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"};
if (gpu_thread_index_ < 0 || gpu_thread_index_ >= 4) {
- throw malformed_input();
+ throw malformed_input("invalid GPU thread index");
}
return kThreadIndexNames[gpu_thread_index_];
@@ -457,13 +459,13 @@
For(const Var* var, const Expr* start, const Expr* stop, Stmt* body)
: var_(var), start_(start), stop_(stop) {
if (!var) {
- throw malformed_input(var);
+ throw malformed_input("invalid Var in For loop", var);
} else if (!start) {
- throw malformed_input(start);
+ throw malformed_input("invalid Start in For loop", start);
} else if (!stop) {
- throw malformed_input(stop);
+ throw malformed_input("invalid Stop in For loop", stop);
} else if (!body || body->get_parent()) {
- throw malformed_input(body);
+ throw malformed_input("invalid Body in For loop", body);
}
Block* b = dynamic_cast<Block*>(body);
@@ -481,13 +483,13 @@
const LoopOptions& loop_options)
: var_(var), start_(start), stop_(stop), loop_options_(loop_options) {
if (!var) {
- throw malformed_input(var);
+ throw malformed_input("invalid Var in For loop", var);
} else if (!start) {
- throw malformed_input(start);
+ throw malformed_input("invalid Start in For loop", start);
} else if (!stop) {
- throw malformed_input(stop);
+ throw malformed_input("invalid Stop in For loop", stop);
} else if (!body || body->get_parent()) {
- throw malformed_input(body);
+ throw malformed_input("invalid Body in For loop", body);
}
Block* b = dynamic_cast<Block*>(body);