|  | #pragma once | 
|  |  | 
|  | #include <c10/core/ScalarType.h> | 
|  | #include <c10/macros/Macros.h> | 
|  | #include <c10/util/Load.h> | 
|  | #include <c10/util/TypeCast.h> | 
|  |  | 
|  | namespace c10 { | 
|  |  | 
|  | // Dynamic type casting utils: | 
|  | // - fetch_and_cast | 
|  | // - cast_and_store | 
|  | // | 
|  | // fetch_and_cast fetch a value with dynamic type specified by a ScalarType | 
|  | // from a void pointer and cast it to a static type. | 
|  | // | 
|  | // cast_and_store casts a static typed value into dynamic type specified | 
|  | // by a ScalarType, and store it into a void pointer. | 
|  | // | 
|  | // NOTE: | 
|  | // | 
|  | // Dynamic casting allows us to support type promotion without blowing up | 
|  | // the combination space: For example, without dynamic cast, in order to | 
|  | // implement `add_` with type promotion, we would need something like | 
|  | // | 
|  | // AT_DISPATCH_ALL_TYPES(output.dtype(), | 
|  | //    AT_DISPATCH_ALL_TYPES(input1.dtype(), | 
|  | //       AT_DISPATCH_ALL_TYPES(input2.dtype(), | 
|  | //           [](arg0_t a, arg1_t b) -> out_t { return a + b; } | 
|  | //       ) | 
|  | //    ) | 
|  | // ) | 
|  | // | 
|  | // If we support N dtypes, the above code would generate the a+b kernel for | 
|  | // all the N * N * N different supported types, the compilation time and | 
|  | // binary size would become horrible. | 
|  | // | 
|  | // Dynamic casting might sounds like a bad idea in terms of performance. | 
|  | // Especially if you ever do it in a loop, you are going to do a billion tests. | 
|  | // But in practice it is not as bad as it might look: | 
|  | // | 
|  | // - on CPU, this is a branch that always has the same outcome, therefore | 
|  | //   hopefully the branch predictor could do the job pretty well | 
|  | // - on GPU, these branches will not diverge, so we could still have the same | 
|  | //   warp executing the same line of code | 
|  | // - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to | 
|  | //   check an integer does not hurt the performance much because the ALUs would | 
|  | //   wait for load instructions anyway. | 
|  | // | 
|  | // For the discussion and benchmark, refer to: | 
|  | // - https://github.com/pytorch/pytorch/pull/28343 | 
|  | // - https://github.com/pytorch/pytorch/pull/28344 | 
|  | // - https://github.com/pytorch/pytorch/pull/28345 | 
|  | // | 
|  |  | 
|  | #ifdef C10_HOST_DEVICE | 
|  | #define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); | 
|  | #else | 
|  | #define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); | 
|  | #endif | 
|  |  | 
|  | // Fetch a value with dynamic type src_type from ptr, and cast it to static type | 
|  | // dest_t. | 
|  | #define FETCH_AND_CAST_CASE(type, scalartype) \ | 
|  | case ScalarType::scalartype:                \ | 
|  | return c10::convert<dest_t>(c10::load<type>(ptr)); | 
|  |  | 
|  | template <typename dest_t> | 
|  | C10_HOST_DEVICE inline dest_t fetch_and_cast( | 
|  | const ScalarType src_type, | 
|  | const void* ptr) { | 
|  | switch (src_type) { | 
|  | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) | 
|  | FETCH_AND_CAST_CASE(uint16_t, UInt16) | 
|  | FETCH_AND_CAST_CASE(uint32_t, UInt32) | 
|  | FETCH_AND_CAST_CASE(uint64_t, UInt64) | 
|  | default: | 
|  | ERROR_UNSUPPORTED_CAST | 
|  | } | 
|  | return dest_t(0); // just to avoid compiler warning | 
|  | } | 
|  |  | 
|  | // Cast a value with static type src_t into dynamic dest_type, and store it to | 
|  | // ptr. | 
|  | #define CAST_AND_STORE_CASE(type, scalartype) \ | 
|  | case ScalarType::scalartype:                \ | 
|  | *(type*)ptr = c10::convert<type>(value);  \ | 
|  | return; | 
|  | template <typename src_t> | 
|  | C10_HOST_DEVICE inline void cast_and_store( | 
|  | const ScalarType dest_type, | 
|  | void* ptr, | 
|  | src_t value) { | 
|  | switch (dest_type) { | 
|  | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) | 
|  | CAST_AND_STORE_CASE(uint16_t, UInt16) | 
|  | CAST_AND_STORE_CASE(uint32_t, UInt32) | 
|  | CAST_AND_STORE_CASE(uint64_t, UInt64) | 
|  | default:; | 
|  | } | 
|  | ERROR_UNSUPPORTED_CAST | 
|  | } | 
|  |  | 
|  | #define DEFINE_UNCASTABLE(T, scalartype_)                     \ | 
|  | template <>                                                 \ | 
|  | C10_HOST_DEVICE inline T fetch_and_cast<T>(                 \ | 
|  | const ScalarType src_type, const void* ptr) {           \ | 
|  | CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type);  \ | 
|  | return c10::load<T>(ptr);                                 \ | 
|  | }                                                           \ | 
|  | template <>                                                 \ | 
|  | C10_HOST_DEVICE inline void cast_and_store<T>(              \ | 
|  | const ScalarType dest_type, void* ptr, T value) {       \ | 
|  | CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ | 
|  | *(T*)ptr = value;                                         \ | 
|  | } | 
|  |  | 
|  | AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) | 
|  |  | 
|  | #undef FETCH_AND_CAST_CASE | 
|  | #undef CAST_AND_STORE_CASE | 
|  | #undef DEFINE_UNCASTABLE | 
|  | #undef ERROR_UNSUPPORTED_CAST | 
|  |  | 
|  | } // namespace c10 |