blob: 9e5a7fd4571de8d0c33b4782f6168e3ec3c8d85d [file] [log] [blame]
#pragma once
#include <dlfcn.h>
#include <elf.h>
#include <torch/csrc/deploy/interpreter/Optional.hpp>
#include <memory>
namespace torch {
namespace deploy {
struct DeployLoaderError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct TLSIndex {
size_t module_id; // if module_id & TLS_LOCAL_FLAG, then module_id &
// ~TLS_LOCAL_FLAG is a TLSMemory*;
size_t offset;
};
struct SymbolProvider {
SymbolProvider() = default;
virtual multipy::optional<Elf64_Addr> sym(const char* name) const = 0;
virtual multipy::optional<TLSIndex> tls_sym(const char* name) const = 0;
SymbolProvider(const SymbolProvider&) = delete;
SymbolProvider& operator=(const SymbolProvider&) = delete;
virtual ~SymbolProvider() = default;
};
// RAII wrapper around dlopen
struct SystemLibrary : public SymbolProvider {
// create a wrapper around an existing handle returned from dlopen
// if steal == true, then this will dlclose the handle when it is destroyed.
static std::shared_ptr<SystemLibrary> create(
void* handle = RTLD_DEFAULT,
bool steal = false);
static std::shared_ptr<SystemLibrary> create(const char* path, int flags);
};
struct CustomLibrary : public SymbolProvider {
static std::shared_ptr<CustomLibrary> create(
const char* filename,
int argc = 0,
const char** argv = nullptr);
virtual void add_search_library(std::shared_ptr<SymbolProvider> lib) = 0;
virtual void load() = 0;
};
using SystemLibraryPtr = std::shared_ptr<SystemLibrary>;
using CustomLibraryPtr = std::shared_ptr<CustomLibrary>;
} // namespace deploy
} // namespace torch