| import torch._C as C |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| |
| from collections import namedtuple |
| import itertools |
| import unittest |
| |
| # TODO: Expand the dispatcher API to be a generic API for interfacing with |
| # the dispatcher from Python! |
| # |
| # These are exhaustive tests for commutativity of dispatch behavior. If you're |
| # looking for more usage-info style tests, check op_registration_test.cpp |
| # |
| # Things not tested here: |
| # - Listeners |
| # - Top level namespace registrations |
| # - Fallback |
| # - Exotic overloads of CppFunction/schema |
| # |
| # Things not directly tested here: |
| # - Internal state of Dispatcher makes sense. This is indirectly |
| # tested by the invariant testing |
| |
| Result = namedtuple('Result', 'state table provenance') |
| |
| class TestDispatch(TestCase): |
| namespace_index = 0 |
| |
| def test_all_invariants(self): |
| # Check that the regular stuff is OK! |
| C._dispatch_check_all_invariants() |
| |
| # You probably don't want to call this directly; if your constructors |
| # don't commute, you can still run commute with a fixed ctor_order |
| # so that you can test that the destructors still commute |
| def run_ops(self, name, ops, ctor_order=None, dtor_order=None, |
| results=None, expect_raises=False): |
| """ |
| Given a list of operator registrations, run the registrations in the |
| order specified by ctor_order, and then run the deregistrations in |
| dtor_order. |
| |
| If results is specified, intermediate results are checked for consistency |
| with results stored in results (and stored in results if this is the |
| first time we've seen them). Results are expected to be equivalent |
| modulo commutativity and inverses (thus, results is keyed on a frozenset |
| of in effect registrations from ops). Results stores namedtuple |
| Result[state, table, provenance], where state is a string that contains |
| non-derived kernel registered or error message if it doesn't pass; |
| table is a string that contains computed dispatch table entries; |
| provenance is a string that describes how exactly we got this string. |
| |
| If expect_raises is True, it is not an error to raise an exception. Instead, |
| we'll store the exception string (instead of the dispatcher state) |
| in results. In principle we should flag these differently, but it's |
| very obvious when you get an error in one case but not another. |
| """ |
| # By allocating every test into a fresh namespace, this makes it less |
| # likely that a bug in the testing framework will result in tests |
| # interfering with each other |
| self.__class__.namespace_index += 1 |
| if results is None: |
| results = {} |
| if ctor_order is None: |
| ctor_order = list(range(len(ops))) |
| if dtor_order is None: |
| dtor_order = list(reversed(ctor_order)) |
| # Refs which retain the c10::Module object so we can explicitly control |
| # when each deregistration happens (deregistration occurs when the |
| # object gets deallocated). |
| refs = [None] * len(ops) |
| # Keep track of the set "in effect" registrations |
| active_ops = set() |
| |
| # double underscore to make it less likely we conflict with something |
| # else |
| test_namespace = "__test{}__".format(self.namespace_index) |
| |
| def check_invariants(actual_provenance): |
| C._dispatch_check_invariants(name) |
| # Normalize the test namespace so that expected outputs are stable |
| actual_state = C._dispatch_dump( |
| "{}::{}".format(test_namespace, name)).replace(test_namespace, "test") |
| actual_table = C._dispatch_dump_table( |
| "{}::{}".format(test_namespace, name)).replace(test_namespace, "test") |
| expected_state, expected_table, expected_provenance = results.setdefault( |
| frozenset(active_ops), |
| Result(actual_state, actual_table, actual_provenance) |
| ) |
| self.assertMultiLineEqual( |
| expected_state, actual_state, |
| "expected from {}; actual from {}" |
| .format(expected_provenance, actual_provenance) |
| ) |
| self.assertMultiLineEqual( |
| expected_table, actual_table, |
| "expected from {}; actual from {}" |
| .format(expected_provenance, actual_provenance) |
| ) |
| |
| results.setdefault(frozenset(), Result("", "", "hardcoded initial state")) |
| check_invariants("initial state") |
| # In the order specified by ctor_order, run registrations |
| set_to_report = frozenset(range(len(ops))) |
| for i, op_ix in enumerate(ctor_order): |
| # It would be better to DEF here, but because we manage |
| # lifetime of multiple registrations with multiple Library |
| # references (refs), we can't deal with the strict checking |
| # from DEF. |
| refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "") |
| active_ops.add(op_ix) |
| try: |
| ops[op_ix](refs[op_ix]) |
| check_invariants("running ctors {}".format(ctor_order[:i + 1])) |
| except RuntimeError as e: |
| if not expect_raises: |
| raise |
| actual = str(e).replace(test_namespace, "test") |
| expected, _, expected_provenance = results.setdefault( |
| frozenset(active_ops), |
| Result(actual, "", "error after running ctors {}".format(ctor_order[:i + 1])) |
| ) |
| self.assertMultiLineEqual(expected, actual, expected_provenance) |
| set_to_report = frozenset(active_ops) |
| active_ops.remove(op_ix) |
| # NB: this finally test asserts that if a registrations fails, |
| # the dispatcher is left in the same state *that it was before*! |
| check_invariants( |
| "running ctors {} and then failing to run ctor {} " |
| "(did this failure leave the dispatcher in a wedged state? " |
| "it shouldn't!)" |
| .format(ctor_order[:i], op_ix)) |
| break |
| last_ctor = i |
| if expect_raises and len(active_ops) == len(ops): |
| # Destroy references first, as some test frameworks (like pytest) |
| # will retain references in the exception raised by assertTrue! EW! |
| refs = None |
| self.assertTrue( |
| False, |
| "expected exception to be raised, but nothing was raised " |
| "(after running ctors {})".format(ctor_order)) |
| # In the order specified by dtor_order, run deregistrations |
| for i, op_ix in enumerate(dtor_order): |
| # Trigger a destruction |
| refs[op_ix] = None |
| # discard not remove, since we may not have actually deregistered |
| # anything if there was an error raised |
| if expect_raises: |
| active_ops.discard(op_ix) |
| else: |
| active_ops.remove(op_ix) |
| check_invariants( |
| "running ctors {}, then running dtors {}" |
| .format(ctor_order[:last_ctor + 1], dtor_order[:i + 1]) |
| ) |
| return results[set_to_report][0] |
| |
| # Operator registrations are commutative (as static initializers can |
| # run in any order) and invertible (by deregistration). (Subject |
| # to some caveats: some legacy behavior in the system are not commutative-- |
| # we want to get rid of these!) |
| # |
| # So while in principle we could simply test a set of operations |
| # by just running them one by one in the order specified by the user, |
| # we can get more assurance about these extra properties by doing |
| # more work: |
| # |
| # 1. Don't run the registrations once in a fixed order: run every possible |
| # permutation. Similarly, run every permutation of deregistration order. |
| # |
| # 2. Don't just check the end state of the dispatcher: for every |
| # subset of operator registrations, ensure that the computed |
| # intermediate state is path independent. One thing to note: |
| # in this function, we assume each operation is unique. In general, |
| # there may be duplicated registrations, but these are usually |
| # idempotent or legacy. We test for behavior here separately. |
| # |
| # NB: checking all permutations means this function is exponential in |
| # the length of ops! So don't pass too many ops to this function! |
| def commute(self, name, ops, ctor_order=None, expect_raises=False): |
| results = {} |
| |
| def go(ctor_order): |
| for dtor_order in itertools.permutations(range(len(ops))): |
| self.run_ops( |
| name, ops, ctor_order, dtor_order, |
| results=results, expect_raises=expect_raises) |
| |
| if ctor_order is not None: |
| go(ctor_order) |
| else: |
| for ctor_order in itertools.permutations(range(len(ops))): |
| go(ctor_order) |
| |
| # Return the "full" Result namedtuple after all operations are run. |
| # If this KeyErrors, that means that there did not exist any |
| # ordering of ctors which got us to the "end". That's an |
| # error in test construction: it means you could have |
| # factored the test into two smaller ones. |
| return results[frozenset(range(len(ops)))] |
| |
| def test_def(self): |
| state = self.commute("foo", [ |
| # m.def("foo(Tensor x) -> Tensor") |
| lambda m: m.def_("foo(Tensor x) -> Tensor"), |
| # m.impl("test_def", [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo"), |
| # m.impl("test_def", kCPU, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", dispatch="cpu"), |
| # m.impl("test_def", kAutograd, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", dispatch="autograd"), |
| # m.impl("test_def", kAutogradCPU, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", dispatch="autogradcpu") |
| ]).state |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: test::foo(Tensor x) -> (Tensor) |
| debug: registered at /dev/null:0 |
| alias analysis kind: FROM_SCHEMA |
| CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| ''') |
| |
| def test_def_impl_schema_mismatch(self): |
| # NB: an impl-impl mismatch is not reported eagerly; you'll find out |
| # about it because one of them won't match with def |
| state = self.commute("foo", [ |
| # m.def("foo(Tensor x, Tensor y) -> Tensor") |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), |
| # m.impl("foo", [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo"), |
| ], expect_raises=True).state |
| self.assertExpectedInline(state, '''In registration for test::foo: expected schema of operator to be "test::foo(Tensor x, Tensor y) -> (Tensor)" (registered at /dev/null:0), but got inferred schema "(Tensor _0) -> (Tensor _0)" (impl_t_t). The number of arguments is different. 2 vs 1.''') # noqa |
| |
| def test_def_with_inference(self): |
| state = self.commute("foo", [ |
| # m.def("foo", [](const Tensor & x) { return x }) |
| lambda m: m.def_name_t_t("foo"), |
| # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "cpu"), |
| # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "autograd"), |
| # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "autogradcpu") |
| ]).state |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: test::foo(Tensor _0) -> (Tensor _0) |
| debug: registered at /dev/null:0 |
| alias analysis kind: CONSERVATIVE |
| CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| ''') |
| |
| def test_def_only(self): |
| state = self.commute("foo", [ |
| # m.def("foo(Tensor x, Tensor y) -> Tensor") |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), |
| ]).state |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: test::foo(Tensor x, Tensor y) -> (Tensor) |
| debug: registered at /dev/null:0 |
| alias analysis kind: FROM_SCHEMA |
| ''') |
| |
| def test_impl_only(self): |
| state = self.commute("foo", [ |
| # m.impl("foo", [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo"), |
| # m.impl("foo", torch::kCPU, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", "cpu"), |
| # m.impl("foo", torch::kAutograd, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", "autograd"), |
| # m.impl("foo", torch::kAutogradCPU, [](const Tensor& x) { return x }) |
| lambda m: m.impl_t_t("foo", "autogradcpu") |
| ]).state |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: (none) |
| CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| catchall: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| ''') |
| |
| def test_computed_table(self): |
| result = self.commute("foo", [ |
| # m.def("foo", [](const Tensor & x) { return x }) |
| lambda m: m.def_name_t_t("foo"), |
| # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "cpu"), |
| # m.impl("foo", torch::kCUDA, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "xla"), |
| # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "autograd"), |
| # m.impl("foo", torch::kAutogradCPU, [](const Tensor & x) { return x }) |
| lambda m: m.impl_t_t("foo", "autogradcpu") |
| ]) |
| state, table = result.state, result.table |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: test::foo(Tensor _0) -> (Tensor _0) |
| debug: registered at /dev/null:0 |
| alias analysis kind: CONSERVATIVE |
| CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| XLA: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| AutogradCPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| ''') |
| |
| def extract_entries(table, dispatch_keys): |
| extracted = '' |
| table_entries = table.split('\n') |
| for k in dispatch_keys: |
| for t in table_entries: |
| if t.startswith(k): |
| extracted += (t + '\n') |
| return extracted |
| |
| # computed dispatch table is too big, so we only check on a few entries we're interested in. |
| extracted_table = extract_entries( |
| table, |
| ('CPU', 'CUDA', 'XLA', 'AutogradOther', 'AutogradCPU', 'AutogradCUDA', 'AutogradXLA')) |
| |
| self.assertExpectedInline(extracted_table, '''\ |
| CPU: impl_t_t [kernel] |
| CUDA: default_def_name_t_t [catch all] |
| XLA: impl_t_t [kernel] |
| AutogradOther: impl_t_t [autograd kernel] |
| AutogradCPU: impl_t_t [kernel] |
| AutogradCUDA: impl_t_t [autograd kernel] |
| AutogradXLA: impl_t_t [autograd kernel] |
| ''') |
| |
| # Can't do this yet for BC reasons |
| @unittest.expectedFailure |
| def test_multiple_def_error(self): |
| state = self.commute("foo", [ |
| # m.def("foo(Tensor x, Tensor y) -> Tensor") |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), |
| # m.def("foo(Tensor x, Tensor y) -> Tensor") |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), |
| ], expect_raises=True).state |
| # TODO: fill in the error message here |
| # self.assertExpectedInline(state, '''''') |
| |
| def test_def_with_explicit_alias(self): |
| state = self.commute("foo", [ |
| # m.def(torch::schema( |
| # "foo(Tensor x, Tensor y) -> Tensor", |
| # AliasAnalysisKind::PURE)) |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor", |
| alias="PURE_FUNCTION") |
| ]).state |
| self.assertExpectedInline(state, '''\ |
| name: test::foo |
| schema: test::foo(Tensor x, Tensor y) -> (Tensor) |
| debug: registered at /dev/null:0 |
| alias analysis kind: PURE_FUNCTION |
| ''') |
| |
| # TODO: get rid of this test when multiple defs are wrong |
| def test_multiple_def_schema_mismatch(self): |
| # error message is order dependent |
| ops = [ |
| # m.def("foo(Tensor x, Tensor y) -> Tensor") |
| lambda m: m.def_("foo(Tensor x, Tensor y) -> Tensor"), |
| # m.def("foo(Tensor x) -> Tensor") |
| lambda m: m.def_("foo(Tensor x) -> Tensor"), |
| ] |
| self.assertExpectedInline( |
| self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state, |
| '''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0)''' # noqa |
| ) |
| self.assertExpectedInline( |
| self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state, |
| '''Tried to register multiple operators with the same name and the same overload name but different schemas: test::foo(Tensor x, Tensor y) -> (Tensor) (registered at /dev/null:0) vs test::foo(Tensor x) -> (Tensor) (registered at /dev/null:0)''' # noqa |
| ) |
| |
| def test_multiple_def_alias_defaulting(self): |
| # TODO: should be an error in both directions soon |
| ops = [ |
| # m.def(torch::schema("foo(Tensor x) -> Tensor", |
| # c10::AliasAnalysisKind::PURE_FUNCTION)) |
| lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"), |
| # RegisterOperators().op("foo(Tensor x) -> Tensor") |
| lambda m: m.def_legacy("foo(Tensor x) -> Tensor"), |
| ] |
| state = self.commute("foo", ops, ctor_order=(0, 1)).state |
| self.assertExpectedInline( |
| state, |
| '''\ |
| name: test::foo |
| schema: test::foo(Tensor x) -> (Tensor) |
| debug: registered at /dev/null:0 |
| alias analysis kind: PURE_FUNCTION |
| ''' |
| ) |
| # NB: When run with ctor order (1, 0), the destructors are NOT |
| # COMMUTATIVE. THIS IS A BUG, however we are purposely leaving the bug |
| # in as it is very benign (only leaves us in a bad state during |
| # destruction, when no useful work is being done), will be fixed when we |
| # make alias defaulting a hard error, and is very nontrivial to fix |
| # prior to that. |
| |
| def test_multiple_def_alias_mismatch(self): |
| # error message is order dependent |
| ops = [ |
| # m.def(torch::schema("foo(Tensor x) -> Tensor", |
| # c10::AliasAnalysisKind::PURE_FUNCTION)) |
| lambda m: m.def_("foo(Tensor x) -> Tensor", alias="PURE_FUNCTION"), |
| # m.def(torch::schema("foo(Tensor x) -> Tensor", |
| # c10::AliasAnalysisKind::CONSERVATIVE)) |
| lambda m: m.def_("foo(Tensor x) -> Tensor", alias="CONSERVATIVE"), |
| ] |
| self.assertExpectedInline( |
| self.commute("foo", ops, ctor_order=(0, 1), expect_raises=True).state, |
| '''Tried to define the schema for test::foo with different alias analysis kinds: PURE_FUNCTION (registered at /dev/null:0) vs CONSERVATIVE (registered at /dev/null:0)''' # noqa |
| ) |
| self.assertExpectedInline( |
| self.commute("foo", ops, ctor_order=(1, 0), expect_raises=True).state, |
| '''Tried to define the schema for test::foo with different alias analysis kinds: CONSERVATIVE (registered at /dev/null:0) vs PURE_FUNCTION (registered at /dev/null:0)''' # noqa |
| ) |
| |
| def test_multiple_fallback(self): |
| global_m = C._dispatch_library("IMPL", "_", "xla") |
| global_m.fallback_fallthrough(), |
| try: |
| global_m.fallback_fallthrough(), |
| except RuntimeError as e: |
| self.assertExpectedInline( |
| str(e), |
| '''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa |
| ) |
| else: |
| self.assertTrue(False) |
| |
| def test_overwrite_catchall(self): |
| ops = [ |
| lambda m: m.impl_t_t("foo", debug="fn1"), |
| lambda m: m.impl_t_t("foo", debug="fn2"), |
| ] |
| # Not commutative |
| self.assertExpectedInline( |
| self.commute("foo", ops, ctor_order=(0, 1)).state, |
| '''\ |
| name: test::foo |
| schema: (none) |
| catchall: fn2 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| catchall (inactive): fn1 :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] |
| ''' |
| ) |
| |
| if __name__ == '__main__': |
| run_tests() |