fp16 support for FullyConnected op(Fixed)
Summary: This diff resloved some issues in reverted PR246.
Differential Revision: D4911821
fbshipit-source-id: 0a6fa47f4c2405475697e40fb926758c534f8ef7
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
index b232850..8d33a7c 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu.cc
@@ -5,8 +5,11 @@
namespace caffe2 {
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
- FC, NERVANA, FullyConnectedOp<float, CUDAContext, NervanaEngine>);
+ FC,
+ NERVANA,
+ FullyConnectedOp<CUDAContext, NervanaEngine>);
REGISTER_CUDA_OPERATOR_WITH_ENGINE(
- FCGradient, NERVANA,
- FullyConnectedGradientOp<float, CUDAContext, NervanaEngine>);
+ FCGradient,
+ NERVANA,
+ FullyConnectedGradientOp<CUDAContext, NervanaEngine>);
} // namespace caffe2
diff --git a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
index a3ae3bb..3eb0fc3 100644
--- a/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
+++ b/caffe2/contrib/nervana/nervana_fc_op_gpu_test.cc
@@ -49,7 +49,7 @@
AddConstInput(std::vector<int>{6, 10}, 1., "W", &ws);
AddConstInput(std::vector<int>{6}, 0.1, "B", &ws);
unique_ptr<OperatorBase> op(
- new FullyConnectedOp<float, CUDAContext, NervanaEngine>(def, &ws));
+ new FullyConnectedOp<CUDAContext, NervanaEngine>(def, &ws));
EXPECT_NE(nullptr, op.get());
EXPECT_TRUE(op->Run());
Blob* Yblob = ws.GetBlob("Y");
diff --git a/caffe2/contrib/nervana/nervana_math_gpu.cc b/caffe2/contrib/nervana/nervana_math_gpu.cc
index f3010b9..09c70e4 100644
--- a/caffe2/contrib/nervana/nervana_math_gpu.cc
+++ b/caffe2/contrib/nervana/nervana_math_gpu.cc
@@ -11,10 +11,18 @@
// limitation that the data has to be contiguous in memory.
template <>
void Gemm<float, CUDAContext, NervanaEngine>(
- const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const float alpha, const float* A,
- const float* B, const float beta, float* C, CUDAContext* context) {
-
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
diff --git a/caffe2/operators/elementwise_op.cu b/caffe2/operators/elementwise_op.cu
index 016a555..e4c89a3 100644
--- a/caffe2/operators/elementwise_op.cu
+++ b/caffe2/operators/elementwise_op.cu
@@ -5,6 +5,7 @@
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/elementwise_op.h"
+#include "caffe2/utils/conversions.h"
namespace caffe2 {
@@ -62,9 +63,6 @@
name, BinaryElementwiseOp< \
input_type, CUDAContext, Cuda##name##Functor, output_type>)
-#define CUDA_ADD(x, y) ((x) + (y))
-CUDA_FUNCTOR(Add, CUDA_ADD, NumericTypes, SameTypeAsInput);
-#undef CUDA_ADD
#define CUDA_SUB(x, y) ((x) - (y))
CUDA_FUNCTOR(Sub, CUDA_SUB, NumericTypes, SameTypeAsInput);
#undef CUDA_SUB
@@ -264,4 +262,165 @@
REGISTER_CUDA_OPERATOR(SumReduceLike, SumReduceLikeOp<CUDAContext>);
+namespace {
+
+template <bool is_scaler, typename T, typename M>
+__global__ void binary_add_kernel(const int N, const T* a, const T* b, T* r) {
+ CUDA_1D_KERNEL_LOOP(idx, N) {
+ r[idx] = convert::To<M, T>(
+ convert::To<T, M>(a[idx]) +
+ convert::To<T, M>(is_scaler ? b[0] : b[idx]));
+ }
+}
+
+template <bool no_post, typename T, typename M>
+__global__ void binary_add_kernel_broadcast(
+ const T* a,
+ const T* b,
+ T* r,
+ const int pre,
+ const int post,
+ const int n) {
+ CUDA_1D_KERNEL_LOOP(idx, no_post ? pre * n : pre * post * n) {
+ r[idx] = convert::To<M, T>(
+ convert::To<T, M>(a[idx]) +
+ convert::To<T, M>(no_post ? b[idx % n] : b[(idx / post) % n]));
+ }
+}
+} // namespace
+
+// Actual Add operator, because the above macros are read-only.
+class CUDAAddOp final : public Operator<CUDAContext> {
+ public:
+ CUDAAddOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<CUDAContext>(operator_def, ws),
+ OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
+ OP_SINGLE_ARG(int, "axis", axis_, -1),
+ OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
+ OP_SINGLE_ARG(string, "order", order_, "NCHW") {
+ // Figure out the correct axis to use.
+ if (enable_broadcast_) {
+ if (axis_ != -1) {
+ // Get axis from an explicit axis argument.
+ CAFFE_ENFORCE_EQ(
+ axis_str_.size(),
+ 0,
+ "Args axis and axis_str cannot be used simultaneously.");
+ } else if (axis_str_.size()) {
+ // Get the axis index semantically.
+ CAFFE_ENFORCE_EQ(
+ axis_str_.size(), 1, "Unsupported axis string", axis_str_);
+ size_t semantic_axis_ = order_.find(axis_str_);
+ CAFFE_ENFORCE_NE(
+ semantic_axis_,
+ string::npos,
+ "Unrecognizable axis string ",
+ axis_str_,
+ " from order string ",
+ order_);
+ axis_ = semantic_axis_;
+ }
+ } else {
+ CAFFE_ENFORCE(
+ axis_ == -1 && axis_str_.size() == 0,
+ "Do not specify axis or axis_str if broadcast is not enabled.");
+ }
+ }
+
+ ~CUDAAddOp() {}
+
+ template <typename T, typename M>
+ bool DoRunWithType() {
+ auto& X0 = Input(0);
+ auto& X1 = Input(1);
+ auto* output = Output(0);
+
+ output->ResizeLike(X0);
+
+ const T* X0data = X0.template data<T>();
+ const T* X1data = X1.template data<T>();
+ T* outputData = output->template mutable_data<T>();
+
+ if (!enable_broadcast_) {
+ CAFFE_ENFORCE_EQ(
+ X0.dims(),
+ X1.dims(),
+ "Dimension mismatch - did you forget to set broadcast=1?");
+ binary_add_kernel<false, T, M><<<
+ CAFFE_GET_BLOCKS(X0.size()),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(X0.size(), X0data, X1data, outputData);
+ } else if (X1.size() == 1) {
+ binary_add_kernel<true, T, M><<<
+ CAFFE_GET_BLOCKS(X0.size()),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(X0.size(), X0data, X1data, outputData);
+ } else {
+ CAFFE_ENFORCE_GT(
+ X0.ndim(),
+ X1.ndim(),
+ "If you are doing broadcasting, input1 should have "
+ "a smaller number of dimensions.");
+ const int axis = (axis_ == -1 ? X0.ndim() - X1.ndim() : axis_);
+ CAFFE_ENFORCE(
+ axis >= 0 && axis < X0.ndim(),
+ "Broadcast axis should be in the range of the number "
+ "of dimensions of the first input.");
+ size_t pre = 1, n = 1, post = 1;
+ for (int i = 0; i < axis; ++i) {
+ pre *= X0.dim(i);
+ }
+ for (int i = 0; i < X1.ndim(); ++i) {
+ CAFFE_ENFORCE_EQ(
+ X0.dim(i + axis), X1.dim(i), "Broadcast dimension mismatch.");
+ n *= X1.dim(i);
+ }
+ for (int i = axis + X1.ndim(); i < X0.ndim(); ++i) {
+ post *= X0.dim(i);
+ }
+
+ if (post == 1) {
+ binary_add_kernel_broadcast<true, T, M><<<
+ CAFFE_GET_BLOCKS(pre * n),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(X0data, X1data, outputData, pre, post, n);
+ } else {
+ binary_add_kernel_broadcast<false, T, M><<<
+ CAFFE_GET_BLOCKS(pre * post * n),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context_.cuda_stream()>>>(X0data, X1data, outputData, pre, post, n);
+ }
+ }
+ return true;
+ }
+
+ bool RunOnDevice() override {
+ if (Input(0).IsType<float>()) {
+ return DoRunWithType<float, float>();
+ } else if (Input(0).IsType<float16>()) {
+ return DoRunWithType<float16, float>();
+ } else if (Input(0).IsType<int32_t>()) {
+ return DoRunWithType<int32_t, int32_t>();
+ } else if (Input(0).IsType<int64_t>()) {
+ return DoRunWithType<int64_t, int64_t>();
+ } else {
+ return false;
+ }
+ }
+
+ private:
+ bool enable_broadcast_;
+ int axis_;
+ string axis_str_;
+ string order_;
+};
+
+namespace {
+REGISTER_CUDA_OPERATOR(Add, CUDAAddOp);
+} // namespace
+
} // namespace caffe2
diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc
index 7a0e0bc..c00f199 100644
--- a/caffe2/operators/fully_connected_op.cc
+++ b/caffe2/operators/fully_connected_op.cc
@@ -3,8 +3,8 @@
namespace caffe2 {
namespace {
-REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(FC, FullyConnectedOp<CPUContext>);
+REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<CPUContext>);
OPERATOR_SCHEMA(FC)
.NumInputs(3)
diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h
index 6e24c91..45adc0a 100644
--- a/caffe2/operators/fully_connected_op.h
+++ b/caffe2/operators/fully_connected_op.h
@@ -3,12 +3,13 @@
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
+#include "caffe2/utils/conversions.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
// This is Caffe's InnerProductOp, with a name that fits its purpose better.
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
class FullyConnectedOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -17,7 +18,13 @@
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedOp() {}
- bool RunOnDevice() override {
+ template <
+ typename T_X,
+ typename T_W,
+ typename T_B,
+ typename T_Y,
+ typename MATH>
+ bool DoRunWithType() {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& b = Input(2);
@@ -63,44 +70,53 @@
Y->Resize(Y_shape_cache_);
CAFFE_ENFORCE(M * N == Y->size(), dimErrorString());
- // X * W^T
- math::Gemm<T, Context, Engine>(
+ // W * x
+ math::Gemm<T_X, Context, Engine>(
CblasNoTrans,
CblasTrans,
M,
N,
K,
1,
- X.template data<T>(),
- W.template data<T>(),
+ X.template data<T_X>(),
+ W.template data<T_W>(),
0,
- Y->template mutable_data<T>(),
+ Y->template mutable_data<T_Y>(),
&context_);
// Add bias term
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it with one.
bias_multiplier_.Resize(M);
- math::Set<T, Context>(
+ math::Set<T_B, Context>(
M,
- static_cast<T>(1),
- bias_multiplier_.template mutable_data<T>(),
+ convert::To<float, T_B>(1),
+ bias_multiplier_.template mutable_data<T_B>(),
&context_);
}
- math::Gemm<T, Context, Engine>(
+ math::Gemm<T_B, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
N,
1,
1,
- bias_multiplier_.template data<T>(),
- b.template data<T>(),
+ bias_multiplier_.template data<T_B>(),
+ b.template data<T_B>(),
1,
- Y->template mutable_data<T>(),
+ Y->template mutable_data<T_Y>(),
&context_);
return true;
}
+ bool RunOnDevice() override {
+ return DoRunWithType<
+ float, // X
+ float, // W
+ float, // B
+ float, // Y
+ float>(); // Math
+ }
+
protected:
size_t axis_{1};
// A local vector to cache the output shape so we don't need to recreate
@@ -109,7 +125,7 @@
Tensor<Context> bias_multiplier_;
};
-template <typename T, class Context, class Engine = DefaultEngine>
+template <class Context, class Engine = DefaultEngine>
class FullyConnectedGradientOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -118,7 +134,16 @@
axis_(OperatorBase::GetSingleArgument<int32_t>("axis", 1)) {}
~FullyConnectedGradientOp() {}
- bool RunOnDevice() override {
+ template <
+ typename T_X,
+ typename T_W,
+ typename T_DY,
+ typename T_B,
+ typename T_DX,
+ typename T_DW,
+ typename T_DB,
+ typename MATH>
+ bool DoRunWithType() {
const auto& X = Input(0);
const auto& W = Input(1);
const auto& dY = Input(2);
@@ -137,60 +162,72 @@
db->Resize(N);
// Compute dW
- math::Gemm<T, Context, Engine>(
+ math::Gemm<T_DY, Context, Engine>(
CblasTrans,
CblasNoTrans,
N,
K,
M,
- 1,
- dY.template data<T>(),
- X.template data<T>(),
- 0,
- dW->template mutable_data<T>(),
+ convert::To<float, MATH>(1),
+ dY.template data<T_DY>(),
+ X.template data<T_X>(),
+ convert::To<float, MATH>(0),
+ dW->template mutable_data<T_DW>(),
&context_);
if (bias_multiplier_.size() != M) {
// If the helper bias multiplier is not M, reshape and fill it
// with one.
bias_multiplier_.Resize(M);
- math::Set<T, Context>(
+ math::Set<T_B, Context>(
M,
- static_cast<T>(1),
- bias_multiplier_.template mutable_data<T>(),
+ convert::To<float, T_B>(1),
+ bias_multiplier_.template mutable_data<T_B>(),
&context_);
}
// Compute dB
- math::Gemv<T, Context>(
+ math::Gemv<T_DY, Context>(
CblasTrans,
M,
N,
- 1,
- dY.template data<T>(),
- bias_multiplier_.template data<T>(),
- 0,
- db->template mutable_data<T>(),
+ convert::To<float, MATH>(1),
+ dY.template data<T_DY>(),
+ bias_multiplier_.template data<T_B>(),
+ convert::To<float, MATH>(0),
+ db->template mutable_data<T_DB>(),
&context_);
// Compute dX
if (OutputSize() == 3) {
auto* dX = Output(2);
dX->ResizeLike(X);
- math::Gemm<T, Context, Engine>(
+ math::Gemm<T_DX, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
M,
K,
N,
- 1,
- dY.template data<T>(),
- W.template data<T>(),
- 0,
- dX->template mutable_data<T>(),
+ convert::To<float, MATH>(1),
+ dY.template data<T_DY>(),
+ W.template data<T_W>(),
+ convert::To<float, MATH>(0),
+ dX->template mutable_data<T_DX>(),
&context_);
}
return true;
}
+ bool RunOnDevice() override {
+ return DoRunWithType<
+ float, // X
+ float, // W
+ float, // dY
+ float, // B
+ float, // dX
+ float, // dW
+ float, // dB
+ float>(); // Math
+ }
+
protected:
size_t axis_{1};
Tensor<Context> bias_multiplier_;
diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc
index 8ee67ac..0743186 100644
--- a/caffe2/operators/fully_connected_op_gpu.cc
+++ b/caffe2/operators/fully_connected_op_gpu.cc
@@ -2,9 +2,60 @@
#include "caffe2/operators/fully_connected_op.h"
namespace caffe2 {
+
+template <>
+bool FullyConnectedOp<CUDAContext>::RunOnDevice() {
+ if (Input(0).IsType<float>()) {
+ return DoRunWithType<
+ float, // X
+ float, // W
+ float, // B
+ float, // Y
+ float>(); // Math
+ } else if (Input(0).IsType<float16>()) {
+ return DoRunWithType<
+ float16, // X
+ float16, // W
+ float16, // B
+ float16, // Y
+ float>(); // Math
+ } else {
+ CAFFE_THROW("Unsupported type");
+ }
+ return false;
+}
+
+template <>
+bool FullyConnectedGradientOp<CUDAContext>::RunOnDevice() {
+ if (Input(0).IsType<float>()) {
+ return DoRunWithType<
+ float, // X
+ float, // W
+ float, // dY
+ float, // B
+ float, // dX
+ float, // dW
+ float, // dB
+ float>(); // Math
+ } else if (Input(0).IsType<float16>()) {
+ return DoRunWithType<
+ float16, // X
+ float16, // W
+ float16, // dY
+ float16, // B
+ float16, // dX
+ float16, // dW
+ float16, // dB
+ float>(); // Math
+ } else {
+ CAFFE_THROW("Unsupported type");
+ }
+ return false;
+}
+
namespace {
-REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<float, CUDAContext>);
-REGISTER_CUDA_OPERATOR(FCGradient,
- FullyConnectedGradientOp<float, CUDAContext>);
+
+REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(FCGradient, FullyConnectedGradientOp<CUDAContext>);
} // namespace
} // namespace caffe2
diff --git a/caffe2/operators/sparse_to_dense_op.h b/caffe2/operators/sparse_to_dense_op.h
index d48b617..439d96c 100644
--- a/caffe2/operators/sparse_to_dense_op.h
+++ b/caffe2/operators/sparse_to_dense_op.h
@@ -50,7 +50,6 @@
return DispatchHelper<
TensorTypes2<
float,
- double,
int32_t,
int64_t,
GenericTensorImplementation>,
diff --git a/caffe2/operators/square_root_divide_op.h b/caffe2/operators/square_root_divide_op.h
index 644c2bd..df018bf 100644
--- a/caffe2/operators/square_root_divide_op.h
+++ b/caffe2/operators/square_root_divide_op.h
@@ -17,7 +17,7 @@
: Operator<Context>(operator_def, ws) {}
bool RunOnDevice() override {
- return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
+ return DispatchHelper<TensorTypes<float>>::call(this, Input(DATA));
}
private:
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index 771da9c..99f0f20 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -3,6 +3,12 @@
#include <cmath>
namespace caffe2 {
+
+template <>
+bool WeightedSumOp<CPUContext>::RunOnDevice() {
+ return DoRunWithType<float>();
+}
+
namespace {
REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
@@ -12,10 +18,9 @@
REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>);
REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
-REGISTER_CPU_OPERATOR(Sum, SumOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(SumInt, SumOp<int, CPUContext>);
-
-REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(Sum, SumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
+REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp<CPUContext>);
REGISTER_CPU_OPERATOR(
ScatterWeightedSum,
ScatterWeightedSumOp<float, CPUContext>);
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 1722e14..28321b2 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -250,13 +250,14 @@
}
};
-template <typename T, class Context>
+template <class Context>
class SumOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(SumOp);
- bool RunOnDevice() override {
+ template <typename T, typename M>
+ bool DoRunWithType() {
auto& input0 = Input(0);
auto* output = Output(0);
if (InputSize() == 1) {
@@ -297,6 +298,16 @@
}
return true;
}
+
+ bool RunOnDevice() override {
+ if (Input(0).template IsType<float>()) {
+ return DoRunWithType<float, float>();
+ } else if (Input(0).template IsType<int>()) {
+ return DoRunWithType<int, int>();
+ } else {
+ return false;
+ }
+ }
};
// WeightedSumOp computes the weighted sum of several tensors. The input should
@@ -304,13 +315,14 @@
// shape, and weight_i are size 1 tensors that specifies the weight of each
// vector. Note that if one wants to do in-place computation, it could only be
// done with X_0 also as the output, but not other X_i.
-template <typename T, class Context>
+template <class Context>
class WeightedSumOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_SIMPLE_CTOR_DTOR(WeightedSumOp);
- bool RunOnDevice() override {
+ template <typename DstType>
+ bool DoRunWithType() {
DCHECK_EQ(InputSize() % 2, 0);
auto& X0 = Input(0);
auto& weight0 = Input(1);
@@ -319,11 +331,11 @@
int size = X0.size();
auto* output = Output(0);
output->ResizeLike(X0);
- math::Scale<T, Context>(
+ math::Scale<DstType, Context>(
size,
- weight0.template data<T>(),
- X0.template data<T>(),
- output->template mutable_data<T>(),
+ weight0.template data<float>(),
+ X0.template data<DstType>(),
+ output->template mutable_data<DstType>(),
&context_);
for (int i = 2; i < InputSize(); i += 2) {
auto& X = Input(i);
@@ -338,15 +350,16 @@
auto& weight = Input(i + 1);
DCHECK_EQ(X.size(), size);
DCHECK_EQ(weight.size(), 1);
- math::Axpy<T, Context>(
+ math::Axpy<DstType, Context>(
size,
- weight.template data<T>(),
- X.template data<T>(),
- output->template mutable_data<T>(),
+ weight.template data<float>(),
+ X.template data<DstType>(),
+ output->template mutable_data<DstType>(),
&context_);
}
return true;
}
+ bool RunOnDevice() override;
};
/**
diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc
index b3df226..7d41fa2 100644
--- a/caffe2/operators/utility_ops_gpu.cc
+++ b/caffe2/operators/utility_ops_gpu.cc
@@ -5,6 +5,30 @@
namespace caffe2 {
template <>
+bool WeightedSumOp<CUDAContext>::RunOnDevice() {
+ if (Input(0).IsType<float>()) {
+ return DoRunWithType<float>();
+ } else if (Input(0).IsType<float16>()) {
+ return DoRunWithType<float16>();
+ } else {
+ CAFFE_THROW("Unsupported inputs");
+ }
+ return false;
+}
+
+template <>
+bool SumOp<CUDAContext>::RunOnDevice() {
+ if (Input(0).IsType<float>()) {
+ return DoRunWithType<float, float>();
+ } else if (Input(0).IsType<float16>()) {
+ return DoRunWithType<float16, float16>();
+ } else {
+ CAFFE_THROW("Unsupported inputs");
+ }
+ return false;
+}
+
+template <>
class CopyOnDeviceLikeOp<CUDAContext, CUDAContext, CUDAContext>
: public Operator<CUDAContext> {
public:
@@ -35,9 +59,8 @@
REGISTER_CUDA_OPERATOR(ExpandDims, ExpandDimsOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(Alias, AliasOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(ResizeLike, ResizeLikeOp<CUDAContext>);
-REGISTER_CUDA_OPERATOR(Sum, SumOp<float, CUDAContext>);
-
-REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<float, CUDAContext>);
+REGISTER_CUDA_OPERATOR(Sum, SumOp<CUDAContext>);
+REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp<CUDAContext>);
REGISTER_CUDA_OPERATOR(Shape, ShapeOp<CUDAContext>);
// From whatever the current context, ensure the output is TensorCPU
REGISTER_CUDA_OPERATOR(
diff --git a/caffe2/utils/conversions.h b/caffe2/utils/conversions.h
new file mode 100644
index 0000000..0c6c323
--- /dev/null
+++ b/caffe2/utils/conversions.h
@@ -0,0 +1,182 @@
+#pragma once
+
+#include <caffe2/core/types.h>
+
+#ifdef __CUDA_ARCH__
+#include <cuda_fp16.h>
+#endif
+
+#ifdef __CUDA_ARCH__
+#define CONVERSIONS_DECL __host__ __device__ inline
+#else
+#define CONVERSIONS_DECL inline
+#endif
+
+namespace caffe2 {
+
+namespace convert {
+
+namespace {
+inline float16 cpu_float2half_rn(float f) {
+ float16 ret;
+
+ static_assert(
+ sizeof(unsigned int) == sizeof(float),
+ "Programming error sizeof(unsigned int) != sizeof(float)");
+
+ unsigned* xp = reinterpret_cast<unsigned int*>(&f);
+ unsigned x = *xp;
+ unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
+ unsigned sign, exponent, mantissa;
+
+ // Get rid of +NaN/-NaN case first.
+ if (u > 0x7f800000) {
+ ret.x = 0x7fffU;
+ return ret;
+ }
+
+ sign = ((x >> 16) & 0x8000);
+
+ // Get rid of +Inf/-Inf, +0/-0.
+ if (u > 0x477fefff) {
+ ret.x = sign | 0x7c00U;
+ return ret;
+ }
+ if (u < 0x33000001) {
+ ret.x = (sign | 0x0000);
+ return ret;
+ }
+
+ exponent = ((u >> 23) & 0xff);
+ mantissa = (u & 0x7fffff);
+
+ if (exponent > 0x70) {
+ shift = 13;
+ exponent -= 0x70;
+ } else {
+ shift = 0x7e - exponent;
+ exponent = 0;
+ mantissa |= 0x800000;
+ }
+ lsb = (1 << shift);
+ lsb_s1 = (lsb >> 1);
+ lsb_m1 = (lsb - 1);
+
+ // Round to nearest even.
+ remainder = (mantissa & lsb_m1);
+ mantissa >>= shift;
+ if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
+ ++mantissa;
+ if (!(mantissa & 0x3ff)) {
+ ++exponent;
+ mantissa = 0;
+ }
+ }
+
+ ret.x = (sign | (exponent << 10) | mantissa);
+
+ return ret;
+}
+
+inline float cpu_half2float(float16 h) {
+ unsigned sign = ((h.x >> 15) & 1);
+ unsigned exponent = ((h.x >> 10) & 0x1f);
+ unsigned mantissa = ((h.x & 0x3ff) << 13);
+
+ if (exponent == 0x1f) { /* NaN or Inf */
+ mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
+ exponent = 0xff;
+ } else if (!exponent) { /* Denorm or Zero */
+ if (mantissa) {
+ unsigned int msb;
+ exponent = 0x71;
+ do {
+ msb = (mantissa & 0x400000);
+ mantissa <<= 1; /* normalize */
+ --exponent;
+ } while (!msb);
+ mantissa &= 0x7fffff; /* 1.mantissa is implicit */
+ }
+ } else {
+ exponent += 0x70;
+ }
+
+ int temp = ((sign << 31) | (exponent << 23) | mantissa);
+
+ unsigned* rp = reinterpret_cast<unsigned*>(&temp);
+ return *rp;
+}
+
+}; // anonymous
+// general version: defer to static_cast
+template <typename IN, typename OUT>
+CONVERSIONS_DECL OUT To(const IN in) {
+ return static_cast<OUT>(in);
+}
+
+#if __CUDA_ARCH__
+__device__ __inline__ __half inf_clip(__half h) {
+ int isi = __hisinf(h);
+ if (isi > 0) {
+ // Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff)
+ h.x = 0x7bffU;
+ } else if (isi < 0) {
+ // As above, negated
+ h.x = 0x7bffU ^ 0x8000;
+ }
+ return h;
+}
+#endif
+
+// explicit for fp16
+template <>
+CONVERSIONS_DECL float16 To(const float in) {
+#if __CUDA_ARCH__
+ // hacky interface between C2 fp16 and CUDA
+ float16 ret;
+ __half r;
+ // r.x = __float2half_rn(in);
+ // ret.x = inf_clip(r).x;
+ ret.x = __float2half(in).x;
+ return ret;
+#else
+ return cpu_float2half_rn(in);
+#endif
+}
+
+template <>
+CONVERSIONS_DECL float To(const float16 in) {
+#if __CUDA_ARCH__
+ __half tmp;
+ tmp.x = in.x;
+ return __half2float(tmp);
+#else
+ return cpu_half2float(in);
+#endif
+};
+
+template <>
+CONVERSIONS_DECL float To(const float in) {
+ return in;
+}
+
+template <typename OUT, typename IN>
+CONVERSIONS_DECL OUT Get(IN x) {
+ return static_cast<OUT>(x);
+}
+
+template <>
+CONVERSIONS_DECL float Get(float16 x) {
+ return To<float16, float>(x);
+}
+
+template <>
+CONVERSIONS_DECL float16 Get(float x) {
+ return To<float, float16>(x);
+}
+
+}; // namespace convert
+
+}; // namespace caffe2
+
+#undef CONVERSIONS_DECL
diff --git a/caffe2/utils/math-detail.h b/caffe2/utils/math-detail.h
index 35a880a..07a1f99 100644
--- a/caffe2/utils/math-detail.h
+++ b/caffe2/utils/math-detail.h
@@ -11,8 +11,12 @@
template<typename T, class Context, int FixedSize>
struct ScaleImpl {
- inline void
- operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
+ inline void operator()(
+ const int N,
+ const float alpha,
+ const T* x,
+ T* y,
+ Context* context) {
Scale(N, alpha, x, y, context);
}
};
@@ -22,7 +26,7 @@
struct ScaleImpl<T, CPUContext, 1> {
inline void operator()(
const int N,
- const T alpha,
+ const float alpha,
const T* x,
T* y,
CPUContext* context) {
@@ -33,8 +37,12 @@
template<typename T, class Context, int FixedSize>
struct AxpyImpl {
- inline void
- operator()(const int N, const T alpha, const T* x, T* y, Context* context) {
+ inline void operator()(
+ const int N,
+ const float alpha,
+ const T* x,
+ T* y,
+ Context* context) {
Axpy(N, alpha, x, y, context);
}
};
@@ -44,7 +52,7 @@
struct AxpyImpl<T, CPUContext, 1> {
inline void operator()(
const int N,
- const T alpha,
+ const float alpha,
const T* x,
T* y,
CPUContext* context) {
@@ -57,14 +65,22 @@
} // namespace detail
template <typename T, class Context, int FixedSize>
-inline void
-ScaleFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
+inline void ScaleFixedSize(
+ const int N,
+ const float alpha,
+ const T* x,
+ T* y,
+ Context* context) {
detail::ScaleImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
}
template <typename T, class Context, int FixedSize>
-inline void
-AxpyFixedSize(const int N, const T alpha, const T* x, T* y, Context* context) {
+inline void AxpyFixedSize(
+ const int N,
+ const float alpha,
+ const T* x,
+ T* y,
+ Context* context) {
detail::AxpyImpl<T, Context, FixedSize>()(N, alpha, x, y, context);
}
diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h
index a2472c0..105cb19 100644
--- a/caffe2/utils/math.h
+++ b/caffe2/utils/math.h
@@ -141,10 +141,20 @@
// Decaf gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
-template <typename T, class Context, class Engine=DefaultEngine>
-void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const T alpha, const T* A,
- const T* B, const T beta, T* C, Context* context);
+template <typename T, class Context, class Engine = DefaultEngine>
+void Gemm(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const T* A,
+ const T* B,
+ const float beta,
+ T* C,
+ Context* context,
+ TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
// We also provide a gemm that has explicit lda, ldb and ldc specified.
// In most cases you probably want to use the function above, though.
@@ -169,10 +179,18 @@
// to Trans, the output is:
// CblasNoTrans: x is an N dim vector and y is an M dim vector.
// CblasTrans: x is an M dim vector and y is an N dim vector.
-template <typename T, class Context, class Engine=DefaultEngine>
-void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N,
- const T alpha, const T* A, const T* x, const T beta,
- T* y, Context* context);
+template <typename T, class Context, class Engine = DefaultEngine>
+void Gemv(
+ const CBLAS_TRANSPOSE TransA,
+ const int M,
+ const int N,
+ const float alpha,
+ const T* A,
+ const T* x,
+ const float beta,
+ T* y,
+ Context* context,
+ TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
template <typename T, class Context>
void Set(const TIndex N, const T alpha, T* X, Context* context);
@@ -218,28 +236,31 @@
Context* context);
template <typename T, class Context>
-void Scale(const int N, const T alpha, const T* x, T* y, Context* context);
+void Scale(const int N, const float alpha, const T* x, T* y, Context* context);
// Different from the Scale function above, if alpha is passed in
// as a pointer, we will assume that it lives on the Context device,
// for example on GPU.
template <typename T, class Context>
-void Scale(const int N, const T* alpha, const T* x, T* y,
- Context* context);
+void Scale(const int N, const float* alpha, const T* x, T* y, Context* context);
template <typename T, class Context>
-void Axpy(const int N, const T alpha, const T* x, T* y, Context* context);
+void Axpy(const int N, const float alpha, const T* x, T* y, Context* context);
// Different from the Axpy function above, if alpha is passed in
// as a pointer, we will assume that it lives on the Context device,
// for example on GPU.
template <typename T, class Context>
-void Axpy(const int N, const T* alpha, const T* x, T* y,
- Context* context);
+void Axpy(const int N, const float* alpha, const T* x, T* y, Context* context);
template <typename T, class Context>
-void Axpby(const int N, const T alpha, const T* x, const T b, T* y,
- Context* context);
+void Axpby(
+ const int N,
+ const float alpha,
+ const T* x,
+ const T b,
+ T* y,
+ Context* context);
template <typename T, class Context, int order>
void Im2colNd(
diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc
index 5cac0c8..e4340df 100644
--- a/caffe2/utils/math_cpu.cc
+++ b/caffe2/utils/math_cpu.cc
@@ -58,9 +58,18 @@
// CblasTrans, respectively, for each of A and B.
template <>
void Gemm<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const float alpha, const float* A,
- const float* B, const float beta, float* C, CPUContext* context) {
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CPUContext* context,
+ TensorProto::DataType math_type) {
auto C_mat = EigenMatrixMap<float>(C, N, M);
if (beta == 0) {
C_mat.setZero();
@@ -178,7 +187,8 @@
const float* x,
const float beta,
float* y,
- CPUContext* context) {
+ CPUContext* context,
+ TensorProto::DataType math_type) {
EigenVectorMap<float> y_vec(y, TransA == CblasNoTrans ? M : N);
if (beta == 0) {
// In Caffe2 we often do a lazy initialization, which may contain NaNs in
@@ -205,19 +215,22 @@
}
}
-#define CAFFE2_SPECIALIZED_SCALE(T) \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const T alpha, const T* x, T* y, CPUContext* context) { \
- EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \
- } \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
- EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \
+#define CAFFE2_SPECIALIZED_SCALE(T) \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
+ EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * alpha; \
+ } \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, \
+ const float* alpha, \
+ const T* x, \
+ T* y, \
+ CPUContext* context) { \
+ EigenVectorMap<T>(y, n) = ConstEigenVectorMap<T>(x, n) * (*alpha); \
}
CAFFE2_SPECIALIZED_SCALE(float)
-CAFFE2_SPECIALIZED_SCALE(double)
#undef CAFFE2_SPECIALIZED_SCALE
#define CAFFE2_SPECIALIZED_DOT(T) \
@@ -228,7 +241,6 @@
*y = ConstEigenVectorMap<T>(a, N).dot(ConstEigenVectorMap<T>(b, N)); \
}
CAFFE2_SPECIALIZED_DOT(float)
-CAFFE2_SPECIALIZED_DOT(double)
#undef CAFFE2_SPECIALIZED_DOT
#define CAFFE2_SPECIALIZED_AXPY(T) \
@@ -243,7 +255,6 @@
EigenVectorMap<T>(Y, N) += ConstEigenVectorMap<T>(x, N) * (*alpha); \
}
CAFFE2_SPECIALIZED_AXPY(float)
-CAFFE2_SPECIALIZED_AXPY(double)
#undef CAFFE2_SPECIALIZED_AXPY
#define CAFFE2_SPECIALIZED_AXPBY(T) \
@@ -254,16 +265,24 @@
y_vec = y_vec * beta + ConstEigenVectorMap<T>(x, N) * alpha; \
}
CAFFE2_SPECIALIZED_AXPBY(float)
-CAFFE2_SPECIALIZED_AXPBY(double)
#undef CAFFE2_SPECIALIZED_AXPBY
#else // CAFFE2_USE_EIGEN_FOR_BLAS
template <>
void Gemm<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const float alpha, const float* A,
- const float* B, const float beta, float* C, CPUContext* context) {
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CPUContext* context,
+ TensorProto::DataType math_type) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B, ldb,
@@ -292,29 +311,39 @@
template <>
void Gemv<float, CPUContext>(
- const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
- const float* A, const float* x, const float beta, float* y,
- CPUContext* context) {
+ const CBLAS_TRANSPOSE TransA,
+ const int M,
+ const int N,
+ const float alpha,
+ const float* A,
+ const float* x,
+ const float beta,
+ float* y,
+ CPUContext* context,
+ TensorProto::DataType math_type) {
cblas_sgemv(CblasRowMajor, TransA, M, N, alpha, A, N, x, 1, beta, y, 1);
}
-#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const T alpha, const T* x, T* y, CPUContext* context) { \
- if (y != x) \
- cblas_##prefix##copy(n, x, 1, y, 1); \
- cblas_##prefix##scal(n, alpha, y, 1); \
- } \
- template <> \
- void Scale<T, CPUContext>( \
- const int n, const T* alpha, const T* x, T* y, CPUContext* context) { \
- if (y != x) \
- cblas_##prefix##copy(n, x, 1, y, 1); \
- cblas_##prefix##scal(n, *alpha, y, 1); \
+#define CAFFE2_SPECIALIZED_SCALE(T, prefix) \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, const float alpha, const T* x, T* y, CPUContext* context) { \
+ if (y != x) \
+ cblas_##prefix##copy(n, x, 1, y, 1); \
+ cblas_##prefix##scal(n, static_cast<float>(alpha), y, 1); \
+ } \
+ template <> \
+ void Scale<T, CPUContext>( \
+ const int n, \
+ const float* alpha, \
+ const T* x, \
+ T* y, \
+ CPUContext* context) { \
+ if (y != x) \
+ cblas_##prefix##copy(n, x, 1, y, 1); \
+ cblas_##prefix##scal(n, static_cast<float>(*alpha), y, 1); \
}
CAFFE2_SPECIALIZED_SCALE(float, s)
-CAFFE2_SPECIALIZED_SCALE(double, d)
#undef CAFFE2_SPECIALIZED_SCALE
#define CAFFE2_SPECIALIZED_DOT(T, prefix) \
@@ -325,7 +354,6 @@
*y = cblas_##prefix##dot(N, a, 1, b, 1); \
}
CAFFE2_SPECIALIZED_DOT(float, s)
-CAFFE2_SPECIALIZED_DOT(double, d)
#undef CAFFE2_SPECIALIZED_DOT
#define CAFFE2_SPECIALIZED_AXPY(T, prefix) \
@@ -340,7 +368,6 @@
cblas_##prefix##axpy(N, *alpha, x, 1, y, 1); \
}
CAFFE2_SPECIALIZED_AXPY(float, s)
-CAFFE2_SPECIALIZED_AXPY(double, d)
#undef CAFFE2_SPECIALIZED_AXPY
// cblas_[sd]axpby is not a standard blas function, and if MKL is not present,
@@ -362,7 +389,6 @@
}
#endif // CAFFE2_USE_MKL
CAFFE2_SPECIALIZED_AXPBY(float, s)
-CAFFE2_SPECIALIZED_AXPBY(double, d)
#undef CAFFE2_SPECIALIZED_AXPBY
#endif // CAFFE2_USE_EIGEN_FOR_BLAS
@@ -436,11 +462,8 @@
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(x, N).array().expr(); \
}
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, exp)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, exp)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, log)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, log)
DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, square)
-DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, square)
#undef DELEGATE_SIMPLE_UNARY_FUNCTION
#define DELEGATE_POWX_FUNCTION(T) \
@@ -450,7 +473,6 @@
EigenVectorMap<T>(y, N) = ConstEigenVectorMap<T>(a, N).array().pow(b); \
}
DELEGATE_POWX_FUNCTION(float)
-DELEGATE_POWX_FUNCTION(double)
#undef DELEGATE_POWX_FUNCTION
#endif // CAFFE2_USE_MKL
@@ -476,7 +498,6 @@
#define DEFINE_SIMPLE_BINARY_FUNCTION(Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(float, Funcname, expr) \
-EIGEN_SIMPLE_BINARY_FUNCTION(double, Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(int32_t, Funcname, expr) \
EIGEN_SIMPLE_BINARY_FUNCTION(int64_t, Funcname, expr)
@@ -546,7 +567,6 @@
DELEGATE_BROADCAST_BINARY_FUNCTION(int32_t, name, op) \
DELEGATE_BROADCAST_BINARY_FUNCTION(int64_t, name, op) \
DELEGATE_BROADCAST_BINARY_FUNCTION(float, name, op) \
- DELEGATE_BROADCAST_BINARY_FUNCTION(double, name, op)
DEFINE_BROADCAST_BINARY_FUNCTION(Add, +)
DEFINE_BROADCAST_BINARY_FUNCTION(Sub, -)
@@ -602,7 +622,6 @@
#define CAFFE2_DEFINE_BINARY_OP(name, op) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, float) \
- CAFFE2_INSTANTIATE_BINARY_OP(name, op, double) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, int32_t) \
CAFFE2_INSTANTIATE_BINARY_OP(name, op, int64_t)
@@ -644,7 +663,6 @@
}
CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(float);
-CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH(double);
#undef CAFFE2_SPECIALIZED_CPU_ADD_STRIPED_BATCH
template <>
@@ -717,7 +735,6 @@
}
CAFFE2_SPECIALIZED_SUM(float);
-CAFFE2_SPECIALIZED_SUM(double);
CAFFE2_SPECIALIZED_SUM(int32_t);
CAFFE2_SPECIALIZED_SUM(int64_t);
diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu
index cfbe91b..46c5bc0 100644
--- a/caffe2/utils/math_gpu.cu
+++ b/caffe2/utils/math_gpu.cu
@@ -5,8 +5,9 @@
#include <thrust/system/cuda/detail/par.h>
#include <thrust/version.h>
-#include "caffe2/utils/math.h"
#include "caffe2/core/context_gpu.h"
+#include "caffe2/utils/conversions.h"
+#include "caffe2/utils/math.h"
#if THRUST_VERSION >= 100800
#define THRUST_SUPPORTS_PER_THREAD
@@ -32,33 +33,30 @@
}
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Exp, exp);
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Log, log);
__device__ float cuda_sqrf(const float x) { return x * x; }
-__device__ double cuda_sqr(const double x) { return x * x; }
DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, cuda_sqrf);
-DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sqr, cuda_sqr);
#undef DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION
-#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \
- __global__ void _Kernel_##T##_##Funcname( \
- const int N, const T* a, const T* b, T* y) { \
- CUDA_1D_KERNEL_LOOP(i, N) { \
- y[i] = a[i] expr b[i]; \
- } \
- } \
- template <> \
- void Funcname<T, CUDAContext>( \
- const int N, const T* a, const T* b, T* y, CUDAContext* context) { \
- _Kernel_##T##_##Funcname<<< \
- CAFFE_GET_BLOCKS(N), \
- CAFFE_CUDA_NUM_THREADS, \
- 0, \
- context->cuda_stream()>>>(N, a, b, y); \
+#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(T, Funcname, expr) \
+ __global__ void _Kernel_##T##_##Funcname( \
+ const int N, const T* a, const T* b, T* y) { \
+ CUDA_1D_KERNEL_LOOP(i, N) { \
+ float r = convert::To<T, float>(a[i]) expr convert::To<T, float>(b[i]); \
+ y[i] = convert::To<float, T>(r); \
+ } \
+ } \
+ template <> \
+ void Funcname<T, CUDAContext>( \
+ const int N, const T* a, const T* b, T* y, CUDAContext* context) { \
+ _Kernel_##T##_##Funcname<<< \
+ CAFFE_GET_BLOCKS(N), \
+ CAFFE_CUDA_NUM_THREADS, \
+ 0, \
+ context->cuda_stream()>>>(N, a, b, y); \
}
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, +);
@@ -66,13 +64,27 @@
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, *);
DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, /);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Add, +);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Sub, -);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Mul, *);
+DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float16, Div, /);
+
// Caffe2 gemm provides a simpler interface to the gemm functions, with the
// limitation that the data has to be contiguous in memory.
template <>
void Gemm<float, CUDAContext>(
- const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const float alpha, const float* A,
- const float* B, const float beta, float* C, CUDAContext* context) {
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const float* B,
+ const float beta,
+ float* C,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
int lda = (TransA == CblasNoTrans) ? K : M;
@@ -99,11 +111,91 @@
}
template <>
+void Gemm<float16, CUDAContext>(
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float16* A,
+ const float16* B,
+ const float beta,
+ float16* C,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ // Note that cublas follows fortran order, so the order is different from
+ // the cblas convention.
+ int lda = (TransA == CblasNoTrans) ? K : M;
+ int ldb = (TransB == CblasNoTrans) ? N : K;
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ cublasOperation_t cuTransB =
+ (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+ if (math_type == TensorProto_DataType_FLOAT) {
+ CUBLAS_CHECK(cublasSgemmEx(
+ context->cublas_handle(),
+ cuTransB,
+ cuTransA,
+ N,
+ M,
+ K,
+ &alpha,
+ B,
+ CUDA_R_16F,
+ ldb,
+ A,
+ CUDA_R_16F,
+ lda,
+ &beta,
+ C,
+ CUDA_R_16F,
+ N));
+
+ } else if (math_type == TensorProto_DataType_FLOAT16) {
+ // convert alpha, beta from caffe2::float16 -> __half
+ __half alpha_fp16;
+ alpha_fp16.x = convert::To<float, float16>(alpha).x;
+ __half beta_fp16;
+ beta_fp16.x = convert::To<float, float16>(beta).x;
+ // call cublasHgemm
+ CUBLAS_CHECK(cublasHgemm(
+ context->cublas_handle(),
+ cuTransB,
+ cuTransA,
+ N,
+ M,
+ K,
+ &alpha_fp16,
+ (const __half*)B,
+ ldb,
+ (const __half*)A,
+ lda,
+ &beta_fp16,
+ (__half*)C,
+ N));
+ } else {
+ // fail
+ CAFFE_THROW("Unsupported math type");
+ }
+}
+
+template <>
void GemmEx<float, CUDAContext>(
- const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB,
- const int M, const int N, const int K, const float alpha, const float* A,
- const int lda, const float* B, const int ldb, const float beta, float* C,
- const int ldc, CUDAContext* context) {
+ const CBLAS_TRANSPOSE TransA,
+ const CBLAS_TRANSPOSE TransB,
+ const int M,
+ const int N,
+ const int K,
+ const float alpha,
+ const float* A,
+ const int lda,
+ const float* B,
+ const int ldb,
+ const float beta,
+ float* C,
+ const int ldc,
+ CUDAContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
@@ -129,40 +221,19 @@
template <>
void Gemv<float, CUDAContext>(
- const CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha,
- const float* A, const float* x, const float beta, float* y,
- CUDAContext* context) {
- cublasOperation_t cuTransA =
- (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_ENFORCE(cublasSgemv(
- context->cublas_handle(),
- cuTransA,
- N,
- M,
- &alpha,
- A,
- N,
- x,
- 1,
- &beta,
- y,
- 1));
-}
-
-template <>
-void Gemv<double, CUDAContext>(
const CBLAS_TRANSPOSE TransA,
const int M,
const int N,
- const double alpha,
- const double* A,
- const double* x,
- const double beta,
- double* y,
- CUDAContext* context) {
+ const float alpha,
+ const float* A,
+ const float* x,
+ const float beta,
+ float* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
- CUBLAS_ENFORCE(cublasDgemv(
+ CUBLAS_ENFORCE(cublasSgemv(
context->cublas_handle(),
cuTransA,
N,
@@ -216,6 +287,73 @@
CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH(double);
#undef CAFFE2_SPECIALIZED_CUDA_ADD_STRIPED_BATCH
+template <>
+void Gemv<float16, CUDAContext>(
+ const CBLAS_TRANSPOSE TransA,
+ const int M,
+ const int N,
+ const float alpha,
+ const float16* A,
+ const float16* x,
+ const float beta,
+ float16* y,
+ CUDAContext* context,
+ TensorProto::DataType math_type) {
+ cublasOperation_t cuTransA =
+ (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+
+ // sort out what we need to call cublasSgemmEx / cublasHgemm
+ int m = (cuTransA == CUBLAS_OP_N) ? N : M;
+ int k = (cuTransA == CUBLAS_OP_N) ? M : N;
+ int LDA = (cuTransA == CUBLAS_OP_N) ? m : k;
+ int LDC = m;
+
+ if (math_type == TensorProto_DataType_FLOAT) {
+ CUBLAS_CHECK(cublasSgemmEx(
+ context->cublas_handle(),
+ cuTransA,
+ CUBLAS_OP_N,
+ m,
+ 1,
+ k,
+ &alpha,
+ A,
+ CUDA_R_16F,
+ LDA,
+ x,
+ CUDA_R_16F,
+ k,
+ &beta,
+ y,
+ CUDA_R_16F,
+ LDC));
+ } else if (math_type == TensorProto_DataType_FLOAT16) {
+ __half alpha_fp16;
+ alpha_fp16.x = convert::To<float, float16>(alpha).x;
+ __half beta_fp16;
+ beta_fp16.x = convert::To<float, float16>(beta).x;
+
+ CUBLAS_CHECK(cublasHgemm(
+ context->cublas_handle(),
+ cuTransA,
+ CUBLAS_OP_N,
+ m,
+ 1,
+ k,
+ &alpha_fp16,
+ (const __half*)A,
+ LDA,
+ (const __half*)x,
+ k,
+ &beta_fp16,
+ (__half*)y,
+ LDC));
+ } else {
+ // fail
+ CAFFE_THROW("Unsupported math type");
+ }
+}
+
namespace {
template <typename T>
__global__ void SetKernel(const int N, const T alpha, T* Y) {
@@ -238,6 +376,7 @@
CAFFE2_SPECIALIZED_CUDA_SET(bool);
CAFFE2_SPECIALIZED_CUDA_SET(int8_t);
CAFFE2_SPECIALIZED_CUDA_SET(int16_t);
+CAFFE2_SPECIALIZED_CUDA_SET(float16);
CAFFE2_SPECIALIZED_CUDA_SET(int);
CAFFE2_SPECIALIZED_CUDA_SET(int64_t);
CAFFE2_SPECIALIZED_CUDA_SET(char);
@@ -247,11 +386,11 @@
namespace {
template <typename T>
-__global__ void UniformShift(const int N, const T min, const T max,
- T* x) {
- T scale = max - min;
+__global__ void
+UniformShift(const int N, const float min, const float max, T* x) {
+ float scale = max - min;
CUDA_1D_KERNEL_LOOP(i, N) {
- x[i] = x[i] * scale + min;
+ x[i] = convert::To<float, T>(convert::To<T, float>(x[i]) * scale + min);
}
}
@@ -336,7 +475,6 @@
context->curand_generator(), r, even_n, mean, std));
}
-
template<>
void Dot<float, CUDAContext>(
const int n, const float* a, const float* b, float* y,
@@ -346,13 +484,28 @@
context->Copy<float, CPUContext, CUDAContext>(1, &result, y);
}
-template<>
-void Dot<double, CUDAContext>(
- const int n, const double* a, const double* b, double* y,
+template <>
+void Dot<float16, CUDAContext>(
+ const int n,
+ const float16* a,
+ const float16* b,
+ float16* y,
CUDAContext* context) {
- double result;
- CUBLAS_ENFORCE(cublasDdot(context->cublas_handle(), n, a, 1, b, 1, y));
- context->Copy<double, CPUContext, CUDAContext>(1, &result, y);
+ float16 result;
+ // execute with 32-bit math
+ CUBLAS_CHECK(cublasDotEx(
+ context->cublas_handle(),
+ n,
+ a,
+ CUDA_R_16F,
+ 1,
+ b,
+ CUDA_R_16F,
+ 1,
+ &result,
+ CUDA_R_16F,
+ CUDA_R_32F));
+ context->Copy<float16, CPUContext, CUDAContext>(1, &result, y);
}
// A previous version of caffe2 used Thrust but it turns out that thrust
@@ -363,7 +516,7 @@
template <typename T>
__global__ void SumKernel(const int N, const T* X, T* Y, bool square) {
const int idx = threadIdx.x;
- __shared__ T reduction_buffer[SUM_KERNEL_NTHREADS];
+ __shared__ float reduction_buffer[SUM_KERNEL_NTHREADS];
reduction_buffer[idx] = 0;
@@ -371,11 +524,12 @@
// N -> 128
if (!square) {
for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
- reduction_buffer[idx] += X[i];
+ reduction_buffer[idx] += convert::To<T, float>(X[i]);
}
} else {
for (int i = idx; i < N; i += SUM_KERNEL_NTHREADS) {
- reduction_buffer[idx] += X[i] * X[i];
+ float Xi = convert::To<T, float>(X[i]);
+ reduction_buffer[idx] += Xi * Xi;
}
}
__syncthreads();
@@ -393,7 +547,7 @@
for (int i = 0; i < 32; ++i) {
tmp += reduction_buffer[i];
}
- *Y = tmp;
+ *Y = convert::To<float, T>(tmp);
}
}
@@ -406,7 +560,7 @@
}
CAFFE2_MATH_SUM_FUNC(float)
-CAFFE2_MATH_SUM_FUNC(double)
+CAFFE2_MATH_SUM_FUNC(float16)
#undef CAFFE2_MATH_SUM_FUNC
#define CAFFE2_MATH_SUMSQR_FUNC(T) \
@@ -438,18 +592,33 @@
0, context->cuda_stream()>>>(N, D, x, idx, y);
}
+template <>
+void Select<float16, CUDAContext>(
+ const int N,
+ const int D,
+ const float16* x,
+ const int* idx,
+ float16* y,
+ CUDAContext* context) {
+ SelectKernel<float16><<<
+ CAFFE_GET_BLOCKS(N),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(N, D, x, idx, y);
+}
+
namespace {
template <typename T>
-__global__ void ScaleKernel(
- const int n, const T alpha, const T* x, T* y) {
+__global__ void ScaleKernel(const int n, const float alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
- y[i] = x[i] * alpha;
+ // y[i] = convert::To<float,T>(convert::To<T, float>(x[i]) * alpha);
+ y[i] = convert::Get<T>(convert::Get<float>(x[i]) * alpha);
}
}
template <typename T>
-__global__ void ScaleKernelDeviceAlpha(
- const int n, const T* alpha, const T* x, T* y) {
+__global__ void
+ScaleKernelDeviceAlpha(const int n, const float* alpha, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) {
y[i] = x[i] * (*alpha);
}
@@ -461,6 +630,20 @@
y[i] = powf(x[i], exponent);
}
}
+
+// fp16 specialization
+template <>
+__global__ void ScaleKernelDeviceAlpha(
+ const int n,
+ const float* alpha,
+ const float16* x,
+ float16* y) {
+ CUDA_1D_KERNEL_LOOP(i, n) {
+ y[i] = convert::To<float, float16>(
+ convert::To<float16, float>(x[i]) * (*alpha));
+ }
+}
+
} // namespace
template <>
@@ -489,12 +672,17 @@
}
template <>
-void Scale<double, CUDAContext>(
- const int n, const double alpha, const double *x, double* y,
+void Scale<float16, CUDAContext>(
+ const int n,
+ const float alpha,
+ const float16* x,
+ float16* y,
CUDAContext* context) {
- ScaleKernel<double><<<
- CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS, 0, context->cuda_stream()>>>(
- n, alpha, x, y);
+ ScaleKernel<float16><<<
+ CAFFE_GET_BLOCKS(n),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(n, alpha, x, y);
}
template <>
@@ -507,11 +695,17 @@
}
template <>
-void Scale<double, CUDAContext>(
- const int n, const double* alpha, const double *x, double* y,
+void Scale<float16, CUDAContext>(
+ const int n,
+ const float* alpha,
+ const float16* x,
+ float16* y,
CUDAContext* context) {
- ScaleKernelDeviceAlpha<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
- 0, context->cuda_stream()>>>(n, alpha, x, y);
+ ScaleKernelDeviceAlpha<float16><<<
+ CAFFE_GET_BLOCKS(n),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(n, alpha, x, y);
}
template <>
@@ -527,18 +721,42 @@
template <>
void Axpy<double, CUDAContext>(
const int N,
- const double alpha,
+ const float alpha,
const double* X,
double* Y,
CUDAContext* context) {
- CUBLAS_ENFORCE(cublasDaxpy(context->cublas_handle(), N, &alpha, X, 1, Y, 1));
+ double alpha_d{alpha};
+ CUBLAS_ENFORCE(
+ cublasDaxpy(context->cublas_handle(), N, &alpha_d, X, 1, Y, 1));
+}
+
+template <>
+void Axpy<float16, CUDAContext>(
+ const int N,
+ const float alpha,
+ const float16* X,
+ float16* Y,
+ CUDAContext* context) {
+ CUBLAS_CHECK(cublasAxpyEx(
+ context->cublas_handle(),
+ N,
+ &alpha,
+ CUDA_R_16F,
+ X,
+ CUDA_R_16F,
+ 1,
+ Y,
+ CUDA_R_16F,
+ 1,
+ CUDA_R_32F));
}
namespace {
template <typename T>
-__global__ void AxpyKernel(const int n, const T* a, const T* x, T* y) {
+__global__ void AxpyKernel(const int n, const float* a, const T* x, T* y) {
CUDA_1D_KERNEL_LOOP(index, n) {
- y[index] += x[index] * (*a);
+ y[index] = convert::Get<T>(
+ convert::Get<float>(x[index]) * (*a) + convert::Get<float>(y[index]));
}
}
} // namespace
@@ -552,14 +770,19 @@
}
template <>
-void Axpy<double, CUDAContext>(
- const int n, const double* alpha, const double* X,
- double* Y, CUDAContext* context) {
- AxpyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
- 0, context->cuda_stream()>>>(n, alpha, X, Y);
+void Axpy<float16, CUDAContext>(
+ const int n,
+ const float* alpha,
+ const float16* X,
+ float16* Y,
+ CUDAContext* context) {
+ AxpyKernel<float16><<<
+ CAFFE_GET_BLOCKS(n),
+ CAFFE_CUDA_NUM_THREADS,
+ 0,
+ context->cuda_stream()>>>(n, alpha, X, Y);
}
-
namespace {
template <typename T>
__global__ void AxpbyKernel(const int n, const T a, const T* x,
@@ -578,14 +801,6 @@
0, context->cuda_stream()>>>(n, a, x, b, y);
}
-template <>
-void Axpby<double, CUDAContext>(
- const int n, const double a, const double* x, const double b, double* y,
- CUDAContext* context) {
- AxpbyKernel<double><<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS,
- 0, context->cuda_stream()>>>(n, a, x, b, y);
-}
-
namespace {
template <typename T>
diff --git a/caffe2/utils/math_gpu_test.cc b/caffe2/utils/math_gpu_test.cc
index 2ceeddd..b1f930b 100644
--- a/caffe2/utils/math_gpu_test.cc
+++ b/caffe2/utils/math_gpu_test.cc
@@ -67,61 +67,4 @@
}
}
-#define TEST_GEMV_WITH_TYPE(field_name) \
- TEST(MathUtilGPUTest, testGemv_##field_name) { \
- if (!HasCudaGPU()) \
- return; \
- Workspace ws; \
- DeviceOption option; \
- option.set_device_type(CUDA); \
- CUDAContext context(option); \
- Blob* blobx = ws.CreateBlob("X"); \
- Blob* bloby = ws.CreateBlob("Y"); \
- Blob* blobz = ws.CreateBlob("Z"); \
- Blob* bloby_host = ws.CreateBlob("Y_host"); \
- \
- vector<int> shapex{64, 128}; \
- vector<int> shapey{64}; \
- vector<int> shapez{128}; \
- \
- auto* tensorx = blobx->GetMutable<Tensor<CUDAContext>>(); \
- tensorx->Resize(shapex); \
- math::Set<field_name, CUDAContext>( \
- 64 * 128, \
- (field_name)1.0, \
- tensorx->mutable_data<field_name>(), \
- &context); \
- \
- auto* tensory = bloby->GetMutable<Tensor<CUDAContext>>(); \
- tensory->Resize(shapey); \
- math::Set<field_name, CUDAContext>( \
- 64, (field_name)1.0, tensory->mutable_data<field_name>(), &context); \
- \
- auto* tensorz = blobz->GetMutable<Tensor<CUDAContext>>(); \
- tensorz->Resize(shapez); \
- \
- math::Gemv<field_name, CUDAContext>( \
- CblasTrans, \
- 64, \
- 128, \
- 1.0, \
- tensorx->template data<field_name>(), \
- tensory->mutable_data<field_name>(), \
- 0.0, \
- tensorz->template mutable_data<field_name>(), \
- &context); \
- context.FinishDeviceComputation(); \
- \
- auto* tensory_host = bloby_host->GetMutable<Tensor<CPUContext>>(); \
- tensory_host->CopyFrom<CUDAContext, CUDAContext>(*tensorz, &context); \
- context.FinishDeviceComputation(); \
- \
- for (int i = 0; i < 128; i++) { \
- EXPECT_EQ(tensory_host->data<field_name>()[i], 64.0); \
- } \
- }
-
-TEST_GEMV_WITH_TYPE(float);
-TEST_GEMV_WITH_TYPE(double);
-
} // namespace caffe2