switch dispatch to function
diff --git a/aten/src/aten/CMakeLists.txt b/aten/src/aten/CMakeLists.txt
index 47904b2..0c92a88 100644
--- a/aten/src/aten/CMakeLists.txt
+++ b/aten/src/aten/CMakeLists.txt
@@ -59,7 +59,6 @@
FIND_PACKAGE(CUDA 5.5)
IF(CUDA_FOUND)
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
- INCLUDE_DIRECTORIES("${CUDA_SDK_ROOT_DIR}/common/inc")
IF(NOT THC_LIBRARIES)
SET(THC_LIBRARIES "THC")
@@ -81,7 +80,6 @@
ENDIF(NOT THCUNN_LIBRARIES)
MESSAGE(STATUS "THCUNN_LIBRARIES: ${THCUNN_LIBRARIES}")
-
ENDIF()
ENDIF()
@@ -142,7 +140,6 @@
if (NOT RETURN_VALUE EQUAL 0)
message(STATUS ${generated_cpp})
message(FATAL_ERROR "Failed to get generated_cpp list")
-
endif()
FILE(GLOB_RECURSE all_templates "templates/*")
diff --git a/aten/src/aten/Context.cpp b/aten/src/aten/Context.cpp
index b2b7ad1..90056a0 100644
--- a/aten/src/aten/Context.cpp
+++ b/aten/src/aten/Context.cpp
@@ -1,6 +1,7 @@
#include "Context.h"
#include <thread>
#include <mutex>
+#include <sstream>
#ifdef TENSORLIB_CUDA_ENABLED
#include "THC/THC.h"
@@ -13,10 +14,16 @@
static inline void errorHandler(const char * msg, void * data) {
throw std::runtime_error(msg);
}
+static inline void argErrorHandler(int arg, const char * msg, void * data) {
+ std::stringstream new_error;
+ new_error << "invalid argument " << arg << ": " << msg;
+ throw std::runtime_error(new_error.str());
+}
Context::Context() {
THSetDefaultErrorHandler(errorHandler,nullptr);
+ THSetDefaultArgErrorHandler(argErrorHandler,nullptr);
#ifdef TENSORLIB_CUDA_ENABLED
thc_state = THCState_alloc();
diff --git a/aten/src/aten/dispatch_macros.py b/aten/src/aten/dispatch_macros.py
index 16c5d4b..d3393fd 100644
--- a/aten/src/aten/dispatch_macros.py
+++ b/aten/src/aten/dispatch_macros.py
@@ -2,39 +2,32 @@
CASE_TEMPLATE = CodeTemplate("""\
case ${TypeID}:
- the_function<${specializations}>(the_type,__VA_ARGS__);
- break;
+ return F<${ScalarType}>::${Backend}(the_type,std::forward<Args>(args)...);
""")
MACRO_TEMPLATE = CodeTemplate("""\
-#define ${macro_name}(the_type,the_function,...)
+#pragma once
+
+namespace tlib {
+
+template<template <typename> class F, typename ... Args>
+auto dispatch(const Type & the_type, Args&&... args)
+ -> decltype(F<double>::CPU(the_type,std::forward<Args>(args)...)) {
switch(the_type.ID()) {
${cases}
}
+}
+
+}
""")
-def create_dispatch(all_types, include_type, include_backend):
+def create_dispatch(all_types):
cases = []
- macro_name = "TLIB_DISPATCH"
- if include_type:
- macro_name += "_TYPE"
- if include_backend:
- macro_name += "_PROCESSOR"
for typ in all_types:
- specializations = []
- if include_type:
- specializations.append(typ['ScalarType'])
- if include_backend:
- specializations.append('tlib::{}Tag'.format(typ['Backend']))
- cases.append(CASE_TEMPLATE.substitute(
- typ, specializations=specializations))
- the_macro = MACRO_TEMPLATE.substitute(macro_name=macro_name, cases=cases)
- # end lines in backslashes to make defines
- return '\\\n'.join(the_macro.split('\n')) + '\n'
+ cases.append(CASE_TEMPLATE.substitute(typ))
+ return MACRO_TEMPLATE.substitute(cases=cases)
def create(all_types):
- return "#pragma once\n\n" + (create_dispatch(all_types, True, False) +
- create_dispatch(all_types, False, True) +
- create_dispatch(all_types, True, True))
+ return create_dispatch(all_types)
diff --git a/aten/src/aten/templates/Tensor.h b/aten/src/aten/templates/Tensor.h
index 5a1fe27..3aaeea5 100644
--- a/aten/src/aten/templates/Tensor.h
+++ b/aten/src/aten/templates/Tensor.h
@@ -94,6 +94,9 @@
Tensor toBackend(Backend b) {
return toType(type().toBackend(b));
}
+ int64_t dim() const {
+ return ndimension();
+ }
template<typename T>
T * data() const;
diff --git a/aten/src/aten/test/scalar_test.cpp b/aten/src/aten/test/scalar_test.cpp
index 1402134..9d4118f 100644
--- a/aten/src/aten/test/scalar_test.cpp
+++ b/aten/src/aten/test/scalar_test.cpp
@@ -9,13 +9,20 @@
constexpr auto Double = ScalarType::Float;
template<typename scalar_type>
-void foo(const Type & t, Tensor a, Tensor b) {
- scalar_type s = 1;
- cout << "hello, dispatch: " << t.toString() << s << "\n";
- auto data = (scalar_type*)a.data_ptr();
-}
+struct Foo {
+ static void CPU(const Type & t, Tensor a, Tensor b) {
+ scalar_type s = 1;
+ cout << "hello, dispatch: " << t.toString() << s << "\n";
+ auto data = (scalar_type*)a.data_ptr();
+ }
+ static void CUDA(const Type & t, Tensor a, Tensor b) {
+ }
+};
template<>
-void foo<Half>(const Type & t, Tensor a, Tensor b) {}
+struct Foo<Half> {
+ static void CPU(const Type & t, Tensor a, Tensor b) {}
+ static void CUDA(const Type & t, Tensor a, Tensor b) {}
+};
int main() {
Scalar what = 257;
@@ -59,8 +66,7 @@
cout << r << "\n";
cout << T.randn({10,10,2}) << "\n";
- TLIB_DISPATCH_TYPE(x.type(),foo,x,prev_h);
-
+ dispatch<Foo>(x.type(),x,prev_h);
return 0;
}