blob: e239c65d0ae972e1b6fed4cca6cdd350b7c6385e [file] [log] [blame]
import re
def to_environment_type(env, arg_string):
"""
Convert, THTensor, THStorage, THLongTensor, etc. to a tuple of:
1. the "Container" - e.g. Tensor or Storage
2. the backend - e.g. CPU, or CUDA, from env
3. the "Scalar Type" - e.g. long, int, from the arg_string or env
in the case it is generic
"""
# backend is always taken from env
backend = env['Backend']
# scalar name might be taken from env
scalar_fallback = env['ScalarName']
# Attempt to match on THTensor
tensor_re = re.compile(r"TH([a-zA-Z]*)Tensor\*")
match = tensor_re.match(arg_string)
if match is not None:
# if the Tensor has a type specified, use it
scalar_type = (match.group(1) if len(match.group(1)) > 0 else
scalar_fallback)
return ('Tensor', backend, scalar_type)
# Attempt to match on THStorage
storage_re = re.compile(r"TH([a-zA-Z]*)Storage\*")
match = storage_re.match(arg_string)
if match is not None:
# if the Storage has a type specified, use it
scalar_type = (match.group(1) if len(match.group(1)) > 0 else
scalar_fallback)
return ('Storage', backend, scalar_type)
# Handle THSize? also should this be Long or int64_t
if arg_string == 'THSize*':
return ('Storage', 'CPU', 'Long')
# for now, just return None, up to you what the failure case is
return None
cpu_env = {
'Backend': 'CPU',
'ScalarName': 'Float',
}
cuda_env = {
'Backend': 'CUDA',
'ScalarName': 'Half',
}
print(to_environment_type(cpu_env, 'THTensor*'))
print(to_environment_type(cpu_env, 'THLongStorage*'))
print(to_environment_type(cuda_env, 'THIntTensor*'))
print(to_environment_type(cuda_env, 'THSize*'))