blob: 180aceb11eb5a37bde2e92c61fe0ee0b09ff54b5 [file] [log] [blame]
#include "pybind_state.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "caffe2/core/asan.h"
#include "caffe2/core/db.h"
#include "caffe2/core/predictor.h"
#include "caffe2/utils/mkl_utils.h"
namespace caffe2 {
namespace python {
namespace py = pybind11;
// gWorkspaces allows us to define and switch between multiple workspaces in
// Python.
static std::map<std::string, std::unique_ptr<Workspace>> gWorkspaces;
// gWorkspace is the pointer to the current workspace. The ownership is kept
// by the gWorkspaces map.
static Workspace* gWorkspace = nullptr;
static std::string gCurrentWorkspaceName;
BlobFetcherBase::~BlobFetcherBase() {}
BlobFeederBase::~BlobFeederBase() {}
CAFFE_DEFINE_TYPED_REGISTRY(BlobFetcherRegistry, CaffeTypeId, BlobFetcherBase);
CAFFE_DEFINE_TYPED_REGISTRY(BlobFeederRegistry, int, BlobFeederBase);
REGISTER_BLOB_FETCHER((TypeMeta::Id<TensorCPU>()), TensorFetcher<CPUContext>);
REGISTER_BLOB_FEEDER(CPU, TensorFeeder<CPUContext>);
class StringFetcher : public BlobFetcherBase {
public:
py::object Fetch(const Blob& blob) override {
return py::str(blob.Get<string>());
}
};
REGISTER_BLOB_FETCHER((TypeMeta::Id<string>()), StringFetcher);
static_assert(
sizeof(int) == sizeof(int32_t),
"We make an assumption that int is always int32 for numpy "
"type mapping.");
int CaffeToNumpyType(const TypeMeta& meta) {
static std::map<CaffeTypeId, int> numpy_type_map{
{TypeMeta::Id<bool>(), NPY_BOOL},
{TypeMeta::Id<double>(), NPY_DOUBLE},
{TypeMeta::Id<float>(), NPY_FLOAT},
{TypeMeta::Id<float16>(), NPY_FLOAT16},
{TypeMeta::Id<int>(), NPY_INT},
{TypeMeta::Id<int8_t>(), NPY_INT8},
{TypeMeta::Id<int16_t>(), NPY_INT16},
{TypeMeta::Id<int64_t>(), NPY_LONGLONG},
{TypeMeta::Id<uint8_t>(), NPY_UINT8},
{TypeMeta::Id<uint16_t>(), NPY_UINT16},
{TypeMeta::Id<std::string>(), NPY_OBJECT},
// Note: Add more types here.
};
const auto it = numpy_type_map.find(meta.id());
return it == numpy_type_map.end() ? -1 : it->second;
}
const TypeMeta& NumpyTypeToCaffe(int numpy_type) {
static std::map<int, TypeMeta> caffe_type_map{
{NPY_BOOL, TypeMeta::Make<bool>()},
{NPY_DOUBLE, TypeMeta::Make<double>()},
{NPY_FLOAT, TypeMeta::Make<float>()},
{NPY_FLOAT16, TypeMeta::Make<float16>()},
{NPY_INT, TypeMeta::Make<int>()},
{NPY_INT8, TypeMeta::Make<int8_t>()},
{NPY_INT16, TypeMeta::Make<int16_t>()},
{NPY_INT64, TypeMeta::Make<int64_t>()},
{NPY_LONG,
sizeof(long) == sizeof(int) ? TypeMeta::Make<int>()
: TypeMeta::Make<int64_t>()},
{NPY_LONGLONG, TypeMeta::Make<int64_t>()},
{NPY_UINT8, TypeMeta::Make<uint8_t>()},
{NPY_UINT16, TypeMeta::Make<uint16_t>()},
{NPY_OBJECT, TypeMeta::Make<std::string>()},
// Note: Add more types here.
};
static TypeMeta unknown_type;
const auto it = caffe_type_map.find(numpy_type);
return it == caffe_type_map.end() ? unknown_type : it->second;
}
template <typename Registry>
std::function<const char*(const string&)> DefinitionGetter(
const Registry* registry) {
return [registry](const string& name) { return registry->HelpMessage(name); };
}
void switchWorkspaceInternal(const std::string& name, bool create_if_missing) {
if (gWorkspaces.count(name)) {
gCurrentWorkspaceName = name;
gWorkspace = gWorkspaces[name].get();
return;
}
CAFFE_ENFORCE(create_if_missing);
std::unique_ptr<Workspace> new_workspace(new Workspace());
gWorkspace = new_workspace.get();
gWorkspaces.insert(std::make_pair(name, std::move(new_workspace)));
gCurrentWorkspaceName = name;
}
namespace python_detail {
// Python Op implementations.
struct Func {
py::object py_func;
bool needs_workspace;
};
using FuncRegistery = std::unordered_map<std::string, Func>;
FuncRegistery& gRegistery() {
// Always leak the objects registered here.
static FuncRegistery* r = new FuncRegistery();
return *r;
}
const Func& getOpFunc(const std::string& token) {
CAFFE_ENFORCE(
gRegistery().count(token),
"Python operator for ",
token,
" is not available. If you use distributed training it probably means "
"that python implementation has to be registered in each of the workers");
return gRegistery()[token];
}
const Func& getGradientFunc(const std::string& token) {
return getOpFunc(token + "_gradient");
}
}
bool PythonOpBase::RunOnDevice() {
std::vector<TensorCPU*> inputs;
inputs.reserve(InputSize());
for (auto i = 0; i < InputSize(); ++i) {
inputs.push_back(const_cast<TensorCPU*>(&Input(i)));
}
std::vector<TensorCPU*> outputs;
outputs.reserve(OutputSize());
for (auto i = 0; i < OutputSize(); ++i) {
outputs.push_back(Output(i));
}
auto& pyFunc = getFunc();
{
// Acquire GIL for call to Python runtime.
py::gil_scoped_acquire g;
try {
if (pyFunc.needs_workspace) {
pyFunc.py_func(inputs, outputs, ws_);
} else {
pyFunc.py_func(inputs, outputs);
}
} catch (const py::error_already_set& e) {
LOG(ERROR) << "Exception encountered running PythonOp function: "
<< e.what() << "\nTraceback: ";
PyObject *type = nullptr, *value = nullptr, *trace = nullptr;
PyErr_Fetch(&type, &value, &trace);
PyTracebackObject* traceback =
reinterpret_cast<PyTracebackObject*>(trace);
vector<PyTracebackObject*> trace_vec;
while (traceback) {
trace_vec.push_back(traceback);
traceback = traceback->tb_next;
}
for (int i = trace_vec.size() - 1; i >= 0; --i) {
int line = trace_vec[i]->tb_lineno;
const char* filename =
PyString_AsString(trace_vec[i]->tb_frame->f_code->co_filename);
const char* funcname =
PyString_AsString(trace_vec[i]->tb_frame->f_code->co_name);
LOG(ERROR) << " # " << trace_vec.size() - i - 1 << " " << filename
<< " (" << line << "): " << funcname;
}
Py_XDECREF(type);
Py_XDECREF(value);
Py_XDECREF(trace);
return false;
}
}
return true;
}
const python_detail::Func& PythonOp::getFunc() {
const std::string& token =
OperatorBase::GetSingleArgument<std::string>("token", "");
return python_detail::getOpFunc(token);
}
const python_detail::Func& PythonGradientOp::getFunc() {
const std::string& token =
OperatorBase::GetSingleArgument<std::string>("token", "");
return python_detail::getGradientFunc(token);
}
struct GetPythonGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
std::vector<OperatorDef> GetGradientDefs() override {
std::vector<std::string> gradientInputs;
for (int i = 0; i < def_.input_size(); ++i) {
gradientInputs.push_back(I(i));
}
for (int i = 0; i < def_.output_size(); ++i) {
gradientInputs.push_back(O(i));
}
for (int i = 0; i < def_.output_size(); ++i) {
gradientInputs.push_back(GO(i));
}
std::vector<std::string> gradientOutputs;
for (int i = 0; i < def_.input_size(); ++i) {
gradientOutputs.push_back(GI(i));
}
return SingleGradientDef(
"PythonGradient", "", gradientInputs, gradientOutputs);
}
};
REGISTER_CPU_OPERATOR(Python, PythonOp);
REGISTER_CPU_OPERATOR(PythonGradient, PythonGradientOp);
// Always allow running in-place
OPERATOR_SCHEMA(Python).AllowInplace([](int, int) { return true; });
OPERATOR_SCHEMA(PythonGradient).AllowInplace([](int, int) { return true; });
REGISTER_GRADIENT(Python, GetPythonGradient);
void addObjectMethods(py::module& m) {
py::class_<NetBase>(m, "Net").def("run", [](NetBase* net) {
py::gil_scoped_release g;
CAFFE_ENFORCE(net->Run());
});
py::class_<Blob>(m, "Blob")
.def(
"serialize",
[](const Blob& blob, const std::string& name) -> py::bytes {
return blob.Serialize(name);
})
.def(
"deserialize",
[](Blob* blob, py::bytes serialized) {
CAFFE_ENFORCE(blob->Deserialize(serialized));
})
.def(
"fetch",
[](const Blob& blob) {
auto fetcher = CreateFetcher(blob.meta().id());
CAFFE_ENFORCE(
fetcher,
"Could not fetch for blob of type: ",
blob.meta().name());
return fetcher->Fetch(blob);
})
.def(
"tensor",
[](Blob* blob) {
auto t = blob->GetMutable<TensorCPU>();
return py::cast(t, py::return_value_policy::reference_internal);
})
.def(
"_feed",
[](Blob* blob,
const py::object& arg,
const py::object device_option) {
DeviceOption option;
if (device_option != py::none()) {
// If we have a device option passed in, read it.
CAFFE_ENFORCE(option.ParseFromString(
py::bytes(device_option).cast<std::string>()));
}
if (PyArray_Check(arg.ptr())) { // numpy array
PyArrayObject* array =
reinterpret_cast<PyArrayObject*>(arg.ptr());
auto feeder = CreateFeeder(option.device_type());
CAFFE_ENFORCE(
feeder, "Unknown device type encountered in FeedBlob.");
feeder->Feed(option, array, blob);
return true;
}
if (PyString_Check(arg.ptr())) { // string
*blob->GetMutable<std::string>() = arg.cast<std::string>();
return true;
}
CAFFE_THROW(
"Unexpected type of argument - only numpy array or string are "
"supported for feeding");
},
"Feed an input array or string, with the (optional) DeviceOption",
py::arg("arg"),
py::arg("device_option") = py::none());
py::class_<TensorCPU>(m, "TensorCPU")
.def_property_readonly(
"data",
[](TensorCPU* t) -> py::object {
if (t->meta() == TypeMeta{}) {
// keep this behavior for backward compatibility
t->mutable_data<float>();
}
auto res = TensorFetcher<CPUContext>().FetchTensor(*t, false);
return res.obj;
},
"Return numpy array pointing to this tensor's data if possible. "
"Otherwise (e.g. for strings) copies the data (same as fetch).")
.def(
"feed",
[](TensorCPU* t, py::object obj) {
if (!PyArray_Check(obj.ptr())) {
CAFFE_THROW(
"Unexpected type of argument -- expected numpy array");
}
TensorFeeder<CPUContext>().FeedTensor(
DeviceOption{}, reinterpret_cast<PyArrayObject*>(obj.ptr()), t);
},
"Copy data from given numpy array into this tensor.")
.def(
"fetch",
[](TensorCPU* t) {
auto res = TensorFetcher<CPUContext>().FetchTensor(*t, true);
return res.obj;
},
"Copy data from this tensor into a new numpy array.")
.def(
"init",
[](TensorCPU* t, std::vector<TIndex> dims, int caffe_type) {
const auto& meta =
DataTypeToTypeMeta((TensorProto::DataType)caffe_type);
CAFFE_ENFORCE(
!TensorFetcher<CPUContext>().NeedsCopy(meta),
"Cannot init tensor of this type. Use `feed` instead.");
t->Resize(dims);
t->raw_mutable_data(meta);
},
"Initialize this tensor to given shape and data type. "
"Fail if the given data type cannot be accessed from python.")
.def_property_readonly(
"_shape", [](const TensorCPU& t) { return t.dims(); })
.def("_reshape", [](TensorCPU* t, std::vector<TIndex> dims) {
t->Resize(dims);
});
py::class_<Workspace>(m, "Workspace")
.def(py::init<>())
.def(py::init<Workspace*>())
.def_property_readonly(
"nets",
[](Workspace* self) {
CHECK_NOTNULL(self);
std::map<std::string, py::object> nets;
for (const auto& name : self->Nets()) {
LOG(INFO) << "name: " << name;
nets[name] = py::cast(
self->GetNet(name),
py::return_value_policy::reference_internal);
}
return nets;
})
.def_property_readonly(
"blobs",
[](Workspace* self) {
CHECK_NOTNULL(self);
std::map<std::string, py::object> blobs;
for (const auto& name : self->Blobs()) {
blobs[name] = py::cast(
self->GetBlob(name),
py::return_value_policy::reference_internal);
}
return blobs;
})
.def(
"_create_net",
[](Workspace* self, py::bytes def) -> py::object {
caffe2::NetDef proto;
CAFFE_ENFORCE(proto.ParseFromString(def));
auto* net = self->CreateNet(proto);
CAFFE_ENFORCE(net);
return py::cast(net, py::return_value_policy::reference_internal);
})
.def(
"create_blob",
[](Workspace* self, const std::string& name) -> py::object {
auto* blob = self->CreateBlob(name);
return py::cast(blob, py::return_value_policy::reference_internal);
})
.def(
"_run_net",
[](Workspace* self, py::bytes def) {
caffe2::NetDef proto;
CAFFE_ENFORCE(proto.ParseFromString(def));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunNetOnce(proto));
})
.def(
"_run_operator",
[](Workspace* self, py::bytes def) {
caffe2::OperatorDef proto;
CAFFE_ENFORCE(proto.ParseFromString(def));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunOperatorOnce(proto));
})
.def(
"_run_plan",
[](Workspace* self, py::bytes def) {
caffe2::PlanDef proto;
CAFFE_ENFORCE(proto.ParseFromString(def));
py::gil_scoped_release g;
CAFFE_ENFORCE(self->RunPlan(proto));
})
.def_property_readonly_static("current", [](py::object /* type */) {
auto ws = gWorkspaces.find(gCurrentWorkspaceName);
CAFFE_ENFORCE(ws != gWorkspaces.end());
CAFFE_ENFORCE(ws->second.get());
return py::cast(ws->second.get(), py::return_value_policy::reference);
});
// Gradients
py::class_<GradientWrapper>(m, "GradientWrapper")
.def(py::init<>())
.def_readwrite("dense", &GradientWrapper::dense_)
.def_readwrite("indices", &GradientWrapper::indices_)
.def_readwrite("values", &GradientWrapper::values_)
.def("is_sparse", &GradientWrapper::IsSparse)
.def("is_dense", &GradientWrapper::IsDense)
.def("is_empty", &GradientWrapper::IsEmpty);
m.def(
"get_gradient_defs",
[](const py::bytes& op_def,
std::vector<GradientWrapper> output_gradients) {
OperatorDef def;
CAFFE_ENFORCE(def.ParseFromString(op_def));
CAFFE_ENFORCE(caffe2::GradientRegistry()->Has(def.type()));
const auto& meta = GetGradientForOp(def, output_gradients);
std::vector<py::bytes> grad_ops;
for (const auto& op : meta.ops_) {
grad_ops.push_back(op.SerializeAsString());
}
return std::pair<std::vector<py::bytes>, std::vector<GradientWrapper>>{
grad_ops, meta.g_input_};
});
// DB
py::class_<db::Transaction>(m, "Transaction")
.def("put", &db::Transaction::Put)
.def("commit", &db::Transaction::Commit);
py::class_<db::Cursor>(m, "Cursor")
.def("supports_seek", &db::Cursor::SupportsSeek)
.def("seek_to_first", &db::Cursor::SeekToFirst)
.def("next", &db::Cursor::Next)
.def("key", [](db::Cursor* self) -> py::bytes { return self->key(); })
.def("value", [](db::Cursor* self) -> py::bytes { return self->value(); })
.def("valid", &db::Cursor::Valid);
py::enum_<db::Mode>(m, "Mode")
.value("read", db::Mode::READ)
.value("write", db::Mode::WRITE)
.value("new", db::Mode::NEW)
.export_values();
py::class_<db::DB /*, std::unique_ptr<DB>*/>(m, "DB")
.def("new_transaction", &db::DB::NewTransaction)
.def("new_cursor", &db::DB::NewCursor)
.def("close", &db::DB::Close);
m.def("create_db", &db::CreateDB);
// OpSchema
py::class_<OpSchema>(m, "OpSchema")
.def_property_readonly("file", &OpSchema::file)
.def_property_readonly("line", &OpSchema::line)
.def_property_readonly(
"doc", &OpSchema::doc, py::return_value_policy::reference)
.def_property_readonly("arg_desc", &OpSchema::arg_desc)
.def_property_readonly("input_desc", &OpSchema::input_desc)
.def_property_readonly("output_desc", &OpSchema::output_desc)
// Note: this does not work yet, we will need to figure out how to pass
// protobuf objects.
.def("infer_tensor", &OpSchema::InferTensor)
.def_static(
"get", &OpSchemaRegistry::Schema, py::return_value_policy::reference)
.def_static(
"get_cpu_impl",
DefinitionGetter(CPUOperatorRegistry()),
py::return_value_policy::reference)
.def_static(
"get_cuda_impl",
DefinitionGetter(CUDAOperatorRegistry()),
py::return_value_policy::reference)
.def_static(
"get_gradient_impl",
DefinitionGetter(GradientRegistry()),
py::return_value_policy::reference);
py::class_<Predictor>(m, "Predictor")
.def(
"__init__",
[](Predictor& instance, py::bytes init_net, py::bytes predict_net) {
NetDef init_net_, predict_net_;
CAFFE_ENFORCE(init_net_.ParseFromString(init_net));
CAFFE_ENFORCE(predict_net_.ParseFromString(predict_net));
new (&instance) Predictor(init_net_, predict_net_);
});
}
void addGlobalMethods(py::module& m) {
m.attr("is_asan") = py::bool_(CAFFE2_ASAN_ENABLED);
m.attr("has_mkldnn") = py::bool_(
#ifdef CAFFE2_HAS_MKL_DNN
true
#else // CAFFE2_HAS_MKL_DNN
false
#endif // CAFFE2_HAS_MKL_DNN
);
m.def("global_init", [](std::vector<std::string> args) -> void {
int argc = args.size();
std::vector<char*> argv;
for (auto& arg : args) {
argv.push_back(const_cast<char*>(arg.data()));
}
char** pargv = argv.data();
CAFFE_ENFORCE(caffe2::GlobalInit(&argc, &pargv));
});
m.def("registered_operators", []() {
std::set<string> all_keys;
// CPU operators
for (const auto& name : caffe2::CPUOperatorRegistry()->Keys()) {
all_keys.insert(name);
}
// CUDA operators
for (const auto& name : caffe2::CUDAOperatorRegistry()->Keys()) {
all_keys.insert(name);
}
// Ensure we are lexicographically ordered.
std::vector<std::string> keys;
for (const auto& key : all_keys) {
keys.push_back(key);
}
return keys;
});
m.def("on_module_exit", []() { gWorkspaces.clear(); });
// create_if_missing not used by necessary for pybind to do
// properly do function overloading.
m.def("switch_workspace", [](Workspace* ws, py::object create_if_missing) {
gWorkspace = ws;
});
m.def(
"switch_workspace",
[](const std::string& name, const py::object create_if_missing) {
if (create_if_missing == py::none()) {
return switchWorkspaceInternal(name, false);
}
return switchWorkspaceInternal(name, create_if_missing.cast<bool>());
},
"Switch to the specified workspace, creating if necessary",
py::arg("name"),
py::arg("create_if_missing") = py::none());
m.def(
"reset_workspace",
[](const py::object& root_folder) {
VLOG(1) << "Resetting workspace.";
if (root_folder == py::none()) {
gWorkspaces[gCurrentWorkspaceName].reset(new Workspace());
} else {
gWorkspaces[gCurrentWorkspaceName].reset(
new Workspace(root_folder.cast<std::string>()));
}
gWorkspace = gWorkspaces[gCurrentWorkspaceName].get();
return true;
},
"Reset the workspace",
py::arg("root_folder") = py::none());
m.def("root_folder", []() {
CAFFE_ENFORCE(gWorkspace);
return gWorkspace->RootFolder();
});
m.def("current_workspace", []() { return gCurrentWorkspaceName; });
m.def("workspaces", []() {
std::vector<std::string> names;
for (const auto& kv : gWorkspaces) {
names.push_back(kv.first);
}
return names;
});
m.def("local_blobs", []() {
CAFFE_ENFORCE(gWorkspace);
return gWorkspace->LocalBlobs();
});
m.def("blobs", []() {
CAFFE_ENFORCE(gWorkspace);
return gWorkspace->Blobs();
});
m.def("has_blob", [](const std::string& name) {
CAFFE_ENFORCE(gWorkspace);
return gWorkspace->HasBlob(name);
});
m.def("create_net", [](py::bytes net_def) {
caffe2::NetDef proto;
CAFFE_ENFORCE(
proto.ParseFromString(net_def),
"Can't parse net proto: ",
std::string(net_def));
CAFFE_ENFORCE(
gWorkspace->CreateNet(proto),
"Error creating net with proto: ",
std::string(net_def));
return true;
});
m.def("run_net", [](const std::string& name) {
CAFFE_ENFORCE(gWorkspace);
CAFFE_ENFORCE(gWorkspace->GetNet(name), "Can't find net ", name);
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunNet(name), "Error running net ", name);
return true;
});
m.def(
"benchmark_net",
[](const std::string& name,
size_t warmup_runs,
size_t main_runs,
bool run_individual) {
CAFFE_ENFORCE(gWorkspace);
auto* net = gWorkspace->GetNet(name);
CAFFE_ENFORCE(net);
py::gil_scoped_release g;
vector<float> stat =
net->TEST_Benchmark(warmup_runs, main_runs, run_individual);
return stat;
});
m.def("delete_net", [](const std::string& name) {
CAFFE_ENFORCE(gWorkspace);
gWorkspace->DeleteNet(name);
return true;
});
m.def("nets", []() { return gWorkspace->Nets(); });
m.def("run_operator_once", [](const py::bytes& op_def) {
CAFFE_ENFORCE(gWorkspace);
OperatorDef def;
CAFFE_ENFORCE(def.ParseFromString(op_def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunOperatorOnce(def));
return true;
});
m.def("run_net_once", [](const py::bytes& net_def) {
CAFFE_ENFORCE(gWorkspace);
NetDef def;
CAFFE_ENFORCE(def.ParseFromString(net_def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunNetOnce(def));
return true;
});
m.def("run_plan", [](const py::bytes& plan_def) {
CAFFE_ENFORCE(gWorkspace);
PlanDef def;
CAFFE_ENFORCE(def.ParseFromString(plan_def));
py::gil_scoped_release g;
CAFFE_ENFORCE(gWorkspace->RunPlan(def));
return true;
});
m.def("create_blob", [](const std::string& name) {
CAFFE_ENFORCE(gWorkspace);
CAFFE_ENFORCE(gWorkspace->CreateBlob(name));
return true;
});
m.def("fetch_blob", [](const std::string& name) -> py::object {
CAFFE_ENFORCE(gWorkspace->HasBlob(name), "Can't find blob: ", name);
const caffe2::Blob& blob = *(gWorkspace->GetBlob(name));
auto fetcher = CreateFetcher(blob.meta().id());
if (fetcher) {
return fetcher->Fetch(blob);
} else {
// If there is no fetcher registered, return a metainfo string.
// If all branches failed, we will return a metainfo string.
std::stringstream ss;
ss << caffe2::string(name) << ", a C++ native class of type "
<< blob.TypeName() << ".";
return py::str(ss.str());
}
});
m.def(
"feed_blob",
[](const std::string& name, py::object arg, py::object device_option) {
DeviceOption option;
if (device_option != py::none()) {
// If we have a device option passed in, read it.
CAFFE_ENFORCE(option.ParseFromString(
py::bytes(device_option).cast<std::string>()));
}
auto* blob = gWorkspace->CreateBlob(name);
if (PyArray_Check(arg.ptr())) { // numpy array
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(arg.ptr());
auto feeder = CreateFeeder(option.device_type());
CAFFE_ENFORCE(feeder, "Unknown device type encountered in FeedBlob.");
feeder->Feed(option, array, blob);
return true;
}
if (PyString_Check(arg.ptr())) { // string
*blob->GetMutable<std::string>() = arg.cast<std::string>();
return true;
}
CAFFE_THROW(
"Unexpected type of argument - only numpy array or string are "
"supported for feeding");
return false;
},
"",
py::arg("name"),
py::arg("arg"),
py::arg("device_option") = py::none());
m.def("serialize_blob", [](const std::string& name) {
CAFFE_ENFORCE(gWorkspace);
auto* blob = gWorkspace->GetBlob(name);
CAFFE_ENFORCE(blob);
return py::bytes(blob->Serialize(name));
});
m.def(
"deserialize_blob",
[](const std::string& name, const py::bytes& serialized) {
CAFFE_ENFORCE(gWorkspace);
auto* blob = gWorkspace->CreateBlob(name);
CAFFE_ENFORCE(blob->Deserialize(serialized.cast<std::string>()));
});
// we support 2 possible signatures of python op: (inputs, outputs) or
// (inputs, outputs, workspace)
m.def("register_python_op", [](py::object func, bool pass_workspace) {
using namespace python_detail;
CAFFE_ENFORCE(func != py::none());
const std::string name = func.attr("__name__").cast<std::string>();
// Unique name since registry is never cleared.
const std::string token = name + to_string(gRegistery().size());
CAFFE_ENFORCE(gRegistery().find(name) == gRegistery().end());
gRegistery()[token] = Func{func, pass_workspace};
return token;
});
m.def(
"register_python_gradient_op",
[](const std::string& token, py::object func) {
using namespace python_detail;
CAFFE_ENFORCE(func != py::none());
CAFFE_ENFORCE(gRegistery().find(token) != gRegistery().end());
// For global sanity gradient ops shouldn't access workspace
gRegistery()[token + "_gradient"] = Func{func, false};
});
#define CAFFE2_CPU_FEATURE_SUPPORT(feature) \
m.def("builtin_cpu_supports_" #feature, []() { \
return __builtin_cpu_supports(#feature); \
})
CAFFE2_CPU_FEATURE_SUPPORT(avx2);
#undef CAFFE2_CPU_FEATURE_SUPPORT
auto initialize = [&]() {
// Initialization of the module
([]() {
// This is a workaround so we can deal with numpy's import_array behavior.
// Despite the fact that you may think import_array() is a function call,
// it is defined as a macro (as of 1.10).
import_array();
})();
// Single threaded, so safe
static bool initialized = false;
if (initialized) {
return;
}
// We will create a default workspace for us to run stuff.
switchWorkspaceInternal("default", true);
gCurrentWorkspaceName = "default";
initialized = true;
};
initialize();
};
PYBIND11_PLUGIN(caffe2_pybind11_state) {
py::module m(
"caffe2_pybind11_state",
"pybind11 stateful interface to Caffe2 workspaces");
addGlobalMethods(m);
addObjectMethods(m);
return m.ptr();
}
} // namespace python
} // namespace caffe2