Add boolean dispatch for function overloading (#14081)

Summary:
This PR allows to overload functions based on the value of a parameter (so long as it is a constant). See `max_pool1d` for an example usage.

This is the first step in enabling the use of `max_pool` functions for the standard library that can return `Tensor` or `Tuple[Tensor, Tensor]` based on the `return_indices` flag. This will give the JIT identical results to the Python versions of the functions.

Depends on #14232 for `Optional[BroadcastingList[T]]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14081

Differential Revision: D13192228

Pulled By: driazati

fbshipit-source-id: fce33c400c1fd06e59747d98507c5fdcd8d4c113
diff --git a/test/test_jit.py b/test/test_jit.py
index 34c9ee1..1e12ef5 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8674,6 +8674,42 @@
             return foo, a, bar
         self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
 
+    def test_bool_dispatch(self):
+        def kwarg_false(x):
+            # type: (Tensor) -> Tensor
+            return F.max_pool1d(x, 1, 1, return_indices=False)
+        self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
+
+        def kwarg_true(x):
+            # type: (Tensor) -> Tuple[Tensor, Tensor]
+            return F.max_pool1d(x, 1, 1, return_indices=True)
+        self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
+
+        def full_kwarg_false(x):
+            # type: (Tensor) -> Tensor
+            return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
+        self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
+
+        def full_kwarg_true(x):
+            # type: (Tensor) -> Tuple[Tensor, Tensor]
+            return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
+        self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
+
+        def use_default(x):
+            # type: (Tensor) -> Tensor
+            return F.max_pool1d(x, 1, 1)
+        self.checkScript(use_default, (torch.randn(3, 3, 3),))
+
+        def arg_false(x):
+            # type: (Tensor) -> Tensor
+            return F.max_pool1d(x, 1, 1, 0, 1, False, False)
+        self.checkScript(arg_false, (torch.randn(3, 3, 3),))
+
+        def arg_true(x):
+            # type: (Tensor) -> Tuple[Tensor, Tensor]
+            return F.max_pool1d(x, 1, 1, 0, 1, False, True)
+        self.checkScript(arg_true, (torch.randn(3, 3, 3),))
+
 
 class MnistNet(nn.Module):
     def __init__(self):
@@ -9277,13 +9313,6 @@
     'test_norm_fro',
     'test_norm_fro_default',
     'test_norm_nuc',
-    # skipped nn functional tests
-    # ops involves sampling which could not test
-
-    'test_nn_adaptive_max_pool1d',
-    'test_nn_adaptive_max_pool2d',
-    'test_nn_adaptive_max_pool3d',
-
 
     # argument has custom behavior
     'test_nn_fractional_max_pool2d',
@@ -9871,6 +9900,7 @@
     ('avg_pool3d', (S, S, S, S, S), (3,)),
     ('fractional_max_pool2d', (S, S, S, S), (3, [2, 3], None)),
     ('max_pool1d', (S, S, S), (2, 1)),
+    ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
     ('max_pool2d', (S, S, S, S), (2, 1)),
     ('max_pool3d', (S, S, S, S, S), (2, 1)),
     ('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 07b301c..f1b736e 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -23,6 +23,10 @@
 # Types that have been declared as weak modules
 _weak_types = weakref.WeakKeyDictionary()
 
+# Wrapper functions that can call either of 2 functions depending on a boolean
+# argument
+_boolean_dispatched = weakref.WeakKeyDictionary()
+
 COMPILATION_PENDING = object()
 COMPILED = object()
 
@@ -104,3 +108,44 @@
         "original_method": fn
     }
     return fn
+
+
+def boolean_dispatch(arg_name, arg_index, default, if_true, if_false):
+    """
+    Dispatches to either of 2 weak script functions based on a boolean argument.
+    In Torch Script, the boolean argument must be constant so that the correct
+    function to use can be determined at compile time.
+    """
+    if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None:
+        raise RuntimeError("both functions must be weak script")
+
+    def fn(*args, **kwargs):
+        dispatch_flag = False
+        if arg_name in kwargs:
+            dispatch_flag = kwargs[arg_name]
+        elif arg_index < len(args):
+            dispatch_flag = args[arg_index]
+
+        if dispatch_flag:
+            return if_true(*args, **kwargs)
+        else:
+            return if_false(*args, **kwargs)
+
+    if if_true.__doc__ is None and if_false.__doc__ is not None:
+        doc = if_false.__doc__
+        if_true.__doc__ = doc
+    elif if_false.__doc__ is None and if_true.__doc__ is not None:
+        doc = if_true.__doc__
+        if_false.__doc__ = doc
+    else:
+        raise RuntimeError("only one function can have a docstring")
+    fn.__doc__ = doc
+
+    _boolean_dispatched[fn] = {
+        "if_true": if_true,
+        "if_false": if_false,
+        "index": arg_index,
+        "default": default,
+        "arg_name": arg_name
+    }
+    return fn
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index 089cc24..9bd1016 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -232,6 +232,11 @@
   // if true, emitBuiltinCall will throw an exception if this builtin does not exist,
   // otherwise it will return nullptr if the builtin is not found.
   bool required);
