blob: b986093ed020ade6050f6d4e701fed7e72d63718 [file] [log] [blame]
#pragma once
#include <c10/util/irange.h>
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/deploy/interpreter/Optional.hpp>
#include <torch/csrc/deploy/interpreter/interpreter_impl.h>
#include <torch/csrc/deploy/noop_environment.h>
#include <torch/csrc/jit/serialization/import.h>
#include <cassert>
#include <fstream>
#include <iostream>
#include <string>
#include <thread>
#include <vector>
namespace torch {
namespace deploy {
struct ReplicatedObj;
struct InterpreterManager;
struct TORCH_API InterpreterSession {
InterpreterSession(
InterpreterSessionImpl* impl,
InterpreterManager* manager) noexcept
: impl_(impl), manager_(manager) {}
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
Obj self; // when retreived from a PythonMovable this will be set.
InterpreterSession(InterpreterSession&&) noexcept = default;
// NOLINTNEXTLINE(bugprone-exception-escape)
~InterpreterSession();
Obj global(const char* module, const char* name) {
TORCH_DEPLOY_TRY
return impl_->global(module, name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Obj fromIValue(at::IValue ivalue) {
TORCH_DEPLOY_TRY
return impl_->fromIValue(std::move(ivalue));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
ReplicatedObj createMovable(Obj obj);
Obj fromMovable(const ReplicatedObj& obj);
private:
friend struct ReplicatedObj;
friend struct Package;
friend struct InterpreterManager;
friend struct ReplicatedObjImpl;
std::unique_ptr<InterpreterSessionImpl> impl_;
InterpreterManager* manager_; // if created from one
int64_t notifyIdx_ = -1;
};
class TORCH_API Interpreter {
private:
std::string libraryName_;
void* handle_;
std::unique_ptr<InterpreterImpl> pImpl_;
bool customLoader_ = false;
InterpreterManager* manager_; // optional if managed by one
std::shared_ptr<Environment> env_;
public:
Interpreter(InterpreterManager* manager, std::shared_ptr<Environment> env);
InterpreterSession acquireSession() const {
TORCH_DEPLOY_TRY
return InterpreterSession(pImpl_->acquireSession(), manager_);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
~Interpreter();
Interpreter(Interpreter&& rhs) noexcept
: libraryName_(std::move(rhs.libraryName_)),
handle_(rhs.handle_),
pImpl_(std::move(rhs.pImpl_)),
manager_(rhs.manager_) {
rhs.handle_ = nullptr;
}
Interpreter(const Interpreter&) = delete;
Interpreter& operator=(const Interpreter&) = delete;
Interpreter& operator=(Interpreter&&) = delete;
friend struct InterpreterManager;
};
struct Package;
struct TORCH_API LoadBalancer {
explicit LoadBalancer(size_t n)
: uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
TORCH_DEPLOY_TRY
// 8*... to avoid false sharing of atomics on the same cache line
memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void setResourceLimit(size_t n) {
TORCH_DEPLOY_TRY
MULTIPY_INTERNAL_ASSERT(n <= allocated_);
n_ = n;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
int acquire();
void free(int where);
private:
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
std::unique_ptr<uint64_t[]>
uses_; // the approximate count of the number of users of interpreter
size_t allocated_;
size_t n_;
};
struct TORCH_API InterpreterManager {
explicit InterpreterManager(
size_t nInterp = 2,
std::shared_ptr<Environment> env = std::make_shared<NoopEnvironment>());
// get a free model, guarenteed that no other user of acquireOne has the same
// model. It _is_ possible that other users will be using the interpreter.
InterpreterSession acquireOne() {
TORCH_DEPLOY_TRY
int where = resources_.acquire();
InterpreterSession I = instances_[where].acquireSession();
I.notifyIdx_ = where;
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
// use to make sure something gets run on all interpreters, such as loading or
// unloading a model eagerly
at::ArrayRef<Interpreter> allInstances() {
TORCH_DEPLOY_TRY
return instances_;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void debugLimitInterpreters(size_t N) {
TORCH_DEPLOY_TRY
AT_ASSERT(N <= instances_.size());
resources_.setResourceLimit(N);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
Package loadPackage(const std::string& uri);
Package loadPackage(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader);
// convience function for loading some python source code as a module across
// all interpreters. this can be used for writing tests of deploy that need to
// execute python code, or for small amounts of application logic that are
// best written in Python. For larger amounts of code, prefer creating and
// loading them as packages.
void registerModuleSource(std::string name, std::string src) {
registeredModuleSource_[std::move(name)] = std::move(src);
}
InterpreterManager(const InterpreterManager&) = delete;
InterpreterManager& operator=(const InterpreterManager&) = delete;
InterpreterManager& operator=(InterpreterManager&&) = delete;
private:
friend struct Package;
friend struct InterpreterSession;
size_t nextObjectId_ = 0;
std::vector<Interpreter> instances_;
LoadBalancer resources_;
std::unordered_map<std::string, std::string> registeredModuleSource_;
};
struct TORCH_API ReplicatedObjImpl {
ReplicatedObjImpl(
size_t object_id,
// NOLINTNEXTLINE(modernize-pass-by-value)
PickledObject data,
InterpreterManager* manager)
: objectId_(object_id), data_(data), manager_(manager) {}
// NOLINTNEXTLINE(bugprone-exception-escape)
~ReplicatedObjImpl();
void unload(const Interpreter* onThisInterpreter);
int64_t objectId_;
PickledObject data_;
InterpreterManager* manager_;
};
struct TORCH_API ReplicatedObj {
ReplicatedObj() : pImpl_(nullptr) {}
InterpreterSession acquireSession(
const Interpreter* onThisInterpreter = nullptr) const;
at::IValue operator()(at::ArrayRef<at::IValue> args) const {
TORCH_DEPLOY_TRY
auto I = acquireSession();
return I.self(args).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue callKwargs(
std::vector<at::IValue> args,
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquireSession();
return I.self.callKwargs(std::move(args), std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] at::IValue callKwargs(
std::unordered_map<std::string, c10::IValue> kwargs) const {
TORCH_DEPLOY_TRY
auto I = acquireSession();
return I.self.callKwargs(std::move(kwargs)).toIValue();
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
[[nodiscard]] bool hasattr(const char* name) const {
TORCH_DEPLOY_TRY
auto I = acquireSession();
return I.self.hasattr(name);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
void unload(const Interpreter* onThisInterpreter = nullptr);
private:
ReplicatedObj(std::shared_ptr<ReplicatedObjImpl> pImpl)
: pImpl_(std::move(pImpl)) {}
std::shared_ptr<ReplicatedObjImpl> pImpl_;
friend struct Package;
friend struct InterpreterSession;
friend struct InterpreterManager;
};
class PythonMethodWrapper : public torch::IMethod {
// PythonMethodWrapper is a more specific instance of a
// ReplicatedObj which represents a python method, and
// is therefore callable and has argument names accessible.
public:
// TODO(whc) make bound method pickleable, then directly construct from that
PythonMethodWrapper(
torch::deploy::ReplicatedObj model,
std::string methodName)
: model_(std::move(model)), methodName_(std::move(methodName)) {}
const std::string& name() const override {
return methodName_;
}
c10::IValue operator()(
std::vector<c10::IValue> args,
const IValueMap& kwargs = IValueMap()) const override {
// TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
// this lookup each time
auto modelSession = model_.acquireSession();
auto method = modelSession.self.attr(methodName_.c_str());
return method.callKwargs(args, kwargs).toIValue();
}
private:
void setArgumentNames(std::vector<std::string>&) const override;
torch::deploy::ReplicatedObj model_;
std::string methodName_;
};
struct TORCH_API Package {
// shorthand for getting the object as a pickle resource in the package
ReplicatedObj loadPickle(const std::string& module, const std::string& file) {
TORCH_DEPLOY_TRY
auto I = acquireSession();
auto loaded = I.self.attr("load_pickle")({module, file});
return I.createMovable(loaded);
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
InterpreterSession acquireSession() {
TORCH_DEPLOY_TRY
auto I = manager_->acquireOne();
I.self =
I.impl_->createOrGetPackageImporterFromContainerFile(containerFile_);
return I;
TORCH_DEPLOY_SAFE_CATCH_RETHROW
}
private:
Package(
const std::string& uri,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
containerFile_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(uri)) {}
Package(
std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader,
InterpreterManager*
pm) // or really any of the constructors to our zip file format
: manager_(pm),
containerFile_(
std::make_shared<caffe2::serialize::PyTorchStreamReader>(reader)) {}
friend struct ReplicatedObj;
friend struct InterpreterManager;
InterpreterManager* manager_;
std::shared_ptr<caffe2::serialize::PyTorchStreamReader> containerFile_;
};
} // namespace deploy
} // namespace torch