Add flag to temporarily enable first class modules (#21560)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21560
ghimport-source-id: a555ca33fcd3efd1147aaf90f26a8e63da1c1a67
Reviewed By: suo
Differential Revision: D15729502
Pulled By: zdevito
fbshipit-source-id: d6c11472bfc791e2ad1e9aa695b0439d72b79681
diff --git a/test/test_jit.py b/test/test_jit.py
index a97c066..84f245c 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -230,6 +230,12 @@
s += t.sum()
return s
+@contextmanager
+def enable_first_class_mode():
+ torch._C._jit_set_first_class_mode(True)
+ yield
+ torch._C._jit_set_first_class_mode(False)
+
# helper function to generate test qparam
def _helper_generate_qparam(script_module, input_data):
class TestGenQParam:
@@ -2969,6 +2975,21 @@
self.assertEqual(D()(v), v + v)
+ def test_first_class_module(self):
+ with enable_first_class_mode():
+ class Foo(torch.jit.ScriptModule):
+ def __init__(self):
+ super(Foo, self).__init__()
+ self.foo = nn.Parameter(torch.rand(3, 4))
+
+ @torch.jit.script_method
+ def forward(self, input):
+ self.foo = input
+ return self.foo
+ foo = Foo()
+ input = torch.rand(3, 4)
+ foo.forward(input)
+ self.assertEqual(input, foo.foo)
def test_invalid_prefix_annotation(self):
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index eff1a0b..049e702 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -333,6 +333,9 @@
"_jit_set_profiling_mode",
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
.def(
+ "_jit_set_first_class_mode",
+ [](bool enabled) { script::setRunAsFirstClass(enabled); })
+ .def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {
return debugGetFusedKernelCode(g, inps);
diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp
index 8812d86..c35bfab 100644
--- a/torch/csrc/jit/script/module.cpp
+++ b/torch/csrc/jit/script/module.cpp
@@ -12,6 +12,16 @@
namespace jit {
namespace script {
+// first class mode runs models as first class objects,
+// and does not force inlining everywhere. This is experimental
+// as we bring up the system since it will degrade performance
+// and may introduce bugs. test_jit.py provides context managers
+// that enable it for specific tests.
+thread_local bool experimental_run_as_first_class = false;
+void setRunAsFirstClass(bool enabled) {
+ experimental_run_as_first_class = enabled;
+}
+
struct RecursiveMethodCallError : public std::exception {};
void placeholderCreator(Function&) {
throw RecursiveMethodCallError();
@@ -213,10 +223,40 @@
return schema.cloneWithArguments(std::move(sliced));
}
-Method::Method(Module* owner, Function* first_class_function)
+Method::Method(
+ Module* owner,
+ const std::shared_ptr<Function>& first_class_function)
: owner_(owner), schema_(sliceFirst(first_class_function->getSchema())) {
- std::tie(function_, initial_ivalues_) =
- owner->lower_first_class_method(first_class_function);
+ if (experimental_run_as_first_class) {
+ function_ = first_class_function;
+ // initial_ivalues_ left blank
+ } else {
+ std::tie(function_, initial_ivalues_) =
+ owner->lower_first_class_method(first_class_function.get());
+ }
+}
+
+void Method::run(Stack& stack) {
+ if (experimental_run_as_first_class) {
+ stack.insert(stack.begin(), owner().module_object());
+ }
+ for (const auto& input : initial_ivalues_) {
+ push(stack, input.value());
+ }
+ function_->run(stack);
+}
+
+IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
+ getSchema().checkAndNormalizeInputs(stack, kwargs);
+ if (experimental_run_as_first_class) {
+ stack.insert(stack.begin(), owner().module_object());
+ }
+ for (const auto& input : initial_ivalues_) {
+ push(stack, input.value());
+ }
+ // use run rather than operator() to skip the second schema check.
+ function_->run(stack);
+ return stack.front();
}
void Module::define(const std::string& src, const ResolverPtr& resolver) {
diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h
index b75dc64..a31688e 100644
--- a/torch/csrc/jit/script/module.h
+++ b/torch/csrc/jit/script/module.h
@@ -57,34 +57,19 @@
std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
struct TORCH_API Method {
- Method(Module* owner, Function* function);
+ Method(Module* owner, const std::shared_ptr<Function>& function);
// the module that contains this method.
Module& owner() const {
return *owner_;
}
- void run(Stack& stack) {
- for (auto input : initial_ivalues_) {
- push(stack, input.value());
- }
- function_->run(stack);
- }
+ void run(Stack& stack);
void run(Stack&& stack) {
run(stack);
}
- IValue operator()(
- std::vector<IValue> stack,
- const Kwargs& kwargs = Kwargs()) {
- getSchema().checkAndNormalizeInputs(stack, kwargs);
- for (auto input : initial_ivalues_) {
- push(stack, input.value());
- }
- // use run rather than operator() to skip the second schema check.
- function_->run(std::move(stack));
- return stack.front();
- }
+ IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs());
const std::vector<Slot>& initial_ivalues() const {
return initial_ivalues_;
@@ -315,7 +300,8 @@
return methods_[*offset].get();
}
- if (Function* fn = class_compilation_unit().find_function(name).get()) {
+ if (const std::shared_ptr<Function>& fn =
+ class_compilation_unit().find_function(name)) {
// lock because technically this is marked const,
// but we have to update the internal Method cache.
// This can be removed when class_compilation_unit() is the source of
@@ -448,8 +434,7 @@
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
template <typename... Types>
- IValue create_class(const c10::QualifiedName& name, Types&&... args)
- const {
+ IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
return create_class(name, {IValue(std::forward<Types>(args))...});
}
@@ -582,6 +567,8 @@
friend struct Method;
};
+TORCH_API void setRunAsFirstClass(bool enabled);
+
} // namespace script
} // namespace jit
} // namespace torch