more fixes to handle a lot of cwrap
diff --git a/aten/src/aten/function_wrapper.py b/aten/src/aten/function_wrapper.py
index bec1fd7..fa09d81 100644
--- a/aten/src/aten/function_wrapper.py
+++ b/aten/src/aten/function_wrapper.py
@@ -49,22 +49,36 @@
self.reason = reason
TYPE_FORMAL_GENERIC = {
- 'THTensor*' : 'Tensor &'
+ 'THTensor*' : 'Tensor &',
+ 'THBoolTensor*': 'Tensor &',
+ 'THStorage*' : 'Storage &',
+ 'THIndexTensor*' : 'Tensor &',
+ 'THGenerator*': 'Generator &',
+ 'accreal' : 'Scalar',
+ 'real' : 'Scalar',
}
TYPE_RETURN = {
'THTensor*' : 'Tensor *',
+ 'THIndexTensor*' : 'Tensor *',
+ 'THBoolTensor*' : 'Tensor *',
'real': 'Scalar',
- 'accreal': 'Scalar'
+ 'accreal': 'Scalar',
}
-TYPE_ARGUMENT = {
- 'THTensor*': CodeTemplate('checked_cast<${THTensor}>(${arg_name})'),
+CHECKED_CAST = {
+ 'THTensor*': CodeTemplate('checked_cast<${Tensor}>(&${arg_name})'),
+ 'THIndexTensor*' : CodeTemplate('checked_cast<${THIndexTensor}>(&${arg_name})')
}
RETURN_WRAP = {
'THTensor*': 'new ${Tensor}(context,${returned})'
}
+CONSTANT_REPLACEMENTS = [
+ ('AS_REAL','${ScalarType}'),
+ ('THPDefaultGenerator->cdata','dynamic_cast<${Processor}Generator*>(context->defaultGenerator(processor())->generator'),
+]
+
class nested_dict(object):
def __init__(self,base,parent):
self.base, self.parent = base,parent
@@ -76,6 +90,10 @@
def create_generic(top_env, declarations):
+ def is_real_argument_to_wrapper(argument):
+ return not argument.get('output',False) and\
+ argument['type'] != 'CONSTANT' and\
+ argument['type'] != 'argument'
def get_formals(option):
seen = set()
result = []
@@ -84,14 +102,14 @@
seen.add(argument['name'])
result.append(argument)
for argument in option['arguments']:
- if not argument.get('output',False):
+ if is_real_argument_to_wrapper(argument):
insert(argument)
for argument in option['arguments']:
- if not argument.get('allocate',False):
+ if argument.get('output') and not argument.get('allocate',False):
insert(argument)
return result
def format_formal(argument):
- type_str = TYPE_FORMAL_GENERIC.get(argument['type'],"NYIType")
+ type_str = TYPE_FORMAL_GENERIC.get(argument['type'],argument['type'])
return '{} {}'.format(type_str,argument['name'])
def format_return_type(option):
@@ -112,8 +130,8 @@
return argument['name']
return None
def process_option(option):
- if option['cname'] != 'neg':
- raise NYIError("all not implemented")
+ #if option['name'] != 'lt':
+ # raise NYIError("NYI")
formals = get_formals(option)
option['formals_list'] = formals
@@ -157,23 +175,33 @@
type_object_declarations = []
type_object_definitions = []
def requires_checked_cast(argument):
- return argument['type'] == 'THTensor*'
- def get_argument(argument):
+ return argument['type'] in CHECKED_CAST
+ def get_argument(argument,option):
if requires_checked_cast(argument):
return "{}_->tensor".format(argument['name'])
+ elif argument['type'] == "CONSTANT":
+ v = str(argument['name'])
+ for pattern,replacement in CONSTANT_REPLACEMENTS:
+ v = re.sub(pattern, replacement, v)
+ return CodeTemplate(v).substitute(processor_type_env)
+ # e.g. argument 0, i.e. repeat the 0th argument in this position...
+ elif argument['type'] == 'argument':
+ index = int(argument['name'])
+ return get_argument(option['arguments'][index],option)
else:
return argument['name']
-
+ def drop_argument(argument):
+ return argument['type'] == 'THGenerator*' and processor_type_env['Processor'] == 'CUDA'
def get_arguments(option):
- return [get_argument(argument) for argument in option['arguments']]
+ return [get_argument(argument,option)
+ for argument in option['arguments'] if not drop_argument(argument)]
def emit_body(env,option):
body = []
for arg in option['formals_list']:
if requires_checked_cast(arg):
- body.append(
- CodeTemplate("auto ${arg_name}_ = checked_cast<${Tensor}>(&${arg_name});").substitute(
- env,arg_name=arg['name']))
+ check_cast = CHECKED_CAST[arg['type']].substitute(env,arg_name=arg['name'])
+ body.append("auto {}_ = {};".format(arg['name'],check_cast))
for arg in option['arguments']:
if arg.get('allocate',False):
body.append(
@@ -191,8 +219,9 @@
if ret['type'] == 'THTensor*':
body.append(CodeTemplate("return new ${Tensor}(context,${arg_name});").substitute(env,arg_name=call))
else:
- body.append("return {};").format(call)
- assert(False and "NYI - return handling")
+ body.append("return {};".format(call))
+ else:
+ raise Exception("NYI - return handling")
return body
def process_option(option):
diff --git a/aten/src/aten/gen.py b/aten/src/aten/gen.py
index d073466..11da8ad 100644
--- a/aten/src/aten/gen.py
+++ b/aten/src/aten/gen.py
@@ -99,6 +99,7 @@
sname = '' if scalar_name == "Float" else scalar_name
env['THStorage'] = 'THCuda{}Storage'.format(sname)
env['THTensor'] = 'THCuda{}Tensor'.format(sname)
+ env['THIndexTensor'] = 'THCudaLongTensor'.format(scalar_name)
env['state'] = ['context->thc_state']
env['isCUDA'] = 'true'
env['storage_device'] = 'return storage->device;'
@@ -106,6 +107,7 @@
env['th_header'] = "TH/TH.h"
env['THStorage'] = "TH{}Storage".format(scalar_name)
env['THTensor'] = 'TH{}Tensor'.format(scalar_name)
+ env['THIndexTensor'] = 'THLongTensor'.format(scalar_name)
env['state'] = []
env['isCUDA'] = 'false'
env['storage_device'] = 'throw std::runtime_error("CPU storage has no device");'