replace `AT_ERROR(...)` with `TORCH_CHECK(false, ...)` (#104534)
Merely cosmetic for `AT_ERROR` I found by chance, following https://github.com/pytorch/pytorch/blob/e9d2d74f0abe4b0e0f238e11b537c64041b3f9a7/c10/util/Exception.h#L622
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104534
Approved by: https://github.com/soulitzer
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 8c90bac..1e66b4f 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -798,7 +798,7 @@
std::stringstream ss;
ss << "invalid number of gradients - expected ";
ss << edges.size() << ", but got " << grads.size();
- AT_ERROR(format_error(ss.str()));
+ TORCH_CHECK(false, format_error(ss.str()));
}
for (const auto i : c10::irange(grads.size())) {
const auto& edge = edges[i];
@@ -811,7 +811,7 @@
// FIXME: TestJit.test_ge_optimized fails this assertion.
// std::stringstream ss;
// ss << "undefined gradient at index " << i;
- // AT_ERROR(format_error(ss.str()));
+ // TORCH_CHECK(false, format_error(ss.str()));
continue;
}
@@ -820,7 +820,7 @@
grad = metadata.reduce_grad(grad);
} else {
const auto message = metadata.incompatible_shape_error_message(i, grad);
- AT_ERROR(format_error(message.str()));
+ TORCH_CHECK(false, format_error(message.str()));
}
}
@@ -839,7 +839,7 @@
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected dtype ";
ss << metadata.dtype() << " but got " << grad.dtype();
- AT_ERROR(format_error(ss.str()));
+ TORCH_CHECK(false, format_error(ss.str()));
}
if (grad.layout() != metadata.layout()) {
// TODO: Currently we only support (*, Sparse) combination for
@@ -856,7 +856,7 @@
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected layout ";
ss << metadata.layout() << " but got " << grad.layout();
- AT_ERROR(format_error(ss.str()));
+ TORCH_CHECK(false, format_error(ss.str()));
}
}
@@ -871,7 +871,7 @@
std::stringstream ss;
ss << "invalid gradient at index " << i << " - expected device ";
ss << metadata.device() << " but got " << grad.device();
- AT_ERROR(format_error(ss.str()));
+ TORCH_CHECK(false, format_error(ss.str()));
}
}
}