blob: 2bb906120c06c791cea8402f20bd8febfa12d50d [file] [log] [blame]
from code_template import CodeTemplate
CASE_TEMPLATE = CodeTemplate("""\
case ${TypeID}:
return F<${ScalarType}>::${Backend}(the_type,std::forward<Args>(args)...);
""")
MACRO_TEMPLATE = CodeTemplate("""\
#pragma once
namespace at {
template<template <typename> class F, typename ... Args>
auto dispatch(const Type & the_type, Args&&... args)
-> decltype(F<double>::CPU(the_type,std::forward<Args>(args)...)) {
switch(the_type.ID()) {
${cases}
default:
runtime_error("dispatch() not implemented for '%s'",the_type.toString());
}
}
}
""")
def create_dispatch(all_types):
cases = []
for typ in all_types:
if typ['Density'] != 'Sparse':
cases.append(CASE_TEMPLATE.substitute(typ))
return MACRO_TEMPLATE.substitute(cases=cases)
def create(all_types):
return create_dispatch(all_types)