+
+c10::optional<size_t> findInputWithName(
+  const std::string& name,
+  at::ArrayRef<NamedValue> kwargs);
+
 } // namespace script
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 1eda01a..5d40b4a 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -260,6 +260,62 @@
   std::shared_ptr<Module> module;
 };
 
+struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
+  BooleanDispatchValue(py::dict dispatched_fn)
+      : dispatched_fn_(std::move(dispatched_fn)) {}
+
+  std::string kind() const override {
+    return "boolean dispatch";
+  }
+
+  std::vector<NamedValue> removeIndex(
+      at::ArrayRef<NamedValue> arr,
+      size_t index) {
+    auto sliced = arr.vec();
+    sliced.erase(sliced.begin() + index);
+    return sliced;
+  }
+
+  std::shared_ptr<SugaredValue> call(
+      SourceRange loc,
+      Method& caller,
+      at::ArrayRef<NamedValue> inputs,
+      at::ArrayRef<NamedValue> attributes,
+      size_t n_binders) override {
+    c10::optional<bool> result;
+    Graph& graph = *(caller.graph());
+
+    auto index = py::cast<size_t>(dispatched_fn_["index"]);
+    auto arg_name = py::str(dispatched_fn_["arg_name"]);
+
+    if (index < inputs.size()) {
+      // Dispatch flag is in arg list
+      result = constant_as<bool>(inputs.at(index).value(graph));
+    } else if (auto i = findInputWithName(arg_name, attributes)) {
+      // Dispatch flag is in kwargs
+      result = constant_as<bool>(attributes[*i].value(graph));
+    } else {
+      // Didn't find dispatch flag, so use default value
+      result = py::cast<bool>(dispatched_fn_["default"]);
+    }
+
+    if (!result) {
+      throw ErrorReport(loc) << "value for boolean dispatch was not constant";
+    }
+
+    std::shared_ptr<SugaredValue> value;
+    if (*result) {
+      value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
+    } else {
+      value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
+    }
+    return value->call(loc, caller, inputs, attributes, n_binders);
+  }
+
+ private:
+  py::dict dispatched_fn_;
+};
+
 std::shared_ptr<SugaredValue> toSugaredValue(
     py::object obj,
     Method& m,
@@ -336,6 +392,12 @@
       return std::make_shared<ModuleValue>(mod);
     }
   }
+
+  py::object dispatched_fn =
+      py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
+  if (!dispatched_fn.is_none()) {
+    return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
+  }
   return std::make_shared<PythonValue>(obj);
 }
 
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 4fecbd5..c9f7672 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -7,7 +7,7 @@
 from torch._six import raise_from, with_metaclass, get_function_from_type
 from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
     _weak_script_methods, _weak_modules, _weak_types, COMPILED, \
-    COMPILATION_PENDING
+    COMPILATION_PENDING, _boolean_dispatched
 from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
     _list_with_default
 import torch.testing
@@ -623,6 +623,10 @@
         return self.module._get_method(attr)
 
 
+def _try_get_dispatched_fn(fn):
+    return _boolean_dispatched.get(fn)
+
+
 def _try_compile_weak_script(fn):
     entry = _compiled_weak_fns.get(fn)
     if entry is None:
@@ -1336,7 +1340,7 @@
     func = getattr(torch.nn.functional, name)
     if func is None:
         return False
-    return func in _compiled_weak_fns
+    return func in _compiled_weak_fns or func in _boolean_dispatched
 
 
 def _unwrap_optional(x):
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 91352fc..9b94bc9 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4,6 +4,7 @@
 import warnings
 import math
 import types
+from typing import List
 
 import torch
 from torch._C import _infer_size, _add_docstr
@@ -304,6 +305,7 @@
 def fractional_max_pool2d(input, kernel_size, output_size=None,
                           output_ratio=None, return_indices=False,
                           _random_samples=None):
+    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList1[int]], float, bool, Tensor) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies 2D fractional max pooling over an input signal composed of several input planes.
 
     Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
@@ -337,47 +339,111 @@
         raise ValueError("fractional_max_pool2d requires specifying either "
                          "an output_size, or a output_ratio")
     if output_size is None:
-        output_ratio = _pair(output_ratio)
-        output_size = (int(input.size(2) * output_ratio[0]),
-                       int(input.size(3) * output_ratio[1]))
+        _output_ratio = _pair(output_ratio)
+        _output_size = (int(input.size(2) * _output_ratio[0]),
+                        int(input.size(3) * _output_ratio[1]))
+    else:
+        _output_size = torch.jit._unwrap_optional(output_size)
 
     if _random_samples is None:
         _random_samples = input.new(input.size(0), input.size(1), 2).uniform_()
-    ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
+    ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, _output_size, _random_samples)
     return ret if return_indices else ret[0]
 
 
