|  | #include "caffe2/operators/batch_matmul_op.h" | 
|  |  | 
|  | #include "caffe2/core/context_gpu.h" | 
|  |  | 
|  | namespace caffe2 { | 
|  |  | 
|  | template <> | 
|  | bool BatchMatMulOp<CUDAContext, DefaultEngine>::RunOnDevice() { | 
|  | return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0)); | 
|  | } | 
|  |  | 
|  | REGISTER_CUDA_OPERATOR(BatchMatMul, BatchMatMulOp<CUDAContext>); | 
|  |  | 
|  |  | 
|  | #if !defined(USE_ROCM) | 
|  |  | 
|  | template <> | 
|  | bool BatchMatMulOp<CUDAContext, TensorCoreEngine>::RunOnDevice() { | 
|  | return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0)); | 
|  | } | 
|  |  | 
|  | REGISTER_CUDA_OPERATOR_WITH_ENGINE( | 
|  | BatchMatMul, | 
|  | TENSORCORE, | 
|  | BatchMatMulOp<CUDAContext, TensorCoreEngine>); | 
|  |  | 
|  | #endif | 
|  |  | 
|  | } // namespace caffe2 |