blob: d2ba120f02e972b744ed25116badd5e422fd397d [file] [log] [blame]
from . import CWrapPlugin
from string import Template
class AssertNDim(CWrapPlugin):
PRE_CODE_TEMPLATE = Template(
"""if(THTensor_(nDimension)(LIBRARY_STATE ${arg_op}) != ${dim_value}) {
THError("Expected argument %s to have %d dimension(s), but has %d",
"${op}", ${dim_value}, THTensor_(nDimension)(LIBRARY_STATE ${arg_op}));
}
""")
def process_option_code_template(self, template, option):
new_code_pre = []
for _, arg in enumerate(option['arguments']):
if 'assert_ndim' not in arg:
continue
dim_value = arg.get('assert_ndim')
op = arg.get('assign_name', arg['name'])
arg_op = "arg_" + op
new_code_pre.append(self.PRE_CODE_TEMPLATE.substitute(op=op,
arg_op=arg_op,
dim_value=dim_value))
template = new_code_pre + template
return template