-def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
-               ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
+                            dilation=1, ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 1D max pooling over an input signal composed of several input
     planes.
 
     See :class:`~torch.nn.MaxPool1d` for details.
     """
-    ret = torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
-    return ret if return_indices else ret[0]
+    if stride is None:
+        _stride = torch.jit.annotate(List[int], [])
+    else:
+        _stride = torch.jit._unwrap_optional(stride)
+    return torch.max_pool1d_with_indices(
+        input, kernel_size, _stride, padding, dilation, ceil_mode)
 
 
-def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
-               ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
+                ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tensor
+    return max_pool1d_with_indices(
+        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool1d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=6,
+    default=False,
+    if_true=max_pool1d_with_indices,
+    if_false=_max_pool1d)
+
+
+@torch._jit_internal.weak_script
+def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
+                            ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 2D max pooling over an input signal composed of several input
     planes.
 
     See :class:`~torch.nn.MaxPool2d` for details.
     """
-    ret = torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
-    return ret if return_indices else ret[0]
+    if stride is None:
+        _stride = torch.jit.annotate(List[int], [])
+    else:
+        _stride = torch.jit._unwrap_optional(stride)
+    return torch._C._nn.max_pool2d_with_indices(input, kernel_size, _stride, padding, dilation, ceil_mode)
 
 
-def max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
-               ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
+                ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tensor
+    return max_pool2d_with_indices(
+        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool2d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=6,
+    default=False,
+    if_true=max_pool2d_with_indices,
+    if_false=_max_pool2d)
+
+
+@torch._jit_internal.weak_script
+def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
+                            dilation=1, ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor]  # noqa
     r"""Applies a 3D max pooling over an input signal composed of several input
     planes.
 
     See :class:`~torch.nn.MaxPool3d` for details.
     """
-    ret = torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
-    return ret if return_indices else ret[0]
+    if stride is None:
+        _stride = torch.jit.annotate(List[int], [])
+    else:
+        _stride = torch.jit._unwrap_optional(stride)
+    return torch._C._nn.max_pool3d_with_indices(
+        input, kernel_size, _stride, padding, dilation, ceil_mode)
+
+
+@torch._jit_internal.weak_script
+def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
+                ceil_mode=False, return_indices=False):
+    # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tensor
+    return max_pool3d_with_indices(
+        input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool3d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=6,
+    default=False,
+    if_true=max_pool3d_with_indices,
+    if_false=_max_pool3d)
 
 
 def _unpool_output_size(input, kernel_size, stride, padding, output_size):
@@ -488,7 +554,9 @@
     return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
 
 
-def adaptive_max_pool1d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 1D adaptive max pooling over an input signal composed of
     several input planes.
 
@@ -498,11 +566,25 @@
         output_size: the target output size (single integer)
         return_indices: whether to return pooling indices. Default: ``False``
     """
-    ret = torch.adaptive_max_pool1d(input, output_size)
-    return ret if return_indices else ret[0]
+    return torch.adaptive_max_pool1d(input, output_size)
 
 
-def adaptive_max_pool2d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def _adaptive_max_pool1d(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+    return adaptive_max_pool1d_with_indices(input, output_size)[0]
+
+adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=2,
+    default=False,
+    if_true=adaptive_max_pool1d_with_indices,
+    if_false=_adaptive_max_pool1d)
+
+
+@torch._jit_internal.weak_script
+def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 2D adaptive max pooling over an input signal composed of
     several input planes.
 
@@ -514,11 +596,25 @@
         return_indices: whether to return pooling indices. Default: ``False``
     """
     output_size = _list_with_default(output_size, input.size())
-    ret = torch._C._nn.adaptive_max_pool2d(input, output_size)
-    return ret if return_indices else ret[0]
+    return torch._C._nn.adaptive_max_pool2d(input, output_size)
 
 
-def adaptive_max_pool3d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def _adaptive_max_pool2d(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+    return adaptive_max_pool2d_with_indices(input, output_size)[0]
+
+adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=2,
+    default=False,
+    if_true=adaptive_max_pool2d_with_indices,
+    if_false=_adaptive_max_pool2d)
+
+
+@torch._jit_internal.weak_script
+def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
     r"""Applies a 3D adaptive max pooling over an input signal composed of
     several input planes.
 
@@ -530,8 +626,20 @@
         return_indices: whether to return pooling indices. Default: ``False``
     """
     output_size = _list_with_default(output_size, input.size())
-    ret = torch._C._nn.adaptive_max_pool3d(input, output_size)
-    return ret if return_indices else ret[0]
+    return torch._C._nn.adaptive_max_pool3d(input, output_size)
+
+
+@torch._jit_internal.weak_script
+def _adaptive_max_pool3d(input, output_size, return_indices=False):
+    # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+    return adaptive_max_pool3d_with_indices(input, output_size)[0]
+
+adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
+    arg_name='return_indices',
+    arg_index=2,
+    default=False,
+    if_true=adaptive_max_pool3d_with_indices,
+    if_false=_adaptive_max_pool3d)
 
 
 adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r"""