| #include <torch/csrc/jit/script/init.h> |
| |
| #include <torch/csrc/Device.h> |
| #include <torch/csrc/jit/import.h> |
| #include <torch/csrc/jit/script/compiler.h> |
| #include <torch/csrc/jit/script/module.h> |
| #include <torch/csrc/jit/script/module_python.h> |
| #include <torch/csrc/jit/script/python_sugared_value.h> |
| #include <torch/csrc/jit/script/sugared_value.h> |
| #include <torch/csrc/jit/testing/file_check.h> |
| |
| #include <torch/csrc/jit/constants.h> |
| #include <torch/csrc/jit/graph_executor.h> |
| #include <torch/csrc/jit/hooks_for_testing.h> |
| #include <torch/csrc/jit/import_source.h> |
| #include <torch/csrc/jit/irparser.h> |
| #include <torch/csrc/jit/passes/python_print.h> |
| #include <torch/csrc/jit/pybind_utils.h> |
| #include <torch/csrc/jit/python_tracer.h> |
| #include <torch/csrc/jit/script/logging.h> |
| #include <torch/csrc/jit/script/parser.h> |
| #include <torch/csrc/jit/tracer.h> |
| #include <torch/csrc/jit/export.h> |
| |
| #include <torch/csrc/api/include/torch/ordered_dict.h> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/core/function_schema.h> |
| #include <ATen/core/qualified_name.h> |
| |
| #include <pybind11/functional.h> |
| #include <pybind11/pybind11.h> |
| #include <pybind11/stl.h> |
| #include <pybind11/stl_bind.h> |
| #include <chrono> |
| #include <cstddef> |
| #include <memory> |
| #include <sstream> |
| #include <string> |
| #include <tuple> |
| #include <utility> |
| #include <vector> |
| |
| PYBIND11_MAKE_OPAQUE(torch::jit::script::ExtraFilesMap); |
| |
| namespace torch { |
| namespace jit { |
| namespace script { |
| |
| using ::c10::Argument; |
| using ::c10::FunctionSchema; |
| |
| using ResolutionCallback = std::function<py::function(std::string)>; |
| using FunctionDefaults = std::unordered_map<std::string, py::object>; |
| |
| namespace { |
| |
| // A resolver that will inspect the outer Python scope to find `name`. |
| struct PythonResolver : public Resolver { |
| explicit PythonResolver(ResolutionCallback rcb) : rcb_(std::move(rcb)) {} |
| |
| /** |
| * While compiling classes, the class type we're compiling will not be |
| * available in Python, since we haven't fowner_ defining the class yet. So |
| * in order to make the class type available to its own methods, we need to |
| * explicitly resolve it. |
| * |
| * @param rcb Python function to resolve a name to its Python object in the |
| * enclosing scope |
| * @param classname The unqualified classname of the class currently being |
| * compiled. |
| * @param classType The class's type. |
| */ |
| explicit PythonResolver( |
| ResolutionCallback rcb, |
| std::string classname, |
| ClassTypePtr classType) |
| : rcb_(std::move(rcb)), |
| classname_(std::move(classname)), |
| classType_(std::move(classType)) {} |
| |
| std::shared_ptr<SugaredValue> resolveValue( |
| const std::string& name, |
| Function& m, |
| const SourceRange& loc) override { |
| pybind11::gil_scoped_acquire ag; |
| py::object obj = rcb_(name); |
| if (obj.is(py::none())) { |
| return nullptr; |
| } |
| return toSugaredValue(obj, m, loc); |
| } |
| |
| static bool isNamedTupleClass(py::object obj) { |
| auto tuple_type = reinterpret_cast<PyObject*>(&PyTuple_Type); |
| return PyObject_IsSubclass(obj.ptr(), tuple_type) && |
| py::hasattr(obj, "_fields"); |
| } |
| |
| TypePtr resolveType(const std::string& name, const SourceRange& loc) |
| override { |
| if (classType_ && name == classname_) { |
| return classType_; |
| } |
| pybind11::gil_scoped_acquire ag; |
| py::object obj = rcb_(name); |
| if (obj.is(py::none())) { |
| return nullptr; |
| } |
| py::bool_ isClass = py::module::import("inspect").attr("isclass")(obj); |
| if (!py::cast<bool>(isClass)) { |
| return nullptr; |
| } |
| |
| auto qualifiedName = c10::QualifiedName(py::cast<std::string>( |
| py::module::import("torch.jit").attr("_qualified_name")(obj))); |
| |
| if (isNamedTupleClass(obj)) { |
| // Currently don't support default values |
| if (py::hasattr(obj, "_field_defaults")) { |
| auto default_dict = py::cast<std::map<std::string, py::object>>( |
| py::getattr(obj, "_field_defaults")); |
| if (default_dict.size()) { |
| std::string error_msg = |
| "Default values are currently not supported" |
| " on NamedTuple fields in TorchScript. Fields " |
| "with default values: ["; |
| bool first = true; |
| for (const auto& kv : default_dict) { |
| if (!first) { |
| error_msg += ", "; |
| } |
| error_msg += kv.first; |
| } |
| error_msg += "]"; |
| throw ErrorReport(loc) << error_msg; |
| } |
| } |
| |
| py::object props = py::module::import("torch.jit") |
| .attr("_get_named_tuple_properties")(obj); |
| std::string unqualName; |
| std::vector<std::string> fields; |
| std::vector<TypePtr> annotations; |
| std::tie(unqualName, fields, annotations) = py::cast< |
| std::tuple<std::string, decltype(fields), decltype(annotations)>>( |
| props); |
| |
| auto tt = TupleType::createNamed(qualifiedName, fields, annotations); |
| if (auto type = get_python_cu()->get_type(qualifiedName)) { |
| TORCH_CHECK( |
| type->isSubtypeOf(tt), |
| "Can't to redefine NamedTuple: ", |
| tt->python_str()); |
| return type; |
| } |
| get_python_cu()->register_type(tt); |
| return tt; |
| } |
| return get_python_cu()->get_type(qualifiedName); |
| } |
| |
| private: |
| ResolutionCallback rcb_; |
| std::string classname_; |
| ClassTypePtr classType_; |
| }; |
| |
| std::shared_ptr<PythonResolver> pythonResolver(ResolutionCallback rcb) { |
| return std::make_shared<PythonResolver>(rcb); |
| } |
| std::shared_ptr<PythonResolver> pythonResolver( |
| ResolutionCallback rcb, |
| std::string classname, |
| ClassTypePtr classType) { |
| return std::make_shared<PythonResolver>( |
| rcb, std::move(classname), std::move(classType)); |
| } |
| |
| void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) { |
| const auto& new_params = new_decl.params(); |
| const auto& old_params = old_decl.params(); |
| |
| // TODO. same number of parameters not strictly necessary. |
| TORCH_INTERNAL_ASSERT( |
| new_params.size() == old_params.size(), |
| "Overload must have same number of parameters\n", |
| new_decl.range(), |
| old_decl.range()); |
| for (size_t i = 0; i < new_decl.params().size(); ++i) { |
| TORCH_INTERNAL_ASSERT( |
| new_params[i].ident().name() == old_params[i].ident().name(), |
| "Overload parameters must have the same names\n", |
| new_params[i].ident(), |
| old_params[i].ident()); |
| } |
| } |
| |
| c10::optional<IValue> tryCalculateDefaultParam( |
| const Argument& arg, |
| const py::object& def_value) { |
| auto n = arg.N(); |
| auto list_type = arg.type()->cast<ListType>(); |
| try { |
| if (n && *n > 0 && list_type) { |
| // BroadcastingList, allow default values T for arg types List[T] |
| return toIValue(def_value, list_type->getElementType()); |
| } else { |
| return toIValue(def_value, arg.type()); |
| } |
| } catch (...) { |
| return c10::nullopt; |
| } |
| } |
| |
| // An overloaded function may have a default that does not subtype all overloads |
| // @overload |
| // def foo(x: str) |
| // def foo(x=1) |
| FunctionDefaults calcOverloadedFunctionDefaults( |
| const FunctionSchema& schema, |
| const FunctionDefaults& defaults) { |
| FunctionDefaults updated_defaults; |
| for (const auto& arg : schema.arguments()) { |
| const std::string& arg_name = arg.name(); |
| auto value = defaults.find(arg_name); |
| if (value == defaults.end()) { |
| continue; |
| } |
| auto maybe_ivalue = tryCalculateDefaultParam(arg, value->second); |
| if (maybe_ivalue) { |
| updated_defaults[arg_name] = value->second; |
| } |
| } |
| return updated_defaults; |
| } |
| |
| } // namespace |
| |
| bool checkMutableFunctionDefault(const py::object& def_arg) { |
| if (py::isinstance<py::list>(def_arg) || py::isinstance<py::dict>(def_arg)) { |
| return true; |
| } |
| if (py::isinstance<py::tuple>(def_arg)) { |
| auto pytuple = def_arg.cast<py::tuple>(); |
| for (py::handle t : pytuple) { |
| py::object obj = py::reinterpret_borrow<py::object>(t); |
| if (checkMutableFunctionDefault(obj)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| void checkMutableFunctionDefault( |
| const SourceRange& range, |
| const Argument& arg, |
| const py::object& def_arg) { |
| if (checkMutableFunctionDefault(def_arg) || arg.type()->cast<ClassType>()) { |
| throw ErrorReport(range) |
| << "Mutable default parameters are not supported because Python binds them to the function" |
| << " and they persist across function calls.\n As a workaround, make the default None and instantiate" |
| << " the default parameter within the body of the function. Found " |
| << def_arg.get_type() << " on parameter " << arg.name(); |
| } |
| } |
| |
| FunctionSchema getSchemaWithNameAndDefaults( |
| const SourceRange& range, |
| const FunctionSchema& schema, |
| const at::optional<std::string>& new_name, |
| const FunctionDefaults& default_args) { |
| std::vector<Argument> new_args; |
| for (auto& arg : schema.arguments()) { |
| auto it = default_args.find(arg.name()); |
| if (it != default_args.end()) { |
| checkMutableFunctionDefault(range, arg, it->second); |
| c10::optional<IValue> value = tryCalculateDefaultParam(arg, it->second); |
| if (!value) { |
| throw ErrorReport(range) |
| << "Expected a default value of type " << arg.type()->python_str() |
| << " on parameter \"" << arg.name() << "\""; |
| } |
| new_args.emplace_back( |
| arg.name(), arg.type(), arg.N(), *value, arg.kwarg_only()); |
| } else { |
| new_args.push_back(arg); |
| } |
| } |
| return FunctionSchema( |
| new_name.value_or(schema.name()), |
| schema.overload_name(), |
| new_args, |
| schema.returns(), |
| schema.is_vararg(), |
| schema.is_varret()); |
| } |
| |
| static Decl mergeDefaultsAndExtraParametersToOverloadDecl( |
| const Decl& overload_decl, |
| const Decl& impl_decl, |
| const FunctionDefaults& defaults) { |
| std::vector<Param> adjusted_params; |
| const auto& overload_params = overload_decl.params(); |
| const auto& impl_params = impl_decl.params(); |
| |
| // following PEP specification that the following should work: |
| // @overload |
| // def mouse_event(x1: int, y1: int) -> ClickEvent: ... |
| // ... |
| // def mouse_event(x1: int, y1: int, x2: Optional[int] = None, y2: |
| // Optional[int] = None) |
| TORCH_CHECK( |
| overload_params.size() <= impl_params.size(), |
| "Overload should not have more parameters than implementation function", |
| overload_decl.range(), |
| impl_decl.range()); |
| |
| for (size_t i = 0; i < overload_params.size(); ++i) { |
| auto overload_name = overload_params[i].ident().name(); |
| auto impl_name = impl_params[i].ident().name(); |
| if (overload_name != impl_name) { |
| throw ErrorReport(overload_decl.range()) |
| << "Overload parameters must have the same names. " |
| << "Found " << overload_name << " and " << impl_name |
| << " on argument " << i; |
| } |
| adjusted_params.push_back(overload_params[i]); |
| } |
| for (size_t i = overload_params.size(); i < impl_params.size(); ++i) { |
| if (!defaults.count(impl_params[i].ident().name())) { |
| throw ErrorReport(impl_decl.range()) |
| << "Expected to find default parameter on argument" |
| << impl_params[i].ident().name() |
| << " because it is not defined on the overloaded declaration"; |
| } |
| if (!impl_params[i].type().present()) { |
| throw ErrorReport(impl_decl.range()) |
| << "Parameters not specified on the overloaded declaration must have a type annotation in the implementation function." |
| << " Did not find type for param " << impl_params[i].ident().name(); |
| } |
| adjusted_params.push_back(impl_params[i]); |
| } |
| return Decl::create( |
| overload_decl.range(), |
| List<Param>::create(overload_decl.range(), adjusted_params), |
| overload_decl.return_type()); |
| } |
| |
| static StrongFunctionPtr script_compile_overloaded_function( |
| const c10::QualifiedName& name, |
| const Decl& overload_decl, |
| const Def& implementation_def, |
| ResolutionCallback rcb, |
| const FunctionDefaults& implementation_defaults, |
| const py::object& signature) { |
| if (signature.is(py::none())) { |
| throw ErrorReport(overload_decl.range()) |
| << "Must explicitly add type annotations to overloaded functions"; |
| } |
| |
| auto adjusted_decl = mergeDefaultsAndExtraParametersToOverloadDecl( |
| overload_decl, implementation_def.decl(), implementation_defaults); |
| auto new_def = implementation_def.withDecl(adjusted_decl); |
| auto cu = get_python_cu(); |
| auto defined_functions = cu->define( |
| QualifiedName(name.prefix()), |
| {new_def}, |
| {pythonResolver(std::move(rcb))}, |
| nullptr, |
| true); |
| TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); |
| auto& defined = defined_functions[0]; |
| FunctionDefaults updated_defaults = calcOverloadedFunctionDefaults( |
| defined->getSchema(), implementation_defaults); |
| defined->setSchema(getSchemaWithNameAndDefaults( |
| new_def.range(), |
| defined->getSchema(), |
| new_def.name().name(), |
| updated_defaults)); |
| StrongFunctionPtr ret(std::move(cu), defined); |
| didFinishEmitFunction(ret); |
| return ret; |
| } |
| |
| static StrongFunctionPtr script_compile_function( |
| const c10::QualifiedName& name, |
| const Def& def, |
| const FunctionDefaults& defaults, |
| ResolutionCallback rcb) { |
| auto cu = get_python_cu(); |
| auto defined_functions = cu->define( |
| QualifiedName(name.prefix()), |
| {def}, |
| {pythonResolver(std::move(rcb))}, |
| nullptr, |
| true); |
| TORCH_INTERNAL_ASSERT(defined_functions.size() == 1); |
| auto& defined = defined_functions[0]; |
| defined->setSchema(getSchemaWithNameAndDefaults( |
| def.range(), defined->getSchema(), def.name().name(), defaults)); |
| StrongFunctionPtr ret(std::move(cu), defined); |
| didFinishEmitFunction(ret); |
| return ret; |
| } |
| |
| struct VISIBILITY_HIDDEN ModuleSelf : public Self { |
| ModuleSelf(std::shared_ptr<ConcreteModuleType> concreteType) |
| : Self(), concreteType_(std::move(concreteType)) {} |
| |
| std::shared_ptr<SugaredValue> makeSugared(Value* v) const override { |
| v->setType(getClassType()); |
| return std::make_shared<ModuleValue>(v, concreteType_); |
| } |
| |
| ClassTypePtr getClassType() const override { |
| return concreteType_->getJitType()->expect<ClassType>(); |
| } |
| |
| private: |
| std::shared_ptr<ConcreteModuleType> concreteType_; |
| }; |
| |
| static TypePtr getTensorType(const at::Tensor& t, bool complete) { |
| auto r = TensorType::create(t); |
| if (!complete) { |
| r = r->dimensionedOnly(); |
| } |
| return r; |
| } |
| |
| static TupleTypePtr getTupleTensorType( |
| const Stack::const_iterator& s_iter, |
| const Stack::const_iterator& s_iter_end, |
| const TypePtr& tupleType, |
| bool complete) { |
| AT_ASSERT(tupleType->kind() == TupleType::Kind); |
| AT_ASSERT(s_iter != s_iter_end); |
| |
| std::vector<TypePtr> types; |
| for (const auto& subType : tupleType->containedTypes()) { |
| if (subType->kind() == TupleType::Kind) { |
| types.push_back( |
| getTupleTensorType(s_iter + 1, s_iter_end, subType, complete)); |
| } else { |
| types.push_back(getTensorType(s_iter->toTensor(), complete)); |
| } |
| } |
| return TupleType::create(types); |
| } |
| |
| static void setInputTensorTypes(Graph& g, const Stack& stack, bool complete) { |
| at::ArrayRef<Value*> input_values = g.inputs(); |
| auto s_iter = stack.begin(); |
| for (auto v : input_values) { |
| AT_ASSERT(s_iter != stack.end()); |
| if (v->type()->kind() == TupleType::Kind) { |
| AT_ASSERT(v->node()->kind() == prim::Param); |
| v->setType(getTupleTensorType(s_iter, stack.end(), v->type(), complete)); |
| } else { |
| v->setType(getTensorType(s_iter->toTensor(), complete)); |
| s_iter++; |
| } |
| } |
| } |
| |
| static std::shared_ptr<Graph> _propagate_shapes( |
| Graph& graph, |
| std::vector<at::Tensor> inputs, |
| bool with_grad = false) { |
| Stack stack(inputs.begin(), inputs.end()); |
| auto retval = graph.copy(); |
| setInputTensorTypes(*retval, stack, /*complete=*/false); |
| PropagateInputShapes(retval); |
| return retval; |
| } |
| |
| static std::shared_ptr<Graph> _propagate_and_assign_input_shapes( |
| Graph& graph, |
| const std::vector<at::Tensor>& inputs, |
| bool with_grad = false, |
| bool propagate = true) { |
| auto retval = graph.copy(); |
| setInputTensorTypes(*retval, fmap<IValue>(inputs), /*complete=*/true); |
| if (propagate) { |
| PropagateInputShapes(retval); |
| } |
| return retval; |
| } |
| |
| static std::shared_ptr<Graph> _assign_output_shapes( |
| Graph& graph, |
| std::vector<at::Tensor> outputs) { |
| auto retval = graph.copy(); |
| AT_ASSERT(retval->outputs().size() == outputs.size()); |
| for (size_t i = 0; i < outputs.size(); ++i) { |
| auto scalar_type = outputs[i].scalar_type(); |
| auto sizes = outputs[i].sizes(); |
| auto type = |
| torch::jit::TensorType::createContiguous(scalar_type, at::kCPU, sizes); |
| retval->outputs()[i]->setType(type); |
| } |
| return retval; |
| } |
| |
| void addFunctionToModule(Module& module, const StrongFunctionPtr& func) { |
| // Make a graph with a fake self argument |
| auto graph = func.function_->graph()->copy(); |
| auto v = graph->insertInput(0, "self"); |
| v->setType(module._ivalue()->type()); |
| const auto name = QualifiedName(*module.type()->name(), "forward"); |
| auto method = |
| module._ivalue()->compilation_unit()->create_function(name, graph); |
| module.type()->addMethod(method); |
| } |
| |
| // this is used in our test suite to check that we correctly preserved type tags |
| bool ivalue_tags_match(const Module& lhs, const Module& rhs) { |
| struct Work { |
| IValue a; |
| IValue b; |
| }; |
| std::unordered_set<const void*> visited; |
| std::vector<Work> work = {{lhs._ivalue(), rhs._ivalue()}}; |
| while (!work.empty()) { |
| Work item = work.back(); |
| work.pop_back(); |
| if (item.a.isPtrType()) { |
| // uncomment to debug type matching errors |
| // std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") " |
| // << item.a.internalToPointer() << " " << /*item.b <<*/ " (" |
| // << *item.b.type() << ") " << item.b.internalToPointer() << |
| // "\n"; |
| |
| if (visited.count(item.a.internalToPointer())) { |
| continue; |
| } |
| visited.emplace(item.a.internalToPointer()); |
| } |
| if (*unshapedType(item.a.type()) != *unshapedType(item.b.type())) { |
| return false; |
| } |
| // check tags for objects that contain subobjects |
| if (item.a.isObject()) { |
| auto ao = item.a.toObject(); |
| auto bo = item.b.toObject(); |
| for (size_t i = 0; i < ao->slots().size(); ++i) { |
| work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)}); |
| } |
| } else if (item.a.isTuple()) { |
| auto at = item.a.toTuple(); |
| auto bt = item.b.toTuple(); |
| for (size_t i = 0; i < at->elements().size(); ++i) { |
| work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)}); |
| } |
| } else if (item.a.isList()) { |
| auto al = item.a.toList(); |
| auto bl = item.b.toList(); |
| for (size_t i = 0; i < al.size(); ++i) { |
| work.emplace_back(Work{al.get(i), bl.get(i)}); |
| } |
| } else if (item.a.isGenericDict()) { |
| auto ad = item.a.toGenericDict(); |
| auto bd = item.b.toGenericDict(); |
| for (auto& item : ad) { |
| // Dictionaory keys cannot contain List/Dicts that require tags |
| // so we do not have to check them. |
| // Furthermore without ordered dicts it is expensive to find the |
| // equivalent key |
| work.emplace_back(Work{item.value(), bd.at(item.key())}); |
| } |
| } else if (item.a.isFuture()) { |
| auto af = item.a.toFuture(); |
| auto bf = item.b.toFuture(); |
| af->wait(); |
| bf->wait(); |
| work.emplace_back(Work{af->value(), bf->value()}); |
| } |
| } |
| |
| return true; |
| } |
| |
| // helper used to implement ._parameters, ._buffers, ._modules dicts |
| // inside of script nn.Module |
| template <typename Policy> |
| struct slot_dict_impl { |
| slot_dict_impl(script::ModulePtr module) : module_(std::move(module)) {} |
| bool contains(const std::string& name) const { |
| if (auto slot = module_->type()->findAttributeSlot(name)) { |
| if (Policy::valid(module_->type(), *slot)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| std::vector<std::pair<std::string, py::object>> items() const { |
| std::vector<std::pair<std::string, py::object>> result; |
| for (size_t i = 0, N = module_->type()->numAttributes(); i < N; ++i) { |
| if (Policy::valid(module_->type(), i)) { |
| result.emplace_back( |
| module_->type()->getAttributeName(i), |
| toPyObject(module_->getSlot(i))); |
| } |
| } |
| return result; |
| } |
| |
| void setattr(const std::string& name, py::object value) { |
| const TypePtr& type = module_->type()->getAttribute(name); |
| script::Module(module_).setattr(name, toIValue(std::move(value), type)); |
| } |
| |
| py::object getattr(const std::string& name) { |
| return toPyObject(script::Module(module_).attr(name)); |
| } |
| |
| static void bind(const py::module& m, const char* name) { |
| py::class_<slot_dict_impl<Policy>>(m, name) |
| .def(py::init( |
| [](Module& m) { return slot_dict_impl<Policy>(m._ivalue()); })) |
| .def("contains", &slot_dict_impl<Policy>::contains) |
| .def("items", &slot_dict_impl<Policy>::items) |
| .def("setattr", &slot_dict_impl<Policy>::setattr) |
| .def("getattr", &slot_dict_impl<Policy>::getattr); |
| } |
| private: |
| script::ModulePtr module_; |
| }; |
| |
| template <typename T> |
| py::list debugMakeList(const T& list) { |
| py::list result; |
| for (auto elem : list) { |
| result.append(py::cast(elem)); |
| } |
| return result; |
| } |
| template <typename T> |
| py::list debugMakeNamedList(const T& list) { |
| py::list result; |
| for (auto elem : list) { |
| result.append(py::cast(std::make_pair(elem.name, elem.value))); |
| } |
| return result; |
| } |
| |
| static py::dict _jit_debug_module_iterators(Module& module) { |
| py::dict result; |
| result["children"] = debugMakeList(module.children()); |
| result["named_children"] = debugMakeNamedList(module.named_children()); |
| result["modules"] = debugMakeList(module.modules()); |
| result["named_modules"] = debugMakeNamedList(module.named_modules()); |
| |
| result["parameters"] = debugMakeList(module.parameters(false)); |
| result["named_parameters"] = |
| debugMakeNamedList(module.named_parameters(false)); |
| result["parameters_r"] = debugMakeList(module.parameters(true)); |
| result["named_parameters_r"] = |
| debugMakeNamedList(module.named_parameters(true)); |
| |
| result["buffers"] = debugMakeList(module.buffers(false)); |
| result["named_buffers"] = debugMakeNamedList(module.named_buffers(false)); |
| result["buffers_r"] = debugMakeList(module.buffers(true)); |
| result["named_buffers_r"] = debugMakeNamedList(module.named_buffers(true)); |
| |
| result["named_attributes"] = |
| debugMakeNamedList(module.named_attributes(false)); |
| result["named_attributes_r"] = |
| debugMakeNamedList(module.named_attributes(true)); |
| return result; |
| } |
| |
| |
| void initJitScriptBindings(PyObject* module) { |
| auto m = py::handle(module).cast<py::module>(); |
| |
| // STL containers are not mutable by default and hence we need to bind as |
| // follows. |
| py::bind_map<ExtraFilesMap>(m, "ExtraFilesMap"); |
| |
| py::class_<Object>(m, "ScriptObject") |
| .def("_type", [](Module& m) { return m.type(); }) |
| .def( |
| "_get_method", |
| [](Object& self, const std::string& name) -> Method { |
| return self.get_method(name); |
| }, |
| py::keep_alive<0, 1>()) |
| .def( |
| "setattr", |
| [](Object& self, const std::string& name, py::object value) { |
| if (self.type()->hasConstant(name)) { |
| TORCH_CHECK( |
| false, |
| "Can't set constant '", |
| name, |
| "' which has value:", |
| self.type()->getConstant(name)); |
| } |
| TypePtr type = self.type()->getAttribute(name); |
| auto ivalue = toIValue(std::move(value), type); |
| self.setattr(name, ivalue); |
| }) |
| .def( |
| "getattr", |
| [](Object& self, const std::string& name) { |
| return toPyObject(self.attr(name)); |
| }) |
| .def( |
| "__getattr__", |
| [](Object& self, const std::string& name) { |
| if (auto method = self.find_method(name)) { |
| return py::cast(*method); |
| } |
| return toPyObject(self.attr(name)); |
| }) |
| .def( |
| "hasattr", |
| [](Object& self, const std::string& name) { |
| return self.hasattr(name); |
| }) |
| .def( |
| "_has_method", |
| [](Object& self, const std::string& name) { |
| return bool(self.find_method(name)); |
| }) |
| .def( |
| "_method_names", [](Object& self) { |
| return fmap(self.get_methods(), [](const Method& method) { |
| return method.name(); |
| }); |
| }); |
| |
| // torch.jit.ScriptModule is a subclass of this C++ object. |
| // Methods here are prefixed with _ since they should not be |
| // public. |
| py::class_<Module, Object>(m, "ScriptModule") |
| .def(py::init<std::string, std::shared_ptr<CompilationUnit>, bool>()) |
| .def( |
| "save", |
| [](Module& m, |
| const std::string& filename, |
| const ExtraFilesMap& _extra_files = ExtraFilesMap()) { |
| m.save(filename, _extra_files); |
| }, |
| py::arg("filename"), |
| py::arg("_extra_files") = ExtraFilesMap()) |
| .def( |
| "save_to_buffer", |
| [](Module& m, const ExtraFilesMap& _extra_files = ExtraFilesMap()) { |
| std::ostringstream buf; |
| m.save(buf, _extra_files); |
| return py::bytes(buf.str()); |
| }, |
| py::arg("_extra_files") = ExtraFilesMap()) |
| .def("_set_optimized", &Module::set_optimized) |
| .def( |
| "dump", |
| &Module::dump, |
| py::arg("code") = true, |
| py::arg("attrs") = true, |
| py::arg("params") = true) |
| .def( |
| "dump_to_str", |
| &Module::dump_to_str, |
| py::arg("code") = true, |
| py::arg("attrs") = true, |
| py::arg("params") = true, |
| py::arg("indent") = 0) |
| .def( |
| "_replicate_for_data_parallel", |
| [](Module& module) { |
| const ModulePtr& obj = module._ivalue(); |
| auto copy = c10::ivalue::Object::create( |
| c10::StrongTypePtr(obj->compilation_unit(), obj->type()), |
| obj->slots().size()); |
| for (size_t i = 0; i < obj->slots().size(); ++i) { |
| copy->setSlot(i, obj->getSlot(i)); |
| } |
| return Module(std::move(copy)); |
| }) |
| .def( |
| "get_debug_state", |
| [](Module& self) { |
| if (auto m = self.find_method("forward")) { |
| return m->get_executor().getDebugState(); |
| } |
| throw std::runtime_error( |
| "Attempted to call get_debug_state on a Module without a compiled forward()"); |
| }) |
| .def( |
| "_define", |
| [](Module& m, |
| std::shared_ptr<ConcreteModuleType> concreteType, |
| const std::string& script, |
| ResolutionCallback rcb) { |
| const auto self = ModuleSelf(std::move(concreteType)); |
| m._ivalue()->compilation_unit()->define( |
| *m.type()->name(), script, pythonResolver(rcb), &self); |
| didFinishEmitModule(m); |
| }) |
| .def( |
| "_create_method_from_trace", |
| [](Module& self, |
| const std::string& name, |
| py::function func, |
| py::tuple input_tuple, |
| py::function var_lookup_fn, |
| bool force_outplace) { |
| // prereq: Module's buffers and parameters are unique |
| // this was ensured in python before calling this function |
| auto typed_inputs = toTraceableStack(input_tuple); |
| |
| std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing( |
| func, typed_inputs, var_lookup_fn, force_outplace, &self)); |
| const auto method_name = QualifiedName(*self.type()->name(), name); |
| auto fn = self._ivalue()->compilation_unit()->create_function( |
| method_name, graph); |
| self.type()->addMethod(fn); |
| didFinishEmitModule(self); |
| }) |
| .def_property_readonly( |
| "code", |
| [](Module& self) { |
| std::vector<at::Tensor> tensors; |
| std::vector<c10::NamedTypePtr> deps; |
| PythonPrint pp(tensors, deps, false); |
| pp.printNamedType(self.type()); |
| return pp.str(); |
| }) |
| .def("apply", &Module::apply) |
| .def("_clone", &Module::clone) |
| .def("_clone_instance", &Module::clone_instance); |
| |
| slot_dict_impl<script::detail::ParameterPolicy>::bind(m, "ParameterDict"); |
| slot_dict_impl<script::detail::BufferPolicy>::bind(m, "BufferDict"); |
| slot_dict_impl<script::detail::ModulePolicy>::bind(m, "ModuleDict"); |
| |
| py::class_<ErrorReport, std::shared_ptr<ErrorReport>>(m, "ErrorReport") |
| .def(py::init<SourceRange>()) |
| .def("what", &ErrorReport::what); |
| |
| py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>( |
| m, "CompilationUnit") |
| .def(py::init<>()) |
| .def( |
| "find_function", |
| [](std::shared_ptr<CompilationUnit> self, const std::string& name) { |
| auto& fn = self->get_function(QualifiedName(name)); |
| return StrongFunctionPtr(std::move(self), &fn); |
| }) |
| .def("set_optimized", &CompilationUnit::set_optimized) |
| .def( |
| "define", |
| [](CompilationUnit& cu, |
| const std::string& src, |
| ResolutionCallback rcb) { |
| cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr); |
| }); |
| |
| py::class_<StrongFunctionPtr>(m, "ScriptFunction", py::dynamic_attr()) |
| .def( |
| "__call__", |
| [](py::args args, py::kwargs kwargs) { |
| HANDLE_TH_ERRORS |
| // see: [pybind11 varargs] |
| auto strongPtr = py::cast<StrongFunctionPtr>(args[0]); |
| Function& callee = *strongPtr.function_; |
| bool tracing = tracer::isTracing(); |
| py::object result = invokeScriptFunctionFromPython( |
| callee, tuple_slice(std::move(args), 1), std::move(kwargs)); |
| return result; |
| END_HANDLE_TH_ERRORS_PYBIND |
| }) |
| .def( |
| "save", |
| [](const StrongFunctionPtr& self, |
| const std::string& filename, |
| const ExtraFilesMap& _extra_files = ExtraFilesMap()) { |
| Module module("__torch__.PlaceholderModule"); |
| // [issue 27343] |
| // Modules have 'training' attributes by default, but due to |
| // https://github.com/pytorch/pytorch/issues/27343, functions end |
| // up having a training attribute when they are loaded. This adds |
| // a fake 'training' attribute that shouldn't be used, but prevents |
| // jitter on saving and loading. Once that issue is fixed this can |
| // be deleted. |
| module.register_attribute("training", BoolType::get(), true); |
| addFunctionToModule(module, self); |
| module.save(filename, _extra_files); |
| }, |
| py::arg("filename"), |
| py::arg("_extra_files") = ExtraFilesMap()) |
| .def( |
| "save_to_buffer", |
| [](const StrongFunctionPtr& self, |
| const ExtraFilesMap& _extra_files = ExtraFilesMap()) { |
| std::ostringstream buf; |
| Module module("__torch__.PlaceholderModule"); |
| // see [issue 27343] |
| module.register_attribute("training", BoolType::get(), true); |
| addFunctionToModule(module, self); |
| module.save(buf, _extra_files); |
| return py::bytes(buf.str()); |
| }, |
| py::arg("_extra_files") = ExtraFilesMap()) |
| .def_property_readonly( |
| "graph", |
| [](const StrongFunctionPtr& self) { return self.function_->graph(); }) |
| .def_property_readonly( |
| "schema", |
| [](const StrongFunctionPtr& self) { |
| return self.function_->getSchema(); |
| }) |
| .def_property_readonly( |
| "code", |
| [](const StrongFunctionPtr& self) { |
| std::vector<at::Tensor> tensors; |
| std::vector<c10::NamedTypePtr> deps; |
| PythonPrint pp(tensors, deps, false); |
| pp.printFunction(*self.function_); |
| return pp.str(); |
| }) |
| .def( |
| "get_debug_state", |
| [](const StrongFunctionPtr& self) { |
| return self.function_->get_executor().getDebugState(); |
| }) |
| .def_property_readonly( |
| "name", |
| [](const StrongFunctionPtr& self) { return self.function_->name(); }) |
| .def_property_readonly( |
| "qualified_name", [](const StrongFunctionPtr& self) { |
| return self.function_->qualname().qualifiedName(); |
| }); |
| |
| py::class_<Method>(m, "ScriptMethod", py::dynamic_attr()) |
| .def( |
| "__call__", |
| [](py::args args, py::kwargs kwargs) { |
| // see: [pybind11 varargs] |
| Method& method = py::cast<Method&>(args[0]); |
| return invokeScriptMethodFromPython( |
| method, tuple_slice(std::move(args), 1), std::move(kwargs)); |
| }) |
| .def_property_readonly("graph", &Method::graph) |
| .def_property_readonly( |
| "schema", [](Method& m) { return m.function().getSchema(); }) |
| .def_property_readonly("name", &Method::name) |
| .def_property_readonly("code", [](Method& self) { |
| std::vector<at::Tensor> tensors; |
| std::vector<c10::NamedTypePtr> deps; |
| PythonPrint pp(tensors, deps, false); |
| pp.printMethod(self.function()); |
| return pp.str(); |
| }); |
| m.def( |
| "_jit_script_compile", |
| [](const std::string& qualname, |
| const Def& def, |
| ResolutionCallback rcb, |
| const FunctionDefaults& defaults) { |
| C10_LOG_API_USAGE_ONCE("torch.script.compile"); |
| const auto name = c10::QualifiedName(qualname); |
| TORCH_INTERNAL_ASSERT(name.name() == def.name().name()); |
| return script_compile_function(name, def, defaults, std::move(rcb)); |
| }); |
| m.def( |
| "_jit_script_compile_overload", |
| [](const std::string& qualname, |
| const Decl& overload_decl, |
| const Def& implementation_def, |
| ResolutionCallback rcb, |
| const FunctionDefaults& implementation_defaults, |
| const py::object& signature) { |
| const auto name = c10::QualifiedName(qualname); |
| return script_compile_overloaded_function( |
| name, |
| overload_decl, |
| implementation_def, |
| std::move(rcb), |
| implementation_defaults, |
| signature); |
| }); |
| m.def( |
| "_replace_overloaded_method_decl", |
| [](const Decl& overload_decl, |
| const Def& implementation_def, |
| const std::string& new_name) { |
| checkOverloadDecl(overload_decl, implementation_def.decl()); |
| return implementation_def.withDecl(overload_decl).withName(new_name); |
| }); |
| m.def( |
| "_create_function_from_trace", |
| [](std::string qualname, |
| py::function func, |
| py::tuple input_tuple, |
| py::function var_lookup_fn, |
| bool force_outplace) { |
| auto typed_inputs = toTraceableStack(input_tuple); |
| std::shared_ptr<Graph> graph = std::get<0>(tracer::createGraphByTracing( |
| func, typed_inputs, var_lookup_fn, force_outplace)); |
| auto cu = get_python_cu(); |
| auto name = c10::QualifiedName(qualname); |
| auto result = cu->create_function( |
| std::move(name), std::move(graph), /*shouldMangle=*/true); |
| StrongFunctionPtr ret(std::move(cu), result); |
| didFinishEmitFunction(ret); |
| return ret; |
| }); |
| |
| m.def( |
| "_jit_script_class_compile", |
| [](const std::string& qualifiedName, |
| const ClassDef& classDef, |
| ResolutionCallback rcb) { |
| C10_LOG_API_USAGE_ONCE("torch.script.class"); |
| if (classDef.superclass().present()) { |
| throw ErrorReport(classDef.range()) |
| << "Torchscript does not support class inheritance."; |
| } |
| auto cu = get_python_cu(); |
| const auto classname = c10::QualifiedName(qualifiedName); |
| auto classType = ClassType::create(classname, cu); |
| cu->register_type(classType); |
| std::vector<ResolverPtr> rcbs; |
| std::vector<Def> methodDefs; |
| for (const auto& def : classDef.body()) { |
| if (def.kind() != TK_DEF) { |
| throw ErrorReport(def.range()) |
| << "Currently class bodies can only contain method " |
| "definitions. File an issue on Github if you want " |
| "something else!"; |
| } |
| methodDefs.emplace_back(Def(def)); |
| rcbs.push_back( |
| pythonResolver(rcb, classDef.name().name(), classType)); |
| } |
| const auto self = SimpleSelf(classType); |
| cu->define(classname, methodDefs, rcbs, &self); |
| }); |
| m.def( |
| "_jit_script_interface_compile", |
| [](const std::string& qualifiedName, |
| const ClassDef& classDef, |
| ResolutionCallback rcb, |
| bool is_module) { |
| get_python_cu()->define_interface( |
| c10::QualifiedName(qualifiedName), |
| classDef, |
| pythonResolver(std::move(rcb)), |
| is_module); |
| }); |
| |
| m.def("_parse_source_def", [](const std::string& src) { |
| Parser p(std::make_shared<Source>(src)); |
| return Def(p.parseFunction(/*is_method=*/true)); |
| }); |
| m.def("parse_type_comment", [](const std::string& comment) { |
| Parser p(std::make_shared<Source>(comment)); |
| return Decl(p.parseTypeComment()); |
| }); |
| |
| m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment); |
| m.def( |
| "import_ir_module", |
| [](std::shared_ptr<CompilationUnit> cu, |
| const std::string& filename, |
| py::object map_location, |
| ExtraFilesMap& extra_files) { |
| c10::optional<at::Device> optional_device; |
| if (!map_location.is(py::none())) { |
| AT_ASSERT(THPDevice_Check(map_location.ptr())); |
| optional_device = |
| reinterpret_cast<THPDevice*>(map_location.ptr())->device; |
| } |
| return import_ir_module( |
| std::move(cu), filename, optional_device, extra_files); |
| }); |
| m.def( |
| "import_ir_module_from_buffer", |
| [](std::shared_ptr<CompilationUnit> cu, |
| const std::string& buffer, |
| py::object map_location, |
| ExtraFilesMap& extra_files) { |
| std::istringstream in(buffer); |
| c10::optional<at::Device> optional_device; |
| if (!map_location.is(py::none())) { |
| AT_ASSERT(THPDevice_Check(map_location.ptr())); |
| optional_device = |
| reinterpret_cast<THPDevice*>(map_location.ptr())->device; |
| } |
| return import_ir_module( |
| std::move(cu), in, optional_device, extra_files); |
| }); |
| |
| m.def("_jit_set_emit_hooks", setEmitHooks); |
| m.def("_jit_get_emit_hooks", getEmitHooks); |
| m.def("_jit_clear_class_registry", []() { |
| get_python_cu()->_clear_python_cu(); |
| }); |
| m.def( |
| "_debug_set_autodiff_subgraph_inlining", |
| debugSetAutodiffSubgraphInlining); |
| m.def("_propagate_shapes", _propagate_shapes); |
| m.def( |
| "_propagate_and_assign_input_shapes", |
| _propagate_and_assign_input_shapes); |
| m.def("_assign_output_shapes", _assign_output_shapes); |
| m.def( |
| "_last_executed_optimized_graph", |
| []() { return lastExecutedOptimizedGraph(); }, |
| "Retrieve the optimized graph that was run the last time the graph executor ran on this thread"); |
| m.def( |
| "_create_function_from_graph", |
| [](const std::string& qualname, std::shared_ptr<Graph> graph) { |
| // TODO this should go in the global Python CU |
| auto cu = std::make_shared<CompilationUnit>(); |
| c10::QualifiedName name(qualname); |
| auto fn = cu->create_function(std::move(name), graph); |
| return StrongFunctionPtr(std::move(cu), fn); |
| }); |
| m.def("_ivalue_tags_match", ivalue_tags_match); |
| m.def("_jit_debug_module_iterators", _jit_debug_module_iterators); |
| |
| py::class_<testing::FileCheck>(m, "FileCheck") |
| .def(py::init<>()) |
| .def("check", &testing::FileCheck::check) |
| .def("check_not", &testing::FileCheck::check_not) |
| .def("check_same", &testing::FileCheck::check_same) |
| .def("check_next", &testing::FileCheck::check_next) |
| .def("check_count", &testing::FileCheck::check_count) |
| .def("check_dag", &testing::FileCheck::check_dag) |
| .def("check_count", &testing::FileCheck::check_count) |
| .def( |
| "check_count", |
| [](testing::FileCheck& f, |
| const std::string& str, |
| size_t count, |
| bool exactly) { return f.check_count(str, count, exactly); }, |
| "Check Count", |
| py::arg("str"), |
| py::arg("count"), |
| py::arg("exactly") = false) |
| .def( |
| "run", |
| [](testing::FileCheck& f, const std::string& str) { |
| return f.run(str); |
| }) |
| .def( |
| "run", [](testing::FileCheck& f, const Graph& g) { return f.run(g); }) |
| .def( |
| "run", |
| [](testing::FileCheck& f, |
| const std::string& input, |
| const std::string& output) { return f.run(input, output); }, |
| "Run", |
| py::arg("checks_file"), |
| py::arg("test_file")) |
| .def( |
| "run", |
| [](testing::FileCheck& f, const std::string& input, const Graph& g) { |
| return f.run(input, g); |
| }, |
| "Run", |
| py::arg("checks_file"), |
| py::arg("graph")); |
| |
| m.def( |
| "_logging_set_logger", |
| [](logging::LoggerBase* logger) { return logging::setLogger(logger); }, |
| py::return_value_policy::reference); |
| m.def("_set_graph_executor_optimize", [](bool optimize) { |
| setGraphExecutorOptimize(optimize); |
| }); |
| |
| m.def("_get_graph_executor_optimize", &torch::jit::getGraphExecutorOptimize); |
| |
| m.def("_create_module_with_type", [](const ClassTypePtr& type) { |
| return Module(get_python_cu(), type); |
| }); |
| |
| m.def("_export_opnames", |
| [](script::Module& sm) {return debugMakeList(torch::jit::export_opnames(sm));}); |
| |
| py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>( |
| m, "ConcreteModuleTypeBuilder") |
| .def(py::init<py::object>()) |
| .def("add_constant", &ConcreteModuleTypeBuilder::addConstant) |
| .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute) |
| .def( |
| "add_function_attribute", |
| &ConcreteModuleTypeBuilder::addFunctionAttribute) |
| .def( |
| "add_builtin_function", |
| &ConcreteModuleTypeBuilder::addBuiltinFunction) |
| .def("add_module", &ConcreteModuleTypeBuilder::addModule) |
| .def("add_overload", &ConcreteModuleTypeBuilder::addOverload) |
| .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned) |
| .def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute) |
| .def( |
| "set_module_dict", |
| [](ConcreteModuleTypeBuilder& self) { |
| self.setIterableModuleKind(IterableModuleKind::DICT); |
| }) |
| .def("build", &ConcreteModuleTypeBuilder::build) |
| .def( |
| "equals", |
| [](const ConcreteModuleTypeBuilder& self, |
| const ConcreteModuleTypeBuilder& other) { return self.equals(other); }) |
| .def("set_module_list", [](ConcreteModuleTypeBuilder& self) { |
| self.setIterableModuleKind(IterableModuleKind::LIST); |
| }); |
| |
| py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>( |
| m, "ConcreteModuleType") |
| .def_property_readonly("py_class", &ConcreteModuleType::getPyClass) |
| .def_property_readonly("jit_type", &ConcreteModuleType::getJitType) |
| .def_static("from_jit_type", &ConcreteModuleType::fromJitType) |
| .def("get_constants", &ConcreteModuleType::getConstantsPy) |
| .def("get_attributes", &ConcreteModuleType::getAttributesPy) |
| .def("get_modules", &ConcreteModuleType::getModulesPy) |
| .def("dump", &ConcreteModuleType::dump) |
| .def( |
| "equals", |
| [](const ConcreteModuleType& self, const ConcreteModuleType& other) { |
| return self.equals(other); |
| }) |
| .def( |
| "equals", |
| [](const ConcreteModuleType& self, |
| const ConcreteModuleTypeBuilder& other) { |
| return self.equals(other); |
| }) |
| .def( |
| "_create_methods", |
| [](std::shared_ptr<ConcreteModuleType> concreteType, |
| const std::vector<Def>& defs, |
| const std::vector<ResolutionCallback>& rcbs, |
| const std::vector<FunctionDefaults>& defaults) { |
| TORCH_INTERNAL_ASSERT(defs.size() == rcbs.size()); |
| std::vector<ResolverPtr> resolvers; |
| resolvers.reserve(rcbs.size()); |
| for (auto& callback : rcbs) { |
| resolvers.push_back(pythonResolver(callback)); |
| } |
| const auto& selfType = |
| concreteType->getJitType()->expect<ClassType>(); |
| const auto& prefix = selfType->name().value(); |
| const auto self = ModuleSelf(std::move(concreteType)); |
| auto cu = selfType->compilation_unit(); |
| cu->define(prefix, defs, resolvers, &self); |
| // Stitch in default arguments for each Def if provided |
| auto defaults_it = defaults.begin(); |
| auto defs_it = defs.begin(); |
| while (defs_it != defs.end()) { |
| const auto method_name = |
| QualifiedName(prefix, (*defs_it).name().name()); |
| auto& method = cu->get_function(method_name); |
| method.setSchema(getSchemaWithNameAndDefaults( |
| defs_it->range(), |
| method.getSchema(), |
| at::nullopt, |
| *defaults_it)); |
| ++defs_it; |
| ++defaults_it; |
| } |
| }); |
| |
| m.def( |
| "_resolve_type", |
| [](const std::string& name, SourceRange range, ResolutionCallback rcb) { |
| return pythonResolver(rcb)->resolveType(name, range); |
| }); |
| |
| m.def( |
| "_run_emit_module_hook", [](const Module& m) { didFinishEmitModule(m); }); |
| |
| py::class_<logging::LoggerBase, std::shared_ptr<logging::LoggerBase>>( |
| m, "LoggerBase"); |
| py::enum_<logging::LockingLogger::AggregationType>(m, "AggregationType") |
| .value("SUM", logging::LockingLogger::AggregationType::SUM) |
| .value("AVG", logging::LockingLogger::AggregationType::AVG) |
| .export_values(); |
| py::class_< |
| logging::LockingLogger, |
| logging::LoggerBase, |
| std::shared_ptr<logging::LockingLogger>>(m, "LockingLogger") |
| .def(py::init<>()) |
| .def("set_aggregation_type", &logging::LockingLogger::setAggregationType) |
| .def("get_counter_val", &logging::LockingLogger::getCounterValue); |
| py::class_< |
| logging::NoopLogger, |
| logging::LoggerBase, |
| std::shared_ptr<logging::NoopLogger>>(m, "NoopLogger") |
| .def(py::init<>()); |
| } |
| } // namespace script |
| } // namespace jit |
| } // namespace torch |