Update script API to take example inputs (#55376)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55376

Test Plan: Imported from OSS

Reviewed By: driazati, gmagogsfm

Differential Revision: D27897350

Pulled By: nikithamalgifb

fbshipit-source-id: 4f63235b9eae898c8f4ccaec3fcf64b4b29c860e
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 72f2490..af2422e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -19,6 +19,7 @@
 )
 from torch.jit._script import (
     script,
+    _script_pdt,
     Attribute,
     ScriptModule,
     script_method,
diff --git a/torch/jit/_script.py b/torch/jit/_script.py
index 6f1772d..eeb110d 100644
--- a/torch/jit/_script.py
+++ b/torch/jit/_script.py
@@ -13,7 +13,7 @@
 import copy
 import pickle
 import warnings
-from typing import Any, Dict
+from typing import Any, Dict, List, Tuple, Optional
 
 
 import torch
@@ -36,6 +36,14 @@
 from torch.overrides import (
     has_torch_function, has_torch_function_unary, has_torch_function_variadic)
 
+from torch.jit._monkeytype_config import (
+    monkeytype_trace,
+    JitTypeTraceConfig ,
+    JitTypeTraceStore
+)
+
+type_trace_db = JitTypeTraceStore()  # DB to hold all call traces from MonkeyType
+
 torch._C.ScriptMethod.graph_for = _graph_for  # type: ignore
 torch._C.ScriptFunction.graph_for = _graph_for  # type: ignore
 ScriptFunction = torch._C.ScriptFunction
@@ -106,6 +114,9 @@
         Returns `value`
 """
 
+def _get_type_trace_db():
+    # This is a private API. Use of this for external purposes is discouraged.
+    return type_trace_db
 
 # Gets a function from the name of a method on a type
 def _get_function_from_type(cls, name):
@@ -840,7 +851,43 @@
     memo: Dict[int, torch.nn.Module] = {}
     return call_prepare_scriptable_func_impl(obj, memo)
 
-def script(obj, optimize=None, _frames_up: int = 0, _rcb=None):
+def _script_pdt(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs: Optional[List[Tuple]] = None):
+    # This is a private API, intended for internal use only. Usage of this API is only for experimental
+    # purposes only and is highly discouraged.
+    global type_trace_db
+    if not _enabled:
+        return obj
+
+    if optimize is not None:
+        warnings.warn(
+            "`optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead"
+        )
+
+    # No-op for modules and functions that are already scripted
+    if isinstance(obj, ScriptModule):
+        return obj
+    if isinstance(obj, ScriptFunction):
+        return obj
+
+    qualified_name = _qualified_name(obj)
+
+    # If MonkeyType is installed, enable profile directed type annotation
+    # Check if example_inputs are defined and generate call traces
+    # for the method by running eager mode version of the method with
+    # the provide example inputs. This logs all the traces in type_trace_db
+    type_trace_db = JitTypeTraceStore()
+    if monkeytype_trace:
+        monkeytype_config = JitTypeTraceConfig(type_trace_db)
+        with monkeytype_trace(monkeytype_config):
+            for example_input in example_inputs:  # type: ignore[union-attr]
+                obj(*example_input)
+    else:
+        warnings.warn("Warning: monkeytype is not installed. Please install https://github.com/Instagram/MonkeyType "
+                      "to enable Profile-Directed Typing in TorchScript. Refer to "
+                      "https://github.com/Instagram/MonkeyType/blob/master/README.rst to install MonkeyType. ")
+    return script(obj, optimize, _frames_up, _rcb)
+
+def script(obj, optimize=None, _frames_up=0, _rcb=None):
     r"""
     Scripting a function or ``nn.Module`` will inspect the source code, compile
     it as TorchScript code using the TorchScript compiler, and return a :class:`ScriptModule` or