blob: 48bff73bd2fddb1d03608b9e7987ae4b25dff045 [file] [log] [blame]
#pragma once
// NOLINTNEXTLINE(modernize-deprecated-headers)
#include <assert.h>
#include <c10/util/irange.h>
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/jit/serialization/import.h>
#include <fstream>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
namespace torch {
namespace deploy {
struct ReplicatedObj;
struct InterpreterManager;
struct TORCH_API InterpreterSession {
InterpreterSession(
InterpreterSessionImpl* impl,
InterpreterManager* manager) noexcept
: impl_(impl), manager_(manager) {}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Obj self; // when retreived from a PythonMovable this will be set.
InterpreterSession(InterpreterSession&&) noexcept = default;
~InterpreterSession();
Obj global(const char* module, const char* name) {
TORCH_DEPLOY_TRY
return impl_->global(module, name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Obj from_ivalue(at::IValue ivalue) {
TORCH_DEPLOY_TRY
return impl_->from_ivalue(std::move(ivalue));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
ReplicatedObj create_movable(Obj obj);
Obj from_movable(const ReplicatedObj& obj);
private:
friend struct ReplicatedObj;
friend struct Package;
friend struct InterpreterManager;
friend struct ReplicatedObjImpl;
std::unique_ptr<InterpreterSessionImpl> impl_;
InterpreterManager* manager_; // if created from one
int64_t notify_idx_ = -1;
};
class TORCH_API Interpreter {
private:
std::string library_name_;
void* handle_;
std::unique_ptr<InterpreterImpl> pImpl_;
InterpreterManager* manager_; // optional if managed by one
public:
Interpreter(InterpreterManager* manager);
InterpreterSession acquire_session() const {
TORCH_DEPLOY_TRY
return InterpreterSession(pImpl_->acquire_session(), manager_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
~Interpreter();
Interpreter(Interpreter&& rhs) noexcept
: library_name_(std::move(rhs.library_name_)),
handle_(rhs.handle_),
pImpl_(std::move(rhs.pImpl_)),
manager_(rhs.manager_) {
rhs.handle_ = nullptr;
}
Interpreter(const Interpreter&) = delete;
Interpreter& operator=(const Interpreter&) = delete;
Interpreter& operator=(Interpreter&&) = delete;
friend struct InterpreterManager;
};
struct Package;
struct TORCH_API LoadBalancer {
explicit LoadBalancer(size_t n)
: uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
TORCH_DEPLOY_TRY
// 8*... to avoid false sharing of atomics on the same cache line
memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void setResourceLimit(size_t n) {
TORCH_DEPLOY_TRY
TORCH_INTERNAL_ASSERT(n <= allocated_);
n_ = n;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
int acquire();
void free(int where);
private:
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<uint64_t[]>
uses_; // the approximate count of the number of users of interpreter
size_t allocated_;
size_t n_;
};
struct TORCH_API InterpreterManager {
InterpreterManager(size_t n_interp = 2) : resources_(n_interp) {
TORCH_DEPLOY_TRY
for (const auto i : c10::irange(n_interp)) {
instances_.emplace_back(this);
auto I = instances_.back().acquire_session();
// make torch.version.interp be the interpreter id
// can be used for balancing work across GPUs
I.global("torch", "version").attr("__setattr__")({"interp", int(i)});
// std::cerr << "Interpreter " << i << " initialized\n";
instances_.back().pImpl_->set_find_module(
[this](const std::string& name) -> at::optional<std::string> {
auto it = registered_module_sources_.find(name);
if (it != registered_module_sources_.end()) {
return it->second;
} else {
return at::nullopt;
}
});
}
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// get a free model, guarenteed that no other user of acquire_one has the same
// model. It _is_ possible that other users will be using the interpreter.
InterpreterSession acquire_one() {
TORCH_DEPLOY_TRY
int where = resources_.acquire();
InterpreterSession I = instances_[where].acquire_session();
I.notify_idx_ = where;
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// use to make sure something gets run on all interpreters, such as loading or
// unloading a model eagerly
at::ArrayRef<Interpreter> all_instances() {
TORCH_DEPLOY_TRY
return instances_;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void debugLimitInterpreters(size_t N) {
TORCH_DEPLOY_TRY
AT_ASSERT(N <= instances_.size());
resources_.setResourceLimit(N);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package load_package(const std::string& uri);
Package load_package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader);
// convience function for loading some python source code as a module across
// all interpreters. this can be used for writing tests of deploy that need to
// execute python code, or for small amounts of application logic that are
// best written in Python. For larger amounts of code, prefer creating and
// loading them as packages.
void register_module_source(std::string name, std::string src) {
registered_module_sources_[std::move(name)] = std::move(src);
}
InterpreterManager(const InterpreterManager&) = delete;
InterpreterManager& operator=(const InterpreterManager&) = delete;
InterpreterManager& operator=(InterpreterManager&&) = delete;
private:
friend struct Package;
friend struct InterpreterSession;
size_t next_object_id_ = 0;
std::vector<Interpreter> instances_;
LoadBalancer resources_;
std::unordered_map<std::string, std::string> registered_module_sources_;
};
struct TORCH_API ReplicatedObjImpl {
ReplicatedObjImpl(
size_t object_id,
// NOLINTNEXTLINE(modernize-pass-by-value)
PickledObject data,
InterpreterManager* manager)
: object_id_(object_id), data_(data), manager_(manager) {}
~ReplicatedObjImpl();
void unload(const Interpreter* on_this_interpreter);
int64_t object_id_;
PickledObject data_;
InterpreterManager* manager_;
};
struct TORCH_API ReplicatedObj {
ReplicatedObj() : pImpl_(nullptr) {}
InterpreterSession acquire_session(
const Interpreter* on_this_interpreter = nullptr) const;
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self(args).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(args), std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue call_kwargs(
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquire_session();
return I.self.call_kwargs(std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void unload(const Interpreter* on_this_interpreter = nullptr);
private:
ReplicatedObj(std::shared_ptr<ReplicatedObjImpl> pImpl)
: pImpl_(std::move(pImpl)) {}
std::shared_ptr<ReplicatedObjImpl> pImpl_;
friend struct Package;
friend struct InterpreterSession;
friend struct InterpreterManager;
};
class PythonMethodWrapper : public torch::IMethod {
// PythonMethodWrapper is a more specific instance of a
// ReplicatedObj which represents a python method, and
// is therefore callable and has argument names accessible.
public:
PythonMethodWrapper(
torch::deploy::ReplicatedObj& model,
std::string method_name)
: model_(std::move(model)), method_name_(std::move(method_name)) {}
c10::IValue operator()(
std::vector<c10::IValue> args,
const IValueMap& kwargs = IValueMap()) override {
// TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
// this lookup each time
auto model_session = model_.acquire_session();
auto method = model_session.self.attr(method_name_.c_str());
return method.call_kwargs(args, kwargs).toIValue();
}
std::vector<std::string> getArgumentNames() override {
throw std::runtime_error("getArgumentNames not yet implemented");
}
private:
torch::deploy::ReplicatedObj model_;
std::string method_name_;
};
struct TORCH_API Package {
// shorthand for getting the object as a pickle resource in the package
ReplicatedObj load_pickle(
const std::string& module,
const std::string& file) {
TORCH_DEPLOY_TRY
auto I = acquire_session();
auto loaded = I.self.attr("load_pickle")({module, file});
return I.create_movable(loaded);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
InterpreterSession acquire_session() {
TORCH_DEPLOY_TRY
auto I = manager_->acquire_one();
I.self = I.impl_->create_or_get_package_importer_from_container_file(
container_file_);
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
private:
Package(
const std::string& uri,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(uri)) {}
Package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
container_file_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(reader)) {}
friend struct ReplicatedObj;
friend struct InterpreterManager;
InterpreterManager* manager_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> container_file_;
};
} // namespace deploy
} // namespace torch