Add boolean dispatch for function overloading (#14081)
Summary:
This PR allows to overload functions based on the value of a parameter (so long as it is a constant). See `max_pool1d` for an example usage.
This is the first step in enabling the use of `max_pool` functions for the standard library that can return `Tensor` or `Tuple[Tensor, Tensor]` based on the `return_indices` flag. This will give the JIT identical results to the Python versions of the functions.
Depends on #14232 for `Optional[BroadcastingList[T]]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14081
Differential Revision: D13192228
Pulled By: driazati
fbshipit-source-id: fce33c400c1fd06e59747d98507c5fdcd8d4c113
diff --git a/test/test_jit.py b/test/test_jit.py
index 34c9ee1..1e12ef5 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8674,6 +8674,42 @@
return foo, a, bar
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
+ def test_bool_dispatch(self):
+ def kwarg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, return_indices=False)
+ self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
+
+ def kwarg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, return_indices=True)
+ self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
+
+ def full_kwarg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
+ self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
+
+ def full_kwarg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
+ self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
+
+ def use_default(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1)
+ self.checkScript(use_default, (torch.randn(3, 3, 3),))
+
+ def arg_false(x):
+ # type: (Tensor) -> Tensor
+ return F.max_pool1d(x, 1, 1, 0, 1, False, False)
+ self.checkScript(arg_false, (torch.randn(3, 3, 3),))
+
+ def arg_true(x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return F.max_pool1d(x, 1, 1, 0, 1, False, True)
+ self.checkScript(arg_true, (torch.randn(3, 3, 3),))
+
class MnistNet(nn.Module):
def __init__(self):
@@ -9277,13 +9313,6 @@
'test_norm_fro',
'test_norm_fro_default',
'test_norm_nuc',
- # skipped nn functional tests
- # ops involves sampling which could not test
-
- 'test_nn_adaptive_max_pool1d',
- 'test_nn_adaptive_max_pool2d',
- 'test_nn_adaptive_max_pool3d',
-
# argument has custom behavior
'test_nn_fractional_max_pool2d',
@@ -9871,6 +9900,7 @@
('avg_pool3d', (S, S, S, S, S), (3,)),
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3], None)),
('max_pool1d', (S, S, S), (2, 1)),
+ ('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
('max_pool2d', (S, S, S, S), (2, 1)),
('max_pool3d', (S, S, S, S, S), (2, 1)),
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 07b301c..f1b736e 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -23,6 +23,10 @@
# Types that have been declared as weak modules
_weak_types = weakref.WeakKeyDictionary()
+# Wrapper functions that can call either of 2 functions depending on a boolean
+# argument
+_boolean_dispatched = weakref.WeakKeyDictionary()
+
COMPILATION_PENDING = object()
COMPILED = object()
@@ -104,3 +108,44 @@
"original_method": fn
}
return fn
+
+
+def boolean_dispatch(arg_name, arg_index, default, if_true, if_false):
+ """
+ Dispatches to either of 2 weak script functions based on a boolean argument.
+ In Torch Script, the boolean argument must be constant so that the correct
+ function to use can be determined at compile time.
+ """
+ if _compiled_weak_fns.get(if_true) is None or _compiled_weak_fns.get(if_false) is None:
+ raise RuntimeError("both functions must be weak script")
+
+ def fn(*args, **kwargs):
+ dispatch_flag = False
+ if arg_name in kwargs:
+ dispatch_flag = kwargs[arg_name]
+ elif arg_index < len(args):
+ dispatch_flag = args[arg_index]
+
+ if dispatch_flag:
+ return if_true(*args, **kwargs)
+ else:
+ return if_false(*args, **kwargs)
+
+ if if_true.__doc__ is None and if_false.__doc__ is not None:
+ doc = if_false.__doc__
+ if_true.__doc__ = doc
+ elif if_false.__doc__ is None and if_true.__doc__ is not None:
+ doc = if_true.__doc__
+ if_false.__doc__ = doc
+ else:
+ raise RuntimeError("only one function can have a docstring")
+ fn.__doc__ = doc
+
+ _boolean_dispatched[fn] = {
+ "if_true": if_true,
+ "if_false": if_false,
+ "index": arg_index,
+ "default": default,
+ "arg_name": arg_name
+ }
+ return fn
diff --git a/torch/csrc/jit/script/compiler.h b/torch/csrc/jit/script/compiler.h
index 089cc24..9bd1016 100644
--- a/torch/csrc/jit/script/compiler.h
+++ b/torch/csrc/jit/script/compiler.h
@@ -232,6 +232,11 @@
// if true, emitBuiltinCall will throw an exception if this builtin does not exist,
// otherwise it will return nullptr if the builtin is not found.
bool required);
+
+c10::optional<size_t> findInputWithName(
+ const std::string& name,
+ at::ArrayRef<NamedValue> kwargs);
+
} // namespace script
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index 1eda01a..5d40b4a 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -260,6 +260,62 @@
std::shared_ptr<Module> module;
};
+struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue {
+ BooleanDispatchValue(py::dict dispatched_fn)
+ : dispatched_fn_(std::move(dispatched_fn)) {}
+
+ std::string kind() const override {
+ return "boolean dispatch";
+ }
+
+ std::vector<NamedValue> removeIndex(
+ at::ArrayRef<NamedValue> arr,
+ size_t index) {
+ auto sliced = arr.vec();
+ sliced.erase(sliced.begin() + index);
+ return sliced;
+ }
+
+ std::shared_ptr<SugaredValue> call(
+ SourceRange loc,
+ Method& caller,
+ at::ArrayRef<NamedValue> inputs,
+ at::ArrayRef<NamedValue> attributes,
+ size_t n_binders) override {
+ c10::optional<bool> result;
+ Graph& graph = *(caller.graph());
+
+ auto index = py::cast<size_t>(dispatched_fn_["index"]);
+ auto arg_name = py::str(dispatched_fn_["arg_name"]);
+
+ if (index < inputs.size()) {
+ // Dispatch flag is in arg list
+ result = constant_as<bool>(inputs.at(index).value(graph));
+ } else if (auto i = findInputWithName(arg_name, attributes)) {
+ // Dispatch flag is in kwargs
+ result = constant_as<bool>(attributes[*i].value(graph));
+ } else {
+ // Didn't find dispatch flag, so use default value
+ result = py::cast<bool>(dispatched_fn_["default"]);
+ }
+
+ if (!result) {
+ throw ErrorReport(loc) << "value for boolean dispatch was not constant";
+ }
+
+ std::shared_ptr<SugaredValue> value;
+ if (*result) {
+ value = toSugaredValue(dispatched_fn_["if_true"], caller, loc);
+ } else {
+ value = toSugaredValue(dispatched_fn_["if_false"], caller, loc);
+ }
+ return value->call(loc, caller, inputs, attributes, n_binders);
+ }
+
+ private:
+ py::dict dispatched_fn_;
+};
+
std::shared_ptr<SugaredValue> toSugaredValue(
py::object obj,
Method& m,
@@ -336,6 +392,12 @@
return std::make_shared<ModuleValue>(mod);
}
}
+
+ py::object dispatched_fn =
+ py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
+ if (!dispatched_fn.is_none()) {
+ return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
+ }
return std::make_shared<PythonValue>(obj);
}
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 4fecbd5..c9f7672 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -7,7 +7,7 @@
from torch._six import raise_from, with_metaclass, get_function_from_type
from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
_weak_script_methods, _weak_modules, _weak_types, COMPILED, \
- COMPILATION_PENDING
+ COMPILATION_PENDING, _boolean_dispatched
from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
_list_with_default
import torch.testing
@@ -623,6 +623,10 @@
return self.module._get_method(attr)
+def _try_get_dispatched_fn(fn):
+ return _boolean_dispatched.get(fn)
+
+
def _try_compile_weak_script(fn):
entry = _compiled_weak_fns.get(fn)
if entry is None:
@@ -1336,7 +1340,7 @@
func = getattr(torch.nn.functional, name)
if func is None:
return False
- return func in _compiled_weak_fns
+ return func in _compiled_weak_fns or func in _boolean_dispatched
def _unwrap_optional(x):
diff --git a/torch/nn/functional.py b/torch/nn/functional.py
index 91352fc..9b94bc9 100644
--- a/torch/nn/functional.py
+++ b/torch/nn/functional.py
@@ -4,6 +4,7 @@
import warnings
import math
import types
+from typing import List
import torch
from torch._C import _infer_size, _add_docstr
@@ -304,6 +305,7 @@
def fractional_max_pool2d(input, kernel_size, output_size=None,
output_ratio=None, return_indices=False,
_random_samples=None):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList1[int]], float, bool, Tensor) -> Tuple[Tensor, Tensor] # noqa
r"""Applies 2D fractional max pooling over an input signal composed of several input planes.
Fractional MaxPooling is described in detail in the paper `Fractional MaxPooling`_ by Ben Graham
@@ -337,47 +339,111 @@
raise ValueError("fractional_max_pool2d requires specifying either "
"an output_size, or a output_ratio")
if output_size is None:
- output_ratio = _pair(output_ratio)
- output_size = (int(input.size(2) * output_ratio[0]),
- int(input.size(3) * output_ratio[1]))
+ _output_ratio = _pair(output_ratio)
+ _output_size = (int(input.size(2) * _output_ratio[0]),
+ int(input.size(3) * _output_ratio[1]))
+ else:
+ _output_size = torch.jit._unwrap_optional(output_size)
if _random_samples is None:
_random_samples = input.new(input.size(0), input.size(1), 2).uniform_()
- ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, output_size, _random_samples)
+ ret = torch._C._nn.fractional_max_pool2d(input, kernel_size, _output_size, _random_samples)
return ret if return_indices else ret[0]
-def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
- ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def max_pool1d_with_indices(input, kernel_size, stride=None, padding=0,
+ dilation=1, ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa
r"""Applies a 1D max pooling over an input signal composed of several input
planes.
See :class:`~torch.nn.MaxPool1d` for details.
"""
- ret = torch.max_pool1d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
- return ret if return_indices else ret[0]
+ if stride is None:
+ _stride = torch.jit.annotate(List[int], [])
+ else:
+ _stride = torch.jit._unwrap_optional(stride)
+ return torch.max_pool1d_with_indices(
+ input, kernel_size, _stride, padding, dilation, ceil_mode)
-def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
- ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def _max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], Optional[BroadcastingList1[int]], int, int, bool, bool) -> Tensor
+ return max_pool1d_with_indices(
+ input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool1d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool1d_with_indices,
+ if_false=_max_pool1d)
+
+
+@torch._jit_internal.weak_script
+def max_pool2d_with_indices(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa
r"""Applies a 2D max pooling over an input signal composed of several input
planes.
See :class:`~torch.nn.MaxPool2d` for details.
"""
- ret = torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
- return ret if return_indices else ret[0]
+ if stride is None:
+ _stride = torch.jit.annotate(List[int], [])
+ else:
+ _stride = torch.jit._unwrap_optional(stride)
+ return torch._C._nn.max_pool2d_with_indices(input, kernel_size, _stride, padding, dilation, ceil_mode)
-def max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
- ceil_mode=False, return_indices=False):
+@torch._jit_internal.weak_script
+def _max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList2[int], Optional[BroadcastingList2[int]], int, int, bool, bool) -> Tensor
+ return max_pool2d_with_indices(
+ input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool2d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool2d_with_indices,
+ if_false=_max_pool2d)
+
+
+@torch._jit_internal.weak_script
+def max_pool3d_with_indices(input, kernel_size, stride=None, padding=0,
+ dilation=1, ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tuple[Tensor, Tensor] # noqa
r"""Applies a 3D max pooling over an input signal composed of several input
planes.
See :class:`~torch.nn.MaxPool3d` for details.
"""
- ret = torch._C._nn.max_pool3d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)
- return ret if return_indices else ret[0]
+ if stride is None:
+ _stride = torch.jit.annotate(List[int], [])
+ else:
+ _stride = torch.jit._unwrap_optional(stride)
+ return torch._C._nn.max_pool3d_with_indices(
+ input, kernel_size, _stride, padding, dilation, ceil_mode)
+
+
+@torch._jit_internal.weak_script
+def _max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1,
+ ceil_mode=False, return_indices=False):
+ # type: (Tensor, BroadcastingList3[int], Optional[BroadcastingList3[int]], int, int, bool, bool) -> Tensor
+ return max_pool3d_with_indices(
+ input, kernel_size, stride, padding, dilation, ceil_mode)[0]
+
+max_pool3d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=6,
+ default=False,
+ if_true=max_pool3d_with_indices,
+ if_false=_max_pool3d)
def _unpool_output_size(input, kernel_size, stride, padding, output_size):
@@ -488,7 +554,9 @@
return (torch.sign(out) * relu(torch.abs(out))).mul(kernel_size).pow(1. / norm_type)
-def adaptive_max_pool1d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def adaptive_max_pool1d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 1D adaptive max pooling over an input signal composed of
several input planes.
@@ -498,11 +566,25 @@
output_size: the target output size (single integer)
return_indices: whether to return pooling indices. Default: ``False``
"""
- ret = torch.adaptive_max_pool1d(input, output_size)
- return ret if return_indices else ret[0]
+ return torch.adaptive_max_pool1d(input, output_size)
-def adaptive_max_pool2d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def _adaptive_max_pool1d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool1d_with_indices(input, output_size)[0]
+
+adaptive_max_pool1d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool1d_with_indices,
+ if_false=_adaptive_max_pool1d)
+
+
+@torch._jit_internal.weak_script
+def adaptive_max_pool2d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 2D adaptive max pooling over an input signal composed of
several input planes.
@@ -514,11 +596,25 @@
return_indices: whether to return pooling indices. Default: ``False``
"""
output_size = _list_with_default(output_size, input.size())
- ret = torch._C._nn.adaptive_max_pool2d(input, output_size)
- return ret if return_indices else ret[0]
+ return torch._C._nn.adaptive_max_pool2d(input, output_size)
-def adaptive_max_pool3d(input, output_size, return_indices=False):
+@torch._jit_internal.weak_script
+def _adaptive_max_pool2d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool2d_with_indices(input, output_size)[0]
+
+adaptive_max_pool2d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool2d_with_indices,
+ if_false=_adaptive_max_pool2d)
+
+
+@torch._jit_internal.weak_script
+def adaptive_max_pool3d_with_indices(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tuple[Tensor, Tensor]
r"""Applies a 3D adaptive max pooling over an input signal composed of
several input planes.
@@ -530,8 +626,20 @@
return_indices: whether to return pooling indices. Default: ``False``
"""
output_size = _list_with_default(output_size, input.size())
- ret = torch._C._nn.adaptive_max_pool3d(input, output_size)
- return ret if return_indices else ret[0]
+ return torch._C._nn.adaptive_max_pool3d(input, output_size)
+
+
+@torch._jit_internal.weak_script
+def _adaptive_max_pool3d(input, output_size, return_indices=False):
+ # type: (Tensor, BroadcastingList1[int], bool) -> Tensor
+ return adaptive_max_pool3d_with_indices(input, output_size)[0]
+
+adaptive_max_pool3d = torch._jit_internal.boolean_dispatch(
+ arg_name='return_indices',
+ arg_index=2,
+ default=False,
+ if_true=adaptive_max_pool3d_with_indices,
+ if_false=_adaptive_max_pool3d)
adaptive_avg_pool1d = _add_docstr(torch.adaptive_avg_pool1d, r"""