Split Type into TypeExtendedInterface and Type (#11520)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11520
Previously, we had Type which was a catch all interface for all
functions and methods we could possibly want to do dynamic dispatch
on. However, we want to check in a non-autogenerated Tensor class
to ATen/core, and to do this, we must also check in a non-autogenerated
Type class which we can do dispatch on. In principle, we could
put the full Type interface in ATen/core, but this would be
a bad developer experience, since any time you add a new free
function, you'd have to regenerate the checked in Type header.
For a better dev experience, we split Type into a two parts,
Type, which will be checked in (though not in this diff), and
TypeExtendedInterface, which will NOT be checked in. Type contains
just enough methods to let Tensor be defined, and leaves the
rest to TypeExtendedInterface.
Some complications:
- We (very unfortunately) have overloaded virtual methods. Because
of C++'s rules, we cannot move one overload without doing some
extra work to make sure that overload in a superclass and an
overload in a subclass resolve together. I've chosen to resolve
this problem simply by moving ALL overloads of a method which
occurs in Tensor to Type.
- There are some places where we take a type() object and call
a method on it, which is not a Tensor base method. I've eliminated
some where possible, but in other cases calling the method on type
is the ONLY way to invoke it; in that case, I've just inserted
a cast. Further refactoring is necessary.
Reviewed By: gchanan
Differential Revision: D9771708
fbshipit-source-id: c59d39fe919cd6f42be6dca699d474346ea3c614
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index b830aa3..287b789 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -107,6 +107,10 @@
# NB: As far as ezyang can tell, we don't *have* to codegen this,
# because we will inherit it from the TYPE_METHOD_DEFINITION_CONCRETE in
# the superclass. But it doesn't seem to be harmful.
+#
+# TODO: self_ty is a hack to make things work for native methods which need to
+# take a dtype, but also need to dispatch differently for different types.
+# Eliminate it at some point.
TYPE_DERIVED_DEFINITION_NATIVE = CodeTemplate("""\
${return_type} ${Type}::${api_name}(${type_method_formals}) const {
${device_guard_declaration}
@@ -173,7 +177,7 @@
# the same name (but different signature) already
ZERO_DIM_CHECK = CodeTemplate("""\
if (${check_name}.dim() == 0) {
- return static_cast<const Type*>(this)->${api_name}(${zero_dim_actuals});
+ return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${zero_dim_actuals});
}""")
ZERO_DIM_ONLY = CodeTemplate("""\
@@ -183,7 +187,7 @@
SPARSE_CHECK = CodeTemplate("""\
if(${check_name}.type().is_sparse()) {
- return static_cast<const Type*>(this)->${api_name}(${sparse_actuals});
+ return static_cast<const TypeExtendedInterface*>(this)->${api_name}(${sparse_actuals});
}""")
BUFFER_DEFINITION = CodeTemplate("""\
@@ -390,6 +394,7 @@
'type_registrations': List[str],
'type_headers': List[str],
'pure_virtual_type_method_declarations': List[str],
+ 'pure_virtual_extended_type_method_declarations': List[str],
'type_method_declarations': List[str],
'type_method_definitions': List[str],
'type_method_inline_definitions': List[str],
@@ -490,6 +495,9 @@
'formals': List[str],
'inferred_type': str,
'inplace': bool,
+ # This controls whether or not we generate the interface in Type or
+ # TypeExtendedInterface
+ 'extended_method': bool,
'method_actuals': List[str],
'method_formals_with_defaults': List[str],
'method_formals': List[str],
@@ -836,8 +844,12 @@
# NN function with no _forward/_backward suffix don't have cimpls.
# They call the _forward function and discard any buffer returns
abstract = False
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ if option['extended_method']:
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ else:
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
body = emit_nn_body(option)
@@ -845,17 +857,27 @@
TYPE_METHOD_DEFINITION_CONCRETE.substitute(
env, type_definition_body=body))
elif broadcast_arg is None:
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ if option['extended_method']:
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ else:
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION_ABSTRACT.substitute(env))
top_env['type_method_definitions'].append(
TYPE_METHOD_DEFINITION_ABSTRACT.substitute(env))
else:
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ if option['extended_method']:
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
+ else:
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
top_env['type_method_declarations'].append(
TYPE_METHOD_DECLARATION_BROADCAST.substitute(env))
top_env['type_method_declarations'].append(
@@ -888,7 +910,7 @@
method_of.append('Tensor')
if is_namespace_function:
- option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
+ option['inferred_type'] = 'detail::infer_type({})'.format(dispatch_tensor)
top_env['function_declarations'].append(
FUNCTION_DECLARATION.substitute(env))
top_env['function_definitions'].append(
@@ -1060,11 +1082,21 @@
# Factory methods are not dispatched over `Type`.
if not is_factory_method:
if option['deprecated']:
+ # Deprecated functions are always non-extended,
+ # because they need to be made available from Type
+ # (the public interface) so that code like
+ # tensor.type().arange(...) keeps working. Once
+ # we remove the deprecated functions, we can eliminate
+ # these methods entirely.
top_env['pure_virtual_type_method_declarations'].append(
DEPRECATED_PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
else:
- top_env['pure_virtual_type_method_declarations'].append(
- PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ if option['extended_method']:
+ top_env['pure_virtual_extended_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
+ else:
+ top_env['pure_virtual_type_method_declarations'].append(
+ PURE_VIRTUAL_TYPE_METHOD_DECLARATION.substitute(env))
top_env['type_method_declarations'].append(TYPE_METHOD_DECLARATION_CONCRETE.substitute(env))
dispatch = option['type_method_definition_dispatch']
option['native_type_method_dispatch'] = dispatch
@@ -1116,12 +1148,12 @@
if is_namespace_function:
if dispatch_type:
- option['inferred_type'] = dispatch_type['name']
+ option['inferred_type'] = 'static_cast<const TypeExtendedInterface&>({})'.format(dispatch_type['name'])
elif dispatch_tensor:
- option['inferred_type'] = 'infer_type({})'.format(dispatch_tensor)
+ option['inferred_type'] = 'detail::infer_type({})'.format(dispatch_tensor)
else:
# doesn't depend on a specific type, use undefined float
- option['inferred_type'] = 'at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float)'
+ option['inferred_type'] = 'detail::non_specific_type()'
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
top_env['function_declarations'].append(declaration.substitute(env))
if is_factory_method:
diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py
index 025faff..5497a75 100644
--- a/aten/src/ATen/gen.py
+++ b/aten/src/ATen/gen.py
@@ -107,6 +107,7 @@
SPARSE_TYPE_DERIVED_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/SparseTypeDerived.cpp")
TYPE_DERIVED_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDerived.h")
TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h")
+TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h")
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
@@ -165,6 +166,7 @@
'cuda_type_registrations': [],
'cuda_type_headers': [],
'pure_virtual_type_method_declarations': [],
+ 'pure_virtual_extended_type_method_declarations': [],
'type_method_declarations': [],
'type_method_definitions': [],
'type_method_inline_definitions': [],
@@ -330,7 +332,7 @@
# so that the script runs quickly when we are just querying the
# outputs
def declare_outputs():
- files = ['Declarations.yaml', 'Type.h', 'TypeDefault.cpp', 'TypeDefault.h', 'Tensor.h',
+ files = ['Declarations.yaml', 'Type.h', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h', 'Tensor.h',
'TensorMethods.h', 'Functions.h',
'CPUCopy.cpp', 'NativeFunctions.h',
'RegisterCPU.cpp', 'RegisterCPU.h']
@@ -400,6 +402,7 @@
backend, density, scalar_type, declarations))
file_manager.write('Type.h', TYPE_H, top_env)
+ file_manager.write('TypeExtendedInterface.h', TYPE_EXTENDED_INTERFACE_H, top_env)
file_manager.write('TypeDefault.h', TYPE_DEFAULT_H, top_env)
file_manager.write('TypeDefault.cpp', TYPE_DEFAULT_CPP, top_env)
diff --git a/aten/src/ATen/native/LegacyBridge.cpp b/aten/src/ATen/native/LegacyBridge.cpp
index 07d7e46..1364c0c 100644
--- a/aten/src/ATen/native/LegacyBridge.cpp
+++ b/aten/src/ATen/native/LegacyBridge.cpp
@@ -144,34 +144,34 @@
Tensor tensor(const Type& dtype) {
if (_type_has_native(dtype)) {
- return dtype.native_tensor();
+ return static_cast<const TypeExtendedInterface&>(dtype).native_tensor();
} else {
- return dtype.th_tensor();
+ return static_cast<const TypeExtendedInterface&>(dtype).th_tensor();
}
}
Tensor tensor(const Type& dtype, ArrayRef<int64_t> size) {
if (_type_has_native(dtype)) {
- return dtype.native_tensor(size);
+ return static_cast<const TypeExtendedInterface&>(dtype).native_tensor(size);
} else {
- return dtype.th_tensor(size);
+ return static_cast<const TypeExtendedInterface&>(dtype).th_tensor(size);
}
}
Tensor sparse_coo_tensor(const Type& dtype, ArrayRef<int64_t> size) {
- return dtype.toSparse().native_sparse_coo_tensor(size);
+ return static_cast<const TypeExtendedInterface&>(dtype.toSparse()).native_sparse_coo_tensor(size);
}
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values) {
- return values.type().toSparse().native_sparse_coo_tensor(indices, values);
+ return static_cast<const TypeExtendedInterface&>(values.type().toSparse()).native_sparse_coo_tensor(indices, values);
}
Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, ArrayRef<int64_t> size) {
- return values.type().toSparse().native_sparse_coo_tensor(indices, values, size);
+ return static_cast<const TypeExtendedInterface&>(values.type().toSparse()).native_sparse_coo_tensor(indices, values, size);
}
Tensor _sparse_coo_tensor_unsafe(const Tensor& indices, const Tensor& values, ArrayRef<int64_t> size) {
- return values.type().toSparse()._native_sparse_coo_tensor_unsafe(indices, values, size);
+ return static_cast<const TypeExtendedInterface&>(values.type().toSparse())._native_sparse_coo_tensor_unsafe(indices, values, size);
}
int64_t get_device(const Tensor& self) {
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index 1a12549..20211ae 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -59,6 +59,10 @@
window_length);
}
+const TypeExtendedInterface& getFactoryType(const TensorOptions& options) {
+ return static_cast<const TypeExtendedInterface&>(at::getType(options));
+}
+
} // namespace
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ arange ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -73,7 +77,7 @@
Scalar step,
const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
- return at::getType(options)._arange(start, end, step);
+ return getFactoryType(options)._arange(start, end, step);
}
Tensor& arange_out(Tensor& result, Scalar start, Scalar end) {
@@ -86,7 +90,7 @@
Tensor arange(Scalar end, const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
- return at::getType(options)._arange(end);
+ return getFactoryType(options)._arange(end);
}
Tensor& arange_out(Tensor& result, Scalar end) {
@@ -94,7 +98,7 @@
}
Tensor _dim_arange(const Tensor& like, int64_t dim) {
- return like.type().toScalarType(at::kLong)._arange(like.size(dim));
+ return static_cast<const TypeExtendedInterface&>(like.type().toScalarType(at::kLong))._arange(like.size(dim));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -102,7 +106,7 @@
Tensor empty(IntList size, const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
// Can't call a factory function, because the buck stops with us!
- return at::getType(options).tensor(size);
+ return getFactoryType(options).tensor(size);
}
Tensor& empty_out(Tensor& result, IntList size) {
@@ -218,7 +222,7 @@
int64_t steps,
const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
- return at::getType(options)._linspace(start, end, steps);
+ return getFactoryType(options)._linspace(start, end, steps);
}
Tensor& linspace_out(Tensor& result, Scalar start, Scalar end) {
@@ -241,7 +245,7 @@
int64_t steps,
const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
- return at::getType(options)._logspace(start, end, steps);
+ return getFactoryType(options)._logspace(start, end, steps);
}
Tensor& logspace_out(Tensor& result, Scalar start, Scalar end) {
@@ -475,7 +479,7 @@
Scalar step,
const TensorOptions& options) {
// Note [Native bindings for legacy TH factory functions]
- return at::getType(options)._range(start, end, step);
+ return getFactoryType(options)._range(start, end, step);
}
Tensor& range_out(Tensor& result, Scalar start, Scalar end) {
diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py
index 173ac43..98b22c7 100644
--- a/aten/src/ATen/preprocess_declarations.py
+++ b/aten/src/ATen/preprocess_declarations.py
@@ -217,8 +217,20 @@
(raw_args - filtered_args)]
+def is_extended_method(option):
+ if 'method' in option['variants']:
+ return False
+ elif option.get('deprecated', False):
+ return False
+ elif not option['variants']:
+ return False
+ else:
+ return True
+
+
def run(declarations):
declarations = [d for d in declarations if not exclude(d)]
+ non_extended_methods = set()
for declaration in declarations:
common_with_cwrap.set_declaration_defaults(declaration)
declaration['options'] = [deepcopy(o) for o in declaration['options']]
@@ -237,6 +249,20 @@
sanitize_return(option)
process_types_and_backends(option)
add_variants(option)
+ if not is_extended_method(option):
+ non_extended_methods.add(option['api_name'])
declaration['options'] = handle_outputs_taken_as_arguments(
declaration['options'])
+
+ # We (very unfortunately) have overloaded virtual methods. Because
+ # of C++'s rules, we cannot move one overload without doing some
+ # extra work to make sure that overload in a superclass and an
+ # overload in a subclass resolve together. I've chosen to resolve
+ # this problem simply by moving ALL overloads of a method which
+ # occurs in Tensor to Type. This is why we have to first compute
+ # which methods *names* go on type, and then move ALL overloads
+ # of this name to Type.
+ for declaration in declarations:
+ for option in declaration['options']:
+ option['extended_method'] = option['api_name'] not in non_extended_methods
return declarations
diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h
index b4a2e05..7e2b658 100644
--- a/aten/src/ATen/templates/Functions.h
+++ b/aten/src/ATen/templates/Functions.h
@@ -4,6 +4,7 @@
#include "ATen/core/Scalar.h"
#include "ATen/Type.h"
+#include "ATen/TypeExtendedInterface.h"
#include "ATen/Tensor.h"
#include "ATen/core/Storage.h"
#include "ATen/core/Generator.h"
@@ -20,14 +21,22 @@
${function_declarations}
-static inline Type & infer_type(const Tensor & t) {
+namespace detail {
+
+static inline TypeExtendedInterface & infer_type(const Tensor & t) {
AT_CHECK(t.defined(), "undefined Tensor");
- return t.type();
+ return static_cast<TypeExtendedInterface&>(t.type());
}
-static inline Type & infer_type(const TensorList & tl) {
+static inline TypeExtendedInterface & infer_type(const TensorList & tl) {
AT_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
- return tl[0].type();
+ return static_cast<TypeExtendedInterface&>(tl[0].type());
}
+static inline TypeExtendedInterface & non_specific_type() {
+ return static_cast<TypeExtendedInterface&>(at::getNonVariableType(at::Backend::Undefined, at::ScalarType::Float));
+}
+
+} // namespace detail
+
// function definitions are all static inline because
// they are one-line statically dispatched functions that
// invoke the actual dynamic dispatch on the correct argument
diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/templates/TypeDefault.h
index 0f315e5..e4a75ab 100644
--- a/aten/src/ATen/templates/TypeDefault.h
+++ b/aten/src/ATen/templates/TypeDefault.h
@@ -2,13 +2,13 @@
// ${generated_comment}
-#include "ATen/Type.h"
+#include "ATen/TypeExtendedInterface.h"
namespace at {
-struct AT_API TypeDefault : public Type {
+struct AT_API TypeDefault : public TypeExtendedInterface {
explicit TypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
- : Type(type_id, is_variable, is_undefined) {}
+ : TypeExtendedInterface(type_id, is_variable, is_undefined) {}
// Make sure overload resolution considers the nullary virtual method.
// (A single argument overload is generated in the list.)
diff --git a/aten/src/ATen/templates/TypeExtendedInterface.h b/aten/src/ATen/templates/TypeExtendedInterface.h
new file mode 100644
index 0000000..82cb658
--- /dev/null
+++ b/aten/src/ATen/templates/TypeExtendedInterface.h
@@ -0,0 +1,12 @@
+#pragma once
+#include <ATen/Type.h>
+
+namespace at {
+
+struct AT_API TypeExtendedInterface : public Type {
+ explicit TypeExtendedInterface(TensorTypeId type_id, bool is_variable, bool is_undefined)
+ : Type(type_id, is_variable, is_undefined) {}
+ ${pure_virtual_extended_type_method_declarations}
+};
+
+} // namespace at
diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h
index 9f327fd..c573891 100644
--- a/caffe2/contrib/aten/aten_op_template.h
+++ b/caffe2/contrib/aten/aten_op_template.h
@@ -214,10 +214,10 @@
DEFINE_IF(int64, Long)
CAFFE_THROW("unsupported type annotation: ", name);
}
- at::Type & stringToType(const std::string & name) {
- return at::getNonVariableType(backend(), stringToScalarType(name));
+ at::TypeExtendedInterface & stringToType(const std::string & name) {
+ return static_cast<at::TypeExtendedInterface&>(at::getNonVariableType(backend(), stringToScalarType(name)));
}
- at::Type * readTypeAttribute(const std::string & name) {
+ at::TypeExtendedInterface * readTypeAttribute(const std::string & name) {
CAFFE_ENFORCE(OperatorBase::HasSingleArgumentOfType<std::string>(name));
return &stringToType(OperatorBase::GetSingleArgument<std::string>(name, ""));
}
diff --git a/caffe2/contrib/aten/gen_op.py b/caffe2/contrib/aten/gen_op.py
index 18a3db4..bc75ac9 100755
--- a/caffe2/contrib/aten/gen_op.py
+++ b/caffe2/contrib/aten/gen_op.py
@@ -278,7 +278,7 @@
# first tensor input is used to define the output type.
defined_inferred_type = True
env['statements'].append(
- 'auto inferred_type = &({}.type());'.format(
+ 'auto inferred_type = &(static_cast<at::TypeExtendedInterface&>({}.type()));'.format(
arg['name']))
else:
init = CT(ARGUMENT_MAP[arg['type']]).substitute(env, arg=arg['name'])
diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp
index d4a9a4e..589bbf8 100644
--- a/tools/autograd/templates/VariableType.cpp
+++ b/tools/autograd/templates/VariableType.cpp
@@ -44,7 +44,7 @@
VariableType::VariableType(Context* context, Type* baseType)
: TypeDefault(baseType->type_id(), /*is_variable=*/true, /*is_undefined=*/false)
- , baseType(baseType)
+ , baseType(static_cast<TypeExtendedInterface*>(baseType))
, id_(context->freshTypeID()) {
str = std::string("Variable[") + baseType->toString() + "]";
}
diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h
index fe3e57f..b9d84ad 100644
--- a/tools/autograd/templates/VariableType.h
+++ b/tools/autograd/templates/VariableType.h
@@ -72,7 +72,7 @@
static at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
static std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
- at::Type* baseType;
+ at::TypeExtendedInterface* baseType;
std::string str;
size_t id_;
};