blob: 94f30c99175a0a8b9275bc30e0d27e45f248a7aa [file] [log] [blame]
#pragma once
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <assert.h>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/jit/serialization/import.h>
#include <fstream>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
/* Torch Deploy intentionally embeds multiple copies of c++ libraries
providing python bindings necessary for torch::deploy users in the same
process space in order to provide a multi-python environment. As a result,
any exception types defined by these duplicated libraries can't be safely
caught or handled outside of the originating dynamic library (.so).
In practice this means that you must either
catch these exceptions inside the torch::deploy API boundary or risk crashing
the client application.
It is safe to throw exception types that are defined once in
the context of the client application, such as c10::Error, which is defined
in libtorch, which isn't duplicated in torch::deploy interpreters.
==> Use TORCH_DEPLOY_TRY, _SAFE_CATCH_RETHROW around _ALL_ torch::deploy APIs
For more information, see
https://gcc.gnu.org/wiki/Visibility (section on c++ exceptions)
or https://stackoverflow.com/a/14364055
or
https://stackoverflow.com/questions/14268736/symbol-visibility-exceptions-runtime-error
note- this may be only a serious problem on versions of gcc prior to 4.0,
but still seems worth sealing off.
*/
#define TORCH_DEPLOY_TRY try {
#define TORCH_DEPLOY_SAFE_CATCH_RETHROW \
} \
catch (std::exception & err) { \
throw c10::Error( \
std::string( \
"Exception Caught inside torch::deploy embedded library: \n") + \
err.what(), \
""); \
} \
catch (...) { \
throw c10::Error( \
std::string( \
"Unknown Exception Caught inside torch::deploy embedded library"), \
""); \
}
namespace torch {
namespace deploy {
struct ReplicatedObj;
struct InterpreterManager;
struct TORCH_API InterpreterSession {
InterpreterSession(
InterpreterSessionImpl* impl,
InterpreterManager* manager) noexcept
: impl_(impl), manager_(manager) {}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Obj self; // when retreived from a PythonMovable this will be set.
InterpreterSession(InterpreterSession&&) noexcept = default;
~InterpreterSession();
Obj global(const char* module, const char* name) {
TORCH_DEPLOY_TRY
return impl_->global(module, name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Obj from_ivalue(at::IValue ivalue) {
TORCH_DEPLOY_TRY
return impl_->from_ivalue(std::move(ivalue));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
ReplicatedObj create_movable(Obj obj);
Obj from_movable(const ReplicatedObj& obj);
private:
friend struct ReplicatedObj;
friend struct Package;
friend struct InterpreterManager;
friend struct ReplicatedObjImpl;
std::unique_ptr<InterpreterSessionImpl> impl_;
InterpreterManager* manager_; // if created from one
int64_t notify_idx_ = -1;
};
class TORCH_API Interpreter {
private:
std::string library_name_;
void* handle_;
std::unique_ptr<InterpreterImpl> pImpl_;
InterpreterManager* manager_; // optional if managed by one
public:
Interpreter(InterpreterManager* manager);
InterpreterSession acquire_session() const {
TORCH_DEPLOY_TRY
return InterpreterSession(pImpl_->acquire_session(), manager_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
~Interpreter();
Interpreter(Interpreter&& rhs) noexcept
: library_name_(std::move(rhs.library_name_)),
handle_(rhs.handle_),
pImpl_(std::move(rhs.pImpl_)),
manager_(rhs.manager_) {
rhs.handle_ = nullptr;
}
Interpreter(const Interpreter&) = delete;
Interpreter& operator=(const Interpreter&) = delete;
Interpreter& operator=(Interpreter&&) = delete;
friend struct InterpreterManager;
};
struct Package;
struct TORCH_API LoadBalancer {
explicit LoadBalancer(size_t n)
: uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
TORCH_DEPLOY_TRY
// 8*... to avoid false sharing of atomics on the same cache line
memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void setResourceLimit(size_t n) {
TORCH_DEPLOY_TRY
TORCH_INTERNAL_ASSERT(n <= allocated_);
n_ = n;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
int acquire();
void free(int where);
private:
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<uint64_t[]>
uses_; // the approximate count of the number of users of interpreter
size_t allocated_;
size_t n_;
};
struct TORCH_API InterpreterManager {
InterpreterManager(size_t n_interp = 2) : resources_(n_interp) {
TORCH_DEPLOY_TRY
for (size_t i = 0; i < n_interp; ++i) {
instances_.emplace_back(this);
auto I = instances_.back().acquire_session();
// make torch.version.interp be the interpreter id
// can be used for balancing work across GPUs
I.global("torch", "version").attr("__setattr__")({"interp", int(i)});
// std::cerr << "Interpreter " << i << " initialized\n";
}
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// get a free model, guarenteed that no other user of acquire_one has the same
// model. It _is_ possible that other users will be using the interpreter.
InterpreterSession acquire_one() {
TORCH_DEPLOY_TRY
int where = resources_.acquire();
InterpreterSession I = instances_[where].acquire_session();
I.notify_idx_ = where;
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// use to make sure something gets run on all interpreters, such as loading or
// unloading a model eagerly
at::ArrayRef<Interpreter> all_instances() {
TORCH_DEPLOY_TRY
return instances_;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void debugLimitInterpreters(size_t N) {
TORCH_DEPLOY_TRY
AT_ASSERT(N <= instances_.size());
resources_.setResourceLimit(N);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package load_package(const std::string& uri);
Package load_package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader);
InterpreterManager(const InterpreterManager&) = delete;
InterpreterManager& operator=(const InterpreterManager&) = delete;
InterpreterManager& operator=(InterpreterManager&&) = delete;
private:
friend struct Package;
friend struct InterpreterSession;
size_t next_object_id_ = 0;
std::vector<Interpreter> instances_;
LoadBalancer resources_;
};
struct TORCH_API ReplicatedObjImpl {
ReplicatedObjImpl(
size_t object_id,
// NOLINTNEXTLINE(modernize-pass-by-value)
PickledObject data,
InterpreterManager* manager)
: object_id_(object_id), data_(data), manager_(manager) {}
~ReplicatedObjImpl();
void unload(const Interpreter* on_this_interpreter);
int64_t object_id_;
PickledObject data_;
InterpreterManager* manager_;
};
struct TORCH_API ReplicatedObj {
ReplicatedObj() : pImpl_(nullptr) {}
InterpreterSession acquire_session(
const Interpreter* on_this_interpreter = nullptr) const;
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self(args).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(args), std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void unload(const Interpreter* on_this_interpreter = nullptr);
private:
ReplicatedObj(std::shared_ptr<ReplicatedObjImpl> pImpl)
: pImpl_(std::move(pImpl)) {}
std::shared_ptr<ReplicatedObjImpl> pImpl_;
friend struct Package;
friend struct InterpreterSession;
friend struct InterpreterManager;
};
struct TORCH_API Package {
// shorthand for getting the object as a pickle resource in the package
ReplicatedObj load_pickle(
const std::string& module,
const std::string& file) {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto loaded = I.self.attr("load_pickle")({module, file});
return I.create_movable(loaded);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
std::string load_text(const std::string& module, const std::string& file) {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto loaded = I.self.attr("load_text")({module, file});
return loaded.toIValue().toStringRef();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
InterpreterSession acquire_session() {
TORCH_DEPLOY_TRY
auto I = manager_->acquire_one();
I.self = I.impl_->create_or_get_package_importer_from_container_file(
container_file_);
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
private:
Package(
const std::string& uri,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(uri)) {}
Package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(reader)) {}
friend struct ReplicatedObj;
friend struct InterpreterManager;
InterpreterManager* manager_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> container_file_;
};
} // namespace deploy
} // namespace torch