Update script API to take example inputs (#55376)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55376
Test Plan: Imported from OSS
Reviewed By: driazati, gmagogsfm
Differential Revision: D27897350
Pulled By: nikithamalgifb
fbshipit-source-id: 4f63235b9eae898c8f4ccaec3fcf64b4b29c860e
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 72f2490..af2422e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -19,6 +19,7 @@
)
from torch.jit._script import (
script,
+ _script_pdt,
Attribute,
ScriptModule,
script_method,
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index 6f1772d..eeb110d 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -13,7 +13,7 @@
import copy
import pickle
import warnings
-from typing import Any, Dict
+from typing import Any, Dict, List, Tuple, Optional
import torch
@@ -36,6 +36,14 @@
from torch.overrides import (
has_torch_function, has_torch_function_unary, has_torch_function_variadic)
+from torch.jit._monkeytype_config import (
+ monkeytype_trace,
+ JitTypeTraceConfig ,
+ JitTypeTraceStore
+)
+
+type_trace_db = JitTypeTraceStore() # DB to hold all call traces from MonkeyType
+
torch._C.ScriptMethod.graph_for = _graph_for # type: ignore
torch._C.ScriptFunction.graph_for = _graph_for # type: ignore
ScriptFunction = torch._C.ScriptFunction
@@ -106,6 +114,9 @@
Returns `value`
"""
+def _get_type_trace_db():
+ # This is a private API. Use of this for external purposes is discouraged.
+ return type_trace_db
# Gets a function from the name of a method on a type
def _get_function_from_type(cls, name):
@@ -840,7 +851,43 @@
memo: Dict[int, torch.nn.Module] = {}
return call_prepare_scriptable_func_impl(obj, memo)
-def script(obj, optimize=None, _frames_up: int = 0, _rcb=None):
+def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Optional[List[Tuple]] = None):
+ # This is a private API, intended for internal use only. Usage of this API is only for experimental
+ # purposes only and is highly discouraged.
+ global type_trace_db
+ if not _enabled:
+ return obj
+
+ if optimize is not None:
+ warnings.warn(
+ "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
+ )
+
+ # No-op for modules and functions that are already scripted
+ if isinstance(obj, ScriptModule):
+ return obj
+ if isinstance(obj, ScriptFunction):
+ return obj
+
+ qualified_name = _qualified_name(obj)
+
+ # If MonkeyType is installed, enable profile directed type annotation
+ # Check if example_inputs are defined and generate call traces
+ # for the method by running eager mode version of the method with
+ # the provide example inputs. This logs all the traces in type_trace_db
+ type_trace_db = JitTypeTraceStore()
+ if monkeytype_trace:
+ monkeytype_config = JitTypeTraceConfig(type_trace_db)
+ with monkeytype_trace(monkeytype_config):
+ for example_input in example_inputs: # type: ignore[union-attr]
+ obj(*example_input)
+ else:
+ warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
+ "to enable Profile-Directed Typing in TorchScript. Refer to "
+ "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
+ return script(obj, optimize, _frames_up, _rcb)
+
+def script(obj, optimize=None, _frames_up=0, _rcb=None):
r"""
Scripting a function or ``nn.Module`` will inspect the source code, compile
it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or