Mild refactor of native_functions.yaml dispatch parsing (#66109)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66109
This refactor is no longer necessary for ufunc codegen, as I changed
the format of ufuncs to not directly be inserted into the 'dispatch'
key, but I think the refactored code here is better. The basic concept
is to directly construct BackendMetadata as we are parsing entries of
the dispatch dictionary, rather than post facto creating them later.
This centralizes the compute and means that the creation of the backend index
is just a simple reindexing by operator name (nothing nontrivial).
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: bdhirsh
Differential Revision: D31385760
Pulled By: ezyang
fbshipit-source-id: 4fcb491ba025d2aa6fd356586b57affb97a507fc
(cherry picked from commit 21c93d41996120697f81168650b4f4b999d6902a)
diff --git a/tools/codegen/model.py b/tools/codegen/model.py
index 6bc0d7df..a5ae3a3 100644
--- a/tools/codegen/model.py
+++ b/tools/codegen/model.py
@@ -355,20 +355,27 @@
raw_dispatch = e.pop('dispatch', None)
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
- dispatch: Dict[DispatchKey, str] = {}
+ dispatch: Dict[DispatchKey, BackendMetadata] = {}
if raw_dispatch is not None:
assert not manual_kernel_registration, \
"cannot specify both manual_kernel_registration and dispatch; with " \
"manual registration, dispatch has no effect!"
+ redundant_composite_implicit_autograd = False
for ks, v in raw_dispatch.items():
if ks == '__line__':
continue # not worth tracking line numbers for dispatch entries
assert isinstance(ks, str), e
- assert isinstance(v, str), e
for k in ks.split(","):
dispatch_key = DispatchKey.parse(k.strip())
- dispatch[dispatch_key] = v
- assert dispatch != {DispatchKey.CompositeImplicitAutograd: cpp.name(func)}, \
+ # Why is 'structured' included? External backends (e.g.
+ # XLA) opt into which ops are structured independently
+ # of which in-tree ops are structured
+ dispatch[dispatch_key] = BackendMetadata(
+ v, structured=structured and is_structured_dispatch_key(dispatch_key))
+ if dispatch_key is DispatchKey.CompositeImplicitAutograd and v == cpp.name(func):
+ redundant_composite_implicit_autograd = True
+
+ assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), \
"unnecessary dispatch table for this function; just delete the dispatch " \
"key entirely"
# if a function is a structured delegate, deleting the dispatch
@@ -378,7 +385,7 @@
f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " \
"name, then delete the dispatch table"
elif not structured and structured_delegate is None:
- dispatch[DispatchKey.CompositeImplicitAutograd] = cpp.name(func)
+ dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata(cpp.name(func), structured=False)
assert not (DispatchKey.CompositeExplicitAutograd in dispatch and DispatchKey.CompositeImplicitAutograd in dispatch), \
"cannot specify both CompositeExplicitAutograd and CompositeImplicitAutograd on a single kernel; each " \
@@ -394,12 +401,11 @@
has_composite_implicit_autograd_kernel = DispatchKey.CompositeImplicitAutograd in dispatch.keys()
has_composite_explicit_autograd_kernel = DispatchKey.CompositeExplicitAutograd in dispatch.keys()
- # BackendMetadata is used to store any information about a NativeFunction that is backend dependent.
- # The most obvious information is the kernel name, which usually contains the name of the backend in it for cpu/cuda.
- # Why is 'structured' included? External backends (e.g. XLA) opt into which ops are structured
- # independently of which in-tree ops are structured
- backend_metadata = {k: {func.name: BackendMetadata(
- kernel=v, structured=structured and is_structured_dispatch_key(k))} for k, v in dispatch.items()}
+ # We aren't going to store dispatch metadata inline in NativeFunctions;
+ # instead it is separately indexed by backend (so other backends can
+ # add more dispatch entries after the fact). Reindex the individual
+ # metadata by OperatorName!
+ backend_metadata = {k: {func.name: v} for k, v in dispatch.items()}
# don't care if it exists or not; make it easier to use this function
# with other yaml parsers that aren't setting __line__ in the dict