blob: f92b72d6708071de6cecd3419653d183918f9369 [file] [log] [blame]
import re
import common_with_cwrap
from collections import OrderedDict
TYPE_TRANSLATIONS = {
'THIndex_t': 'long',
'THCTensor*': 'THTensor*',
}
def argument_to_declaration(arg):
typ = TYPE_TRANSLATIONS.get(arg.type, arg.type)
result = {
'arg': typ + ' ' + arg.name,
}
if arg.is_optional:
result['default'] = 'nullptr'
return result
def function_to_declaration(func, backend):
return {
'mode': 'NN',
'name': func.name,
'types': ['Float', 'Double'],
# skip state argument...
'arguments': [argument_to_declaration(a) for a in func.arguments[1:]],
'backends': [backend],
'variants': ['function'],
}
include_only = '(updateOutput|updateGradInput|accGradParameters|backward)$'
exclude = 'LookupTable'
def run(paths):
functions = OrderedDict()
for path in paths:
backend = 'CUDA' if re.search('THCU', path) else 'CPU'
for func in common_with_cwrap.parse_header(path):
if re.search(include_only, func.name) is None or re.search(exclude, func.name) is not None:
continue
if func.name in functions:
functions[func.name]['backends'].append(backend)
else:
functions[func.name] = function_to_declaration(func, backend)
declarations = [f for _, f in functions.items()]
return declarations