[torch::deploy] remove asserts from deploy (#73456)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73456
Replaces `at::Error` with a more simple implementation of exceptions in order to reduce the dependency of torch::deploy on torch.
Note: Internal Testing / changes are still needed
Test Plan: Imported from OSS
Reviewed By: samdow
Differential Revision: D34868005
Pulled By: PaliC
fbshipit-source-id: c8bb1f7a2b169b5a8e3b63a697e0ced748a0524c
(cherry picked from commit 51b3763d16e74458a5cfb8e4d660806dea897617)
diff --git a/torch/csrc/deploy/Exception.h b/torch/csrc/deploy/Exception.h
new file mode 100644
index 0000000..64cfe1c
--- /dev/null
+++ b/torch/csrc/deploy/Exception.h
@@ -0,0 +1,33 @@
+#include <exception>
+
+#ifndef MULTIPY_EXCEPTION_H
+#define MULTIPY_INTERNAL_ASSERT_WITH_MESSAGE(condition, message) \
+ if (!(condition)) { \
+ throw std::runtime_error( \
+ "Internal Assertion failed: (" + std::string(#condition) + "), " + \
+ "function " + __FUNCTION__ + ", file " + __FILE__ + ", line " + \
+ std::to_string(__LINE__) + ".\n" + "Please report bug to Pytorch.\n" + \
+ message + "\n"); \
+ }
+
+#define MULTIPY_INTERNAL_ASSERT_NO_MESSAGE(condition) \
+ MULTIPY_INTERNAL_ASSERT_WITH_MESSAGE(#condition, "")
+
+#define MULTIPY_INTERNAL_ASSERT_(x, condition, message, FUNC, ...) FUNC
+
+#define MULTIPY_INTERNAL_ASSERT(...) \
+ MULTIPY_INTERNAL_ASSERT_( \
+ , \
+ ##__VA_ARGS__, \
+ MULTIPY_INTERNAL_ASSERT_WITH_MESSAGE(__VA_ARGS__), \
+ MULTIPY_INTERNAL_ASSERT_NO_MESSAGE(__VA_ARGS__));
+
+#define MULTIPY_CHECK(condition, message) \
+ if (!(condition)) { \
+ throw std::runtime_error( \
+ "Check failed: (" + std::string(#condition) + "), " + "function " + \
+ __FUNCTION__ + ", file " + __FILE__ + ", line " + \
+ std::to_string(__LINE__) + ".\n" + message + "\n"); \
+ }
+
+#endif // MULTIPY_EXCEPTION_H
diff --git a/torch/csrc/deploy/deploy.cpp b/torch/csrc/deploy/deploy.cpp
index 647c9a4..7ce21b6 100644
--- a/torch/csrc/deploy/deploy.cpp
+++ b/torch/csrc/deploy/deploy.cpp
@@ -1,4 +1,3 @@
-#include <c10/util/Exception.h>
#include <torch/csrc/deploy/deploy.h>
#include <torch/csrc/deploy/elf_file.h>
#include <torch/cuda.h>
@@ -46,7 +45,7 @@
};
static bool writeDeployInterpreter(FILE* dst) {
- TORCH_INTERNAL_ASSERT(dst);
+ MULTIPY_INTERNAL_ASSERT(dst);
const char* payloadStart = nullptr;
size_t size = 0;
bool customLoader = false;
@@ -83,7 +82,7 @@
payloadStart = libStart;
}
size_t written = fwrite(payloadStart, 1, size, dst);
- TORCH_INTERNAL_ASSERT(size == written, "expected written == size");
+ MULTIPY_INTERNAL_ASSERT(size == written, "expected written == size");
return customLoader;
}
@@ -214,9 +213,9 @@
// function.
static dlopen_t find_real_dlopen() {
void* libc = dlopen("libdl.so.2", RTLD_NOLOAD | RTLD_LAZY | RTLD_LOCAL);
- TORCH_INTERNAL_ASSERT(libc);
+ MULTIPY_INTERNAL_ASSERT(libc);
auto dlopen_ = (dlopen_t)dlsym(libc, "dlopen");
- TORCH_INTERNAL_ASSERT(dlopen_);
+ MULTIPY_INTERNAL_ASSERT(dlopen_);
return dlopen_;
}
@@ -227,7 +226,7 @@
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
char libraryName[] = "/tmp/torch_deployXXXXXX";
int fd = mkstemp(libraryName);
- TORCH_INTERNAL_ASSERT(fd != -1, "failed to create temporary file");
+ MULTIPY_INTERNAL_ASSERT(fd != -1, "failed to create temporary file");
libraryName_ = libraryName;
FILE* dst = fdopen(fd, "wb");
@@ -343,12 +342,12 @@
return;
}
- TORCH_INTERNAL_ASSERT(iArgumentNames.isList());
+ MULTIPY_INTERNAL_ASSERT(iArgumentNames.isList());
auto argumentNames = iArgumentNames.toListRef();
argumentNamesOut.reserve(argumentNames.size());
for (auto& argumentName : argumentNames) {
- TORCH_INTERNAL_ASSERT(argumentName.isString());
+ MULTIPY_INTERNAL_ASSERT(argumentName.isString());
argumentNamesOut.push_back(argumentName.toStringRef());
}
}
diff --git a/torch/csrc/deploy/deploy.h b/torch/csrc/deploy/deploy.h
index c6a4794a..3ac0c16 100644
--- a/torch/csrc/deploy/deploy.h
+++ b/torch/csrc/deploy/deploy.h
@@ -1,6 +1,5 @@
#pragma once
#include <c10/util/Optional.h>
-#include <c10/util/irange.h>
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/deploy/noop_environment.h>
diff --git a/torch/csrc/deploy/elf_file.cpp b/torch/csrc/deploy/elf_file.cpp
index 85eaaa1..69b6e89 100644
--- a/torch/csrc/deploy/elf_file.cpp
+++ b/torch/csrc/deploy/elf_file.cpp
@@ -1,5 +1,6 @@
-#include <c10/util/irange.h>
+#include <torch/csrc/deploy/Exception.h>
#include <torch/csrc/deploy/elf_file.h>
+#include <c10/util/irange.h>
namespace torch {
namespace deploy {
@@ -13,7 +14,7 @@
shdrList_ = (Elf64_Shdr*)(fileData + ehdr_->e_shoff);
auto strtabSecNo = ehdr_->e_shstrndx;
- TORCH_CHECK(
+ MULTIPY_CHECK(
strtabSecNo >= 0 && strtabSecNo < numSections_,
"e_shstrndx out of range");
@@ -26,7 +27,7 @@
}
at::optional<Section> ElfFile::findSection(const char* name) const {
- TORCH_CHECK(name != nullptr, "Null name");
+ MULTIPY_CHECK(name != nullptr, "Null name");
at::optional<Section> found = at::nullopt;
for (const auto& section : sections_) {
if (strcmp(name, section.name) == 0) {
@@ -40,13 +41,13 @@
void ElfFile::checkFormat() const {
// check the magic numbers
- TORCH_CHECK(
+ MULTIPY_CHECK(
(ehdr_->e_ident[EI_MAG0] == ELFMAG0) &&
(ehdr_->e_ident[EI_MAG1] == ELFMAG1) &&
(ehdr_->e_ident[EI_MAG2] == ELFMAG2) &&
(ehdr_->e_ident[EI_MAG3] == ELFMAG3),
"Unexpected magic numbers");
- TORCH_CHECK(
+ MULTIPY_CHECK(
ehdr_->e_ident[EI_CLASS] == ELFCLASS64, "Only support 64bit ELF file");
}
diff --git a/torch/csrc/deploy/elf_file.h b/torch/csrc/deploy/elf_file.h
index e27750c..b6e0349 100644
--- a/torch/csrc/deploy/elf_file.h
+++ b/torch/csrc/deploy/elf_file.h
@@ -2,6 +2,7 @@
#include <c10/util/Optional.h>
#include <elf.h>
+#include <torch/csrc/deploy/Exception.h>
#include <torch/csrc/deploy/mem_file.h>
#include <vector>
@@ -40,7 +41,7 @@
const char* name = "";
if (strtabSection_) {
- TORCH_CHECK(nameOff >= 0 && nameOff < strtabSection_.len);
+ MULTIPY_CHECK(nameOff >= 0 && nameOff < strtabSection_.len, "");
name = strtabSection_.start + nameOff;
}
const char* start = memFile_.data() + shOff;
@@ -48,7 +49,7 @@
}
[[nodiscard]] const char* str(size_t off) const {
- TORCH_CHECK(off < strtabSection_.len, "String table index out of range");
+ MULTIPY_CHECK(off < strtabSection_.len, "String table index out of range");
return strtabSection_.start + off;
}
void checkFormat() const;
diff --git a/torch/csrc/deploy/example/benchmark.cpp b/torch/csrc/deploy/example/benchmark.cpp
index e1d5b6f..38c2964 100644
--- a/torch/csrc/deploy/example/benchmark.cpp
+++ b/torch/csrc/deploy/example/benchmark.cpp
@@ -2,8 +2,6 @@
#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
-#include <c10/util/irange.h>
-
#include <torch/script.h>
#include <pthread.h>
diff --git a/torch/csrc/deploy/interpreter/builtin_registry.cpp b/torch/csrc/deploy/interpreter/builtin_registry.cpp
index a34768c..6599489 100644
--- a/torch/csrc/deploy/interpreter/builtin_registry.cpp
+++ b/torch/csrc/deploy/interpreter/builtin_registry.cpp
@@ -1,6 +1,6 @@
#include <Python.h>
-#include <c10/util/Exception.h>
#include <fmt/format.h>
+#include <torch/csrc/deploy/Exception.h>
#include <torch/csrc/deploy/interpreter/builtin_registry.h>
namespace torch {
@@ -55,10 +55,10 @@
}
void BuiltinRegistry::runPreInitialization() {
- TORCH_INTERNAL_ASSERT(!Py_IsInitialized());
+ MULTIPY_INTERNAL_ASSERT(!Py_IsInitialized());
sanityCheck();
PyImport_FrozenModules = BuiltinRegistry::getAllFrozenModules();
- TORCH_INTERNAL_ASSERT(PyImport_FrozenModules != nullptr);
+ MULTIPY_INTERNAL_ASSERT(PyImport_FrozenModules != nullptr);
appendCPythonInittab();
}
@@ -83,7 +83,7 @@
)PYTHON";
void BuiltinRegistry::runPostInitialization() {
- TORCH_INTERNAL_ASSERT(Py_IsInitialized());
+ MULTIPY_INTERNAL_ASSERT(Py_IsInitialized());
std::string metaPathSetupScript(metaPathSetupTemplate);
std::string replaceKey = "<<<DEPLOY_BUILTIN_MODULES_CSV>>>";
auto itr = metaPathSetupScript.find(replaceKey);
@@ -91,7 +91,7 @@
metaPathSetupScript.replace(itr, replaceKey.size(), getBuiltinModulesCSV());
}
int r = PyRun_SimpleString(metaPathSetupScript.c_str());
- TORCH_INTERNAL_ASSERT(r == 0);
+ MULTIPY_INTERNAL_ASSERT(r == 0);
}
void BuiltinRegistry::registerBuiltin(
@@ -152,14 +152,14 @@
auto* cpythonInternalFrozens = getItem("cpython_internal");
// Num frozen builtins shouldn't change (unless modifying the underlying
// cpython version)
- TORCH_INTERNAL_ASSERT(
+ MULTIPY_INTERNAL_ASSERT(
cpythonInternalFrozens != nullptr &&
cpythonInternalFrozens->numModules == NUM_FROZEN_PY_BUILTIN_MODULES,
"Missing python builtin frozen modules");
auto* frozenpython = getItem("frozenpython");
#ifdef FBCODE_CAFFE2
- TORCH_INTERNAL_ASSERT(
+ MULTIPY_INTERNAL_ASSERT(
frozenpython != nullptr, "Missing frozen python modules");
#else
auto* frozentorch = getItem("frozentorch");
@@ -167,7 +167,7 @@
// and frozentorch contains stdlib+torch, while in fbcode they are separated
// due to thirdparty2 frozenpython. No fixed number of torch modules to check
// for, but there should be at least one.
- TORCH_INTERNAL_ASSERT(
+ MULTIPY_INTERNAL_ASSERT(
frozenpython != nullptr && frozentorch != nullptr &&
frozenpython->numModules + frozentorch->numModules >
NUM_FROZEN_PY_STDLIB_MODULES + 1,
diff --git a/torch/csrc/deploy/interpreter/interpreter_impl.h b/torch/csrc/deploy/interpreter/interpreter_impl.h
index 10a1489..a16c90e 100644
--- a/torch/csrc/deploy/interpreter/interpreter_impl.h
+++ b/torch/csrc/deploy/interpreter/interpreter_impl.h
@@ -15,8 +15,8 @@
the client application.
It is safe to throw exception types that are defined once in
- the context of the client application, such as c10::Error, which is defined
- in libtorch, which isn't duplicated in torch::deploy interpreters.
+ the context of the client application, such as std::runtime_error,
+ which isn't duplicated in torch::deploy interpreters.
==> Use TORCH_DEPLOY_TRY, _SAFE_CATCH_RETHROW around _ALL_ torch::deploy APIs
@@ -30,20 +30,16 @@
*/
#define TORCH_DEPLOY_TRY try {
-#define TORCH_DEPLOY_SAFE_CATCH_RETHROW \
- } \
- catch (std::exception & err) { \
- throw c10::Error( \
- std::string( \
- "Exception Caught inside torch::deploy embedded library: \n") + \
- err.what(), \
- ""); \
- } \
- catch (...) { \
- throw c10::Error( \
- std::string( \
- "Unknown Exception Caught inside torch::deploy embedded library"), \
- ""); \
+#define TORCH_DEPLOY_SAFE_CATCH_RETHROW \
+ } \
+ catch (std::exception & err) { \
+ throw std::runtime_error( \
+ "Exception Caught inside torch::deploy embedded library: \n" + \
+ std::string(err.what())); \
+ } \
+ catch (...) { \
+ throw std::runtime_error( \
+ "Unknown Exception Caught inside torch::deploy embedded library"); \
}
namespace torch {
namespace deploy {
diff --git a/torch/csrc/deploy/mem_file.h b/torch/csrc/deploy/mem_file.h
index c50889f..0c3273c 100644
--- a/torch/csrc/deploy/mem_file.h
+++ b/torch/csrc/deploy/mem_file.h
@@ -1,9 +1,9 @@
#pragma once
-#include <c10/util/Exception.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <sys/stat.h>
+#include <torch/csrc/deploy/Exception.h>
#include <unistd.h>
#include <cerrno>
#include <cstdio>
@@ -20,18 +20,21 @@
struct MemFile {
explicit MemFile(const char* filename_) : fd_(0), mem_(nullptr), n_bytes_(0) {
fd_ = open(filename_, O_RDONLY);
- TORCH_CHECK(fd_ != -1, "failed to open {}: {}", filename_, strerror(errno));
+ MULTIPY_CHECK(
+ fd_ != -1, "failed to open " + filename_ + ": " + strerror(errno));
// NOLINTNEXTLINE
struct stat s;
if (-1 == fstat(fd_, &s)) {
close(fd_); // destructors don't run during exceptions
- TORCH_CHECK(false, "failed to stat {}: {}", filename_, strerror(errno));
+ MULTIPY_CHECK(
+ false, "failed to stat " + filename_ + ": " + strerror(errno));
}
n_bytes_ = s.st_size;
mem_ = mmap(nullptr, n_bytes_, PROT_READ, MAP_SHARED, fd_, 0);
if (MAP_FAILED == mem_) {
close(fd_);
- TORCH_CHECK(false, "failed to mmap {}: {}", filename_, strerror(errno));
+ MULTIPY_CHECK(
+ false, "failed to mmap " + filename_ + ": " + strerror(errno));
}
}
MemFile(const MemFile&) = delete;
diff --git a/torch/csrc/deploy/test_deploy.cpp b/torch/csrc/deploy/test_deploy.cpp
index 840720c..b5bb96c 100644
--- a/torch/csrc/deploy/test_deploy.cpp
+++ b/torch/csrc/deploy/test_deploy.cpp
@@ -2,7 +2,6 @@
#include <gtest/gtest.h>
#include <cstring>
-#include <c10/util/irange.h>
#include <libgen.h>
#include <torch/csrc/deploy/deploy.h>
#include <torch/script.h>
@@ -182,13 +181,14 @@
auto obj = session1.fromMovable(replicatedObj);
// should throw an error when trying to access obj from different session
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- EXPECT_THROW(session2.createMovable(obj), c10::Error);
+ EXPECT_THROW(session2.createMovable(obj), std::runtime_error);
try {
session2.createMovable(obj);
- } catch (c10::Error& error) {
+ } catch (std::runtime_error& error) {
EXPECT_TRUE(
- error.msg().find(
- "Cannot create movable from an object that lives in different session") !=
+ std::string(error.what())
+ .find(
+ "Cannot create movable from an object that lives in different session") !=
std::string::npos);
}
}
@@ -197,15 +197,15 @@
// See explanation in deploy.h
torch::deploy::InterpreterManager manager(3);
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- EXPECT_THROW(manager.loadPackage("some garbage path"), c10::Error);
+ EXPECT_THROW(manager.loadPackage("some garbage path"), std::runtime_error);
torch::deploy::Package p = manager.loadPackage(path("SIMPLE", simple));
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- EXPECT_THROW(p.loadPickle("some other", "garbage path"), c10::Error);
+ EXPECT_THROW(p.loadPickle("some other", "garbage path"), std::runtime_error);
auto model = p.loadPickle("model", "model.pkl");
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- EXPECT_THROW(model(at::IValue("unexpected input")), c10::Error);
+ EXPECT_THROW(model(at::IValue("unexpected input")), std::runtime_error);
}
TEST(TorchpyTest, AcquireMultipleSessionsInTheSamePackage) {
@@ -238,7 +238,7 @@
auto t = obj.toIValue().toTensor();
// try to feed it to the other interpreter, should error
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- ASSERT_THROW(I1.global("torch", "sigmoid")({t}), c10::Error);
+ ASSERT_THROW(I1.global("torch", "sigmoid")({t}), std::runtime_error);
}
TEST(TorchpyTest, TaggingRace) {
@@ -259,7 +259,7 @@
try {
I.fromIValue(t);
success++;
- } catch (const c10::Error& e) {
+ } catch (const std::runtime_error& e) {
failed++;
}
}
@@ -279,7 +279,7 @@
torch::deploy::InterpreterManager m(1);
auto I = m.acquireOne();
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
- ASSERT_THROW(I.fromIValue(t), c10::Error); // NOT a segfault
+ ASSERT_THROW(I.fromIValue(t), std::runtime_error); // NOT a segfault
}
TEST(TorchpyTest, RegisterModule) {
@@ -379,7 +379,7 @@
try {
I.global("libtest_deploy_lib", "raise_exception")(no_args);
ASSERT_TRUE(false); // raise_exception did not throw?
- } catch (std::exception& err) {
+ } catch (std::runtime_error& err) {
ASSERT_TRUE(std::string(err.what()).find("yet") != std::string::npos);
}
in_another_module = 6;
diff --git a/torch/csrc/deploy/test_deploy_gpu.cpp b/torch/csrc/deploy/test_deploy_gpu.cpp
index 8fa154b..9cb6292 100644
--- a/torch/csrc/deploy/test_deploy_gpu.cpp
+++ b/torch/csrc/deploy/test_deploy_gpu.cpp
@@ -78,7 +78,7 @@
auto makeModel = p.loadPickle("make_trt_module", "model.pkl");
{
auto I = makeModel.acquireSession();
- auto model = I.self(at::ArrayRef<at::IValue>{});
+ auto model = I.self(c10::ArrayRef<at::IValue>{});
auto input = at::ones({1, 2, 3}).cuda();
auto output = input * 2;
ASSERT_TRUE(
@@ -91,7 +91,7 @@
#if HAS_NUMPY
TEST(TorchpyTest, TestNumpy) {
torch::deploy::InterpreterManager m(2);
- auto noArgs = at::ArrayRef<torch::deploy::Obj>();
+ auto noArgs = c10::ArrayRef<torch::deploy::Obj>();
auto I = m.acquireOne();
auto mat35 = I.global("numpy", "random").attr("rand")({3, 5});
auto mat58 = I.global("numpy", "random").attr("rand")({5, 8});
diff --git a/torch/csrc/deploy/unity/tests/test_unity_simple_model.cpp b/torch/csrc/deploy/unity/tests/test_unity_simple_model.cpp
index 3987340..57beb7e 100644
--- a/torch/csrc/deploy/unity/tests/test_unity_simple_model.cpp
+++ b/torch/csrc/deploy/unity/tests/test_unity_simple_model.cpp
@@ -18,7 +18,7 @@
auto I = m.acquireOne();
- auto noArgs = at::ArrayRef<Obj>();
+ auto noArgs = c10::ArrayRef<Obj>();
auto input = I.global("torch", "randn")({32, 256});
auto model = I.global("simple_model", "SimpleModel")(noArgs);