| from code_template import CodeTemplate |
| |
| FILE = CodeTemplate("""\ |
| #include "TH/TH.h" |
| #ifdef AT_CUDA_ENABLED |
| #undef THNN_ |
| #include "THC/THC.h" |
| #endif |
| #include "ATen/Utils.h" |
| ${copy_includes} |
| |
| namespace at { |
| |
| ${copy_functions} |
| |
| } |
| """) |
| |
| CASE = CodeTemplate("""\ |
| case ${src_id}: |
| ${THTensor}_copy${cuda}${src_scalar_name}(${state,}dst_->tensor,static_cast<${src_tensor}*>(src.pImpl)->tensor); |
| break; |
| """) |
| |
| FUNCTION = CodeTemplate("""\ |
| void ${Type}::copy(const Tensor & src, Tensor & dst) { |
| // code generated by function_wrapper |
| auto dst_ = checked_cast<${Tensor}>(dst.pImpl,"dst",0,false); |
| (void) dst_; //silence unused warning |
| switch(src.type().ID()) { |
| ${copy_body} |
| default: |
| runtime_error("copy does not support %s to %s copy.",src.type().toString(),toString()); |
| break; |
| } |
| dst.pImpl->setScalar(src.pImpl->isScalar()); |
| } |
| """) |
| |
| |
| def create_one(env, all_types): |
| copy_body = [] |
| for src_type in all_types: |
| if env['Density'] == 'Sparse' or src_type['Density'] == 'Sparse': |
| # skip sparse copies, which are not yet implemented |
| continue |
| state = [] |
| cuda = '' |
| if src_type['Backend'] == 'CUDA': |
| cuda = 'Cuda' |
| if env['Backend'] == 'CUDA' or src_type['Backend'] == 'CUDA': |
| state.append('context->thc_state') |
| copy_body.append(CASE.substitute(env, |
| src_scalar_name=src_type['ScalarName'], |
| src_id=src_type['TypeID'], |
| src_tensor=src_type['Tensor'], |
| cuda=cuda, |
| state=state, |
| )) |
| return FUNCTION.substitute(env, copy_body=copy_body) |
| |
| |
| def create(all_types): |
| |
| top_env = { |
| 'copy_includes': [], |
| 'copy_functions': [], |
| } |
| for dst_type in all_types: |
| top_env['copy_includes'].append( |
| '#include "ATen/{}.h"'.format(dst_type['Type'])) |
| top_env['copy_includes'].append( |
| '#include "ATen/{}.h"'.format(dst_type['Tensor'])) |
| top_env['copy_functions'].append(create_one(dst_type, all_types)) |
| return FILE.substitute(top_env) |