blob: e749942533eeb2cead0c79787662ef6bc241f59a [file] [log] [blame]
from code_template import CodeTemplate
FILE = CodeTemplate("""\
#include "TH/TH.h"
#ifdef AT_CUDA_ENABLED
#undef THNN_
#include "THC/THC.h"
#endif
#include "ATen/Utils.h"
${copy_includes}
namespace at {
${copy_functions}
}
""")
CASE = CodeTemplate("""\
case ${src_id}:
${THTensor}_copy${cuda}${src_scalar_name}(${state,}dst_->tensor,static_cast<${src_tensor}*>(src.pImpl)->tensor);
break;
""")
FUNCTION = CodeTemplate("""\
void ${Type}::copy(const Tensor & src, Tensor & dst) {
// code generated by function_wrapper
auto dst_ = checked_cast<${Tensor}>(dst.pImpl,"dst",0,false);
(void) dst_; //silence unused warning
switch(src.type().ID()) {
${copy_body}
default:
runtime_error("copy does not support %s to %s copy.",src.type().toString(),toString());
break;
}
dst.pImpl->setScalar(src.pImpl->isScalar());
}
""")
def create_one(env, all_types):
copy_body = []
for src_type in all_types:
if env['Density'] == 'Sparse' or src_type['Density'] == 'Sparse':
# skip sparse copies, which are not yet implemented
continue
state = []
cuda = ''
if src_type['Backend'] == 'CUDA':
cuda = 'Cuda'
if env['Backend'] == 'CUDA' or src_type['Backend'] == 'CUDA':
state.append('context->thc_state')
copy_body.append(CASE.substitute(env,
src_scalar_name=src_type['ScalarName'],
src_id=src_type['TypeID'],
src_tensor=src_type['Tensor'],
cuda=cuda,
state=state,
))
return FUNCTION.substitute(env, copy_body=copy_body)
def create(all_types):
top_env = {
'copy_includes': [],
'copy_functions': [],
}
for dst_type in all_types:
top_env['copy_includes'].append(
'#include "ATen/{}.h"'.format(dst_type['Type']))
top_env['copy_includes'].append(
'#include "ATen/{}.h"'.format(dst_type['Tensor']))
top_env['copy_functions'].append(create_one(dst_type, all_types))
return FILE.substitute(top_env)