Simplify some TH codegen by moving code out of the switch and killing dead code. (#32888)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32888
This kills ~1500 lines of generated code by doing the following:
1) Stop binding _th_clone, which isn't used anymore.
2) Move allocation code out of the switch, because it doesn't need to be there, example:
Now:
```
auto dispatch_scalar_type = infer_scalar_type(self);
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(scalarTypeToTypeMeta(dispatch_scalar_type), 0, allocator(), true),DispatchKey::CPUTensorId).release();
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
switch (dispatch_scalar_type) {
case ScalarType::Bool: {
...
case ScalarType::Byte: {
...
```
Before:
```
auto dispatch_scalar_type = infer_scalar_type(self);
switch(dispatch_scalar_type) {
case ScalarType::Bool: {
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(caffe2::TypeMeta::Make<bool>(), 0, allocator(), true),DispatchKey::CPUTensorId).release();
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
case ScalarType::Byte: {
auto result_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(caffe2::TypeMeta::Make<byte>(), 0, allocator(), true),DispatchKey::CPUTensorId).release();
auto result = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(result_));
```
Note there's one extra lookup from ScalarType -> TypeMeta, but that can go away once we are able to put everything in a dispatch macro.
3) Prepare for more moves out of the switch by using dispatch_scalar_type where we would have used an explicit ScalarType::Name
More moves are currently blocked by "real" types needing to map scalar_type -> C++ type. Dispatch macros can solve that, but I'll need to wrap the actual TH calls in templates so the entire
thing can be done via dispatch.
4) Kill some codegen that isn't used anymore: ALLOC_WRAP, is_actual_return_long.
Test Plan: Imported from OSS
Differential Revision: D19672613
Pulled By: gchanan
fbshipit-source-id: 753f480842d11757e10182e43b471bd3abaa5446
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 57f0791..af0603a 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -135,20 +135,6 @@
- THTensor* self
]]
[[
- name: _th_clone
- cname: newClone
- return: THTensor*
- variants:
- - function
- cpu_half: True
- cpu_bool: True
- cuda_bool: True
- cpu_bfloat16: True
- cuda_bfloat16: True
- arguments:
- - THTensor* self
-]]
-[[
name: _th_index_select
cuda_bool: True
cname: indexSelect
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 59c3e95..3c9d06b 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -65,6 +65,7 @@
""")
LEGACY_TH_DEFINITION_SWITCH_STATEMENT = CodeTemplate("""\
${dispatch_scalar_type_declaration}
+${switch_prologue}
switch (dispatch_scalar_type) {
${cases}
default:
@@ -302,7 +303,7 @@
CodeTemplate(
'checked_dense_tensor_unwrap('
'${arg_name}, "${arg_name}", ${arg_pos}, "${api_name}", ${null_okay}, '
- 'DeviceType::${DeviceType}, ScalarType::${ScalarName})'),
+ 'DeviceType::${DeviceType}, ${scalar_type})'),
'THByteTensor*':
CodeTemplate(
'checked_dense_tensor_unwrap('
@@ -324,14 +325,14 @@
'${arg_name}, "${arg_name}", ${arg_pos}, '
# We're punning here (Backend and DeviceType constructors coincide)
# but DeviceType is the correct way to classify storages
- 'DeviceType::${Backend}, at::scalarTypeToTypeMeta(ScalarType::${ScalarName}))'),
+ 'DeviceType::${Backend}, at::scalarTypeToTypeMeta(${scalar_type}))'),
# This is a cast done via direct-construction
'IntArrayRefStride': CodeTemplate('at::IntArrayRef ${result_name} = get_intlist_stride_th(${arg_name});'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate(
'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, '
- 'Backend::${Backend}, ScalarType::${ScalarName})'),
+ 'Backend::${Backend}, ${scalar_type})'),
'IntArrayRef': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos})')
}
@@ -348,7 +349,7 @@
ALLOC_NOARGS_WRAP = {
'THTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
- '(c10::Storage(caffe2::TypeMeta::Make<${ScalarType}>(), 0, allocator(), true),'
+ '(c10::Storage(scalarTypeToTypeMeta(${ScalarName}), 0, allocator(), true),'
'DispatchKey::${Backend}TensorId).release()',
'THByteTensor*': 'c10::make_intrusive<TensorImpl, UndefinedTensorImpl>'
'(c10::Storage(scalarTypeToTypeMeta(ScalarType::Byte), 0, allocator(), true),'
@@ -361,13 +362,6 @@
'DispatchKey::${Backend}TensorId).release()',
}
-ALLOC_WRAP = {
- 'THTensor*': '${arguments}',
- 'THByteTensor*': '${arguments}',
- 'THBoolTensor*': '${arguments}',
- 'THIndexTensor*': '${arguments}',
-}
-
# Replacements for constants when calling into TH
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${ScalarType}'),
@@ -1356,16 +1350,6 @@
return [get_argument(env, argument, option)
for argument in arguments]
- def is_actual_return_long(env, ret):
- # type: (Environment, ReturnDecl) -> bool
- if ret['type'] == 'long':
- return True
- if ret['type'] == 'real':
- return env['ScalarName'] == 'Long'
- if ret['type'] == 'accreal':
- return env['AccScalarName'] == 'Long'
- return False
-
def handle_zero_dim(env, option):
# type: (Environment, FunctionOption) -> List[str]
zero_dim_dispatch = option.get('zero_dim_dispatch_when_scalar', '')
@@ -1381,10 +1365,10 @@
for arg in option['formals_list']]
return [ZERO_DIM_CHECK.substitute(env, check_name=zero_dim_dispatch, zero_dim_actuals=zero_dim_actuals)]
- def allocate_arg(env, arg, output_count):
- # type: (Environment, THFormal, int) -> List[str]
+ def allocate_arg(arg, output_count, backend, scalar_name):
+ # type: (THFormal, int, str, str) -> List[str]
name = arg['name']
- allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(env)
+ allocation = CodeTemplate(ALLOC_NOARGS_WRAP[arg['type']]).substitute(Backend=backend, ScalarName=scalar_name)
tensor_arg = '{}_'.format(name)
if arg.get('mask', False):
allocation = 'output_mask[{}] ? {} : nullptr'.format(output_count, allocation)
@@ -1437,16 +1421,23 @@
body = [] # type: List[str]
body += handle_zero_dim(env, option)
+ switch_prologue = [] # type: List[str]
+ output_count = 0
cases = []
+
+ for arg in option['arguments']:
+ # make a new allocation of TensorImpl, then wrap a Tensor around it.
+ if arg.get('allocate', False):
+ switch_prologue += allocate_arg(arg, output_count, env['Backend'], 'dispatch_scalar_type')
+ output_count += 1
+
for scalar_name, c_type, accreal, _ in scalar_types:
if scalar_name in scalar_type_cases:
case_body = [] # type: List[str]
# arguments are potentially duplicated because of one argument
# referencing another
seen_names = set() # type: Set[str]
- seen_tensorlists = set() # type: Set[str]
count = 0
- output_count = 0
case_env = {
'Backend': env['Backend'],
@@ -1466,28 +1457,23 @@
for arg in option['arguments']:
if is_real_argument_to_wrapper(arg):
count += 1
- if arg['type'] == 'TensorList':
- seen_tensorlists.add(arg['name'])
# only generated checked casts the first time we see it
if arg['name'] not in seen_names and requires_checked_cast(arg):
seen_names.add(arg['name'])
# make a new allocation of TensorImpl, then wrap a Tensor around it.
- if arg.get('allocate', False):
- case_body += allocate_arg(case_env, arg, output_count)
- output_count += 1
- # extract the TensorImpl from an existing tensor (or Storage, etc.)
- else:
+ if not arg.get('allocate', False):
# special case where we allow undefined Tensors, and thus
# the checked cast succeeds even if the Tensor is not
# defined
null_okay = 'true' if nullable_argument(arg) else 'false'
+ # extract the TensorImpl from an existing tensor (or Storage, etc.)
check_cast = CHECKED_CAST[arg['type']].substitute(
case_env, arg_name=arg['name'], arg_pos=count,
api_name=option['api_name'], null_okay=null_okay,
- size=arg.get('size'))
+ size=arg.get('size'), scalar_type='dispatch_scalar_type')
case_body.append("auto {}_ = {};".format(
arg['name'], check_cast))
@@ -1536,32 +1522,20 @@
assert len(calls) == 1
call = calls[0]
- if ret['type'] in ALLOC_WRAP.keys():
- wrapped_tensor = CodeTemplate(ALLOC_WRAP[ret['type']]).substitute(
- case_env, arguments=[call])
- return_tensor = (
- "return Tensor(" +
- "c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(" +
- "(${wrapped_tensor})));")
- case_body.append(CodeTemplate(return_tensor).substitute(
- case_env, wrapped_tensor=wrapped_tensor))
# return the same underlying Tensor type for both real and accreal; this ensures
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
- elif ret['type'] == 'accreal' or ret['type'] == 'real':
+ if ret['type'] == 'accreal' or ret['type'] == 'real':
return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
'options(ScalarType::${ScalarName}));')
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
else:
- # we using int64_t for long in the API, so correct it here...
- if is_actual_return_long(case_env, ret):
- call = "static_cast<int64_t>({})".format(call)
case_body.append("return {};".format(call))
else:
raise Exception("NYI - return handling")
cases.append(LEGACY_TH_DEFINITION_CASE.substitute(case_env, case_body=case_body))
- body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases))
+ body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases, switch_prologue=switch_prologue))
return body
def process_legacy_th_option(option):