Extract some utility operators to their own source files to reduce build size.
Summary: Extract some operators from utility_ops and normalize_op to reduce build size impact of depending on these files.
Reviewed By: Maratyszcza
Differential Revision: D6616741
fbshipit-source-id: 1757b6b8a3ce4e2a248deee61322344e5095e940
diff --git a/caffe2/operators/flatten_op.cc b/caffe2/operators/flatten_op.cc
new file mode 100644
index 0000000..48f6573
--- /dev/null
+++ b/caffe2/operators/flatten_op.cc
@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/flatten_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
+
+OPERATOR_SCHEMA(Flatten)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .TensorInferenceFunction([](const OperatorDef& def,
+ const vector<TensorShape>& in) {
+ ArgumentHelper helper(def);
+ const int axis = helper.GetSingleArgument<int>("axis", 1);
+ vector<TensorShape> out(1);
+ TIndex outer = 1;
+ TIndex inner = 1;
+ std::size_t index = 0;
+ for (auto d : in[0].dims()) {
+ if (index < axis) {
+ outer *= d;
+ } else {
+ inner *= d;
+ }
+ ++index;
+ }
+ out[0].set_data_type(in[0].data_type());
+ out[0].add_dims(outer);
+ out[0].add_dims(inner);
+ return out;
+ })
+ .SetDoc(R"DOC(
+Flattens the input tensor into a 2D matrix. If input tensor has shape
+(d_0, d_1, ... d_n) then the output will have shape
+(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
+)DOC")
+ .Input(0, "input", "A tensor of rank >= axis.")
+ .Output(
+ 0,
+ "output",
+ "A 2D tensor with the contents of the input tensor, "
+ "with input dimensions up to axis flattened to the outer dimension "
+ "of the output and remaining input dimensions flattened into the inner "
+ "dimension of the output.")
+ .Arg(
+ "axis",
+ "(Default to 1) Indicate up to which input dimensions "
+ "(exclusive) should be flattened to the outer dimension of the output");
+
+class GetFlattenGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ return SingleGradientDef(
+ "ResizeLike", "", vector<string>{GO(0), I(0)}, vector<string>{GI(0)});
+ }
+};
+
+REGISTER_GRADIENT(Flatten, GetFlattenGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/flatten_op.h b/caffe2/operators/flatten_op.h
new file mode 100644
index 0000000..db0a1b7
--- /dev/null
+++ b/caffe2/operators/flatten_op.h
@@ -0,0 +1,53 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
+#define CAFFE2_OPERATORS_FLATTEN_OP_H_
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <class Context>
+class FlattenOp : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+ FlattenOp(const OperatorDef& operator_def, Workspace* ws)
+ : Operator<Context>(operator_def, ws),
+ axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
+
+ bool RunOnDevice() override {
+ auto& input = Input(0);
+ auto* output = Output(0);
+ CAFFE_ENFORCE_GE(
+ input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
+ output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
+ context_.template CopyItems<Context, Context>(
+ input.meta(),
+ input.size(),
+ input.raw_data(),
+ output->raw_mutable_data(input.meta()));
+ return true;
+ }
+
+ private:
+ int axis_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc
new file mode 100644
index 0000000..4eb091b
--- /dev/null
+++ b/caffe2/operators/minmax_gradient_ops.cc
@@ -0,0 +1,82 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/minmax_ops.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
+
+OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+
+template <typename T, class Context>
+bool SelectGradientOpBase<T, Context>::RunOnDevice() {
+ auto& output = Input(0);
+ auto& grad_output = Input(1);
+ const int kInputStartOffset = 2;
+
+ const T* data = output.template data<T>();
+ ConstEigenArrayMap<T> output_array(
+ output.template data<T>(), 1, output.size());
+ ConstEigenArrayMap<T> grad_out_array(
+ grad_output.template data<T>(), 1, grad_output.size());
+
+ for (int i = 0; i < OutputSize(); i++) {
+ auto& input = Input(i + kInputStartOffset);
+ ConstEigenArrayMap<T> input_array(
+ input.template data<T>(), 1, input.size());
+
+ auto* grad_input = Output(i);
+ grad_input->ResizeLike(input);
+ EigenArrayMap<T> grad_in_array(
+ grad_input->template mutable_data<T>(), 1, grad_input->size());
+ grad_in_array = grad_out_array *
+ input_array.cwiseEqual(output_array).template cast<T>();
+ }
+ return true;
+}
+
+class GetMaxGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ auto gradInputs = vector<string>();
+ auto inputs = vector<string>{O(0), GO(0)};
+ for (int i = 0; i < def_.input_size(); i++) {
+ gradInputs.push_back(GI(i));
+ inputs.push_back(I(i));
+ }
+ return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
+ }
+};
+REGISTER_GRADIENT(Max, GetMaxGradient);
+
+class GetMinGradient : public GradientMakerBase {
+ using GradientMakerBase::GradientMakerBase;
+ vector<OperatorDef> GetGradientDefs() override {
+ auto gradInputs = vector<string>();
+ auto inputs = vector<string>{O(0), GO(0)};
+ for (int i = 0; i < def_.input_size(); i++) {
+ gradInputs.push_back(GI(i));
+ inputs.push_back(I(i));
+ }
+ return SingleGradientDef("MinGradient", "", inputs, gradInputs);
+ }
+};
+REGISTER_GRADIENT(Min, GetMinGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc
new file mode 100644
index 0000000..4f47ce4
--- /dev/null
+++ b/caffe2/operators/minmax_ops.cc
@@ -0,0 +1,82 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/minmax_ops.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>);
+
+OPERATOR_SCHEMA(Max)
+ .NumInputs(1, INT_MAX)
+ .NumOutputs(1)
+ .IdenticalTypeAndShapeOfInput(0)
+ .AllowInplace({{0, 0}})
+ .SetDoc(R"DOC(
+ Element-wise max of each of the input tensors. The first input tensor can be
+ used in-place as the output tensor, in which case the max will be done in
+ place and results will be accumulated in input0. All inputs and outputs must
+ have the same shape and data type.
+ )DOC")
+ .Input(0, "data_0", "First of the input tensors. Can be inplace.")
+ .Output(0, "max", "Output tensor. Same dimension as inputs.");
+
+OPERATOR_SCHEMA(Min)
+ .NumInputs(1, INT_MAX)
+ .NumOutputs(1)
+ .IdenticalTypeAndShapeOfInput(0)
+ .AllowInplace({{0, 0}})
+ .SetDoc(R"DOC(
+Element-wise min of each of the input tensors. The first input tensor can be
+used in-place as the output tensor, in which case the min will be done in
+place and results will be accumulated in input0. All inputs and outputs must
+have the same shape and data type.
+)DOC")
+ .Input(0, "data_0", "First of the input tensors. Can be inplace.")
+ .Output(0, "min", "Output tensor. Same dimension as inputs.");
+
+template <typename T, class Context>
+bool MaxOp<T, Context>::Compute() {
+ auto& input0 = Input(0);
+ const int N = input0.size();
+ T* output_data = Output(0)->template mutable_data<T>();
+
+ for (int i = 1; i < InputSize(); i++) {
+ auto input_data = Input(i).template data<T>();
+ EigenVectorMap<T> output_vec(output_data, N);
+ output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N));
+ }
+
+ return true;
+}
+
+template <typename T, class Context>
+bool MinOp<T, Context>::Compute() {
+ auto& input0 = Input(0);
+ const int N = input0.size();
+ T* output_data = Output(0)->template mutable_data<T>();
+
+ for (int i = 1; i < InputSize(); i++) {
+ auto input_data = Input(i).template data<T>();
+ EigenVectorMap<T> output_vec(output_data, N);
+ output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N));
+ }
+
+ return true;
+}
+
+} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h
new file mode 100644
index 0000000..c4cc642
--- /dev/null
+++ b/caffe2/operators/minmax_ops.h
@@ -0,0 +1,112 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
+#define CAFFE2_OPERATORS_MINMAX_OPS_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/types.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class MaxMinOpBase : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
+
+ bool RunOnDevice() override {
+ auto& input0 = Input(0);
+ auto* output = Output(0);
+
+ output->ResizeLike(input0);
+ output->CopyFrom(input0, &context_);
+
+ if (InputSize() == 1) {
+ return true;
+ }
+
+ // Dimension checking
+ for (int i = 1; i < InputSize(); ++i) {
+ CAFFE_ENFORCE_EQ(
+ output->dims(),
+ Input(i).dims(),
+ "Description: Input #",
+ i,
+ ", input dimension:",
+ Input(i).dims(),
+ " should match output dimension: ",
+ output->dims());
+ }
+
+ return this->Compute();
+ }
+
+ virtual bool Compute() = 0;
+};
+
+template <typename T, class Context>
+class MaxOp : public MaxMinOpBase<T, Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ MaxOp(const OperatorDef& operator_def, Workspace* ws)
+ : MaxMinOpBase<T, Context>(operator_def, ws) {}
+ virtual ~MaxOp() noexcept {}
+ bool Compute() override;
+};
+
+template <typename T, class Context>
+class SelectGradientOpBase : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
+
+ bool RunOnDevice() override;
+};
+
+template <typename T, class Context>
+class MaxGradientOp : public SelectGradientOpBase<T, Context> {
+ public:
+ MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
+ : SelectGradientOpBase<T, Context>(operator_def, ws) {}
+ virtual ~MaxGradientOp() noexcept {}
+};
+
+template <typename T, class Context>
+class MinOp : public MaxMinOpBase<T, Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ MinOp(const OperatorDef& operator_def, Workspace* ws)
+ : MaxMinOpBase<T, Context>(operator_def, ws) {}
+ virtual ~MinOp() noexcept {}
+ bool Compute() override;
+};
+
+template <typename T, class Context>
+class MinGradientOp : public SelectGradientOpBase<T, Context> {
+ public:
+ MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
+ : SelectGradientOpBase<T, Context>(operator_def, ws) {}
+ virtual ~MinGradientOp() noexcept {}
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_MINMAX_OPS_H_
diff --git a/caffe2/operators/normalize_l1_op.cc b/caffe2/operators/normalize_l1_op.cc
new file mode 100644
index 0000000..13d9ee4
--- /dev/null
+++ b/caffe2/operators/normalize_l1_op.cc
@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/normalize_l1_op.h"
+
+#include "caffe2/core/tensor.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+void NormalizeL1Op<T, Context>::DoNormalize(
+ const T* xData,
+ T* yData,
+ const int m,
+ const int n,
+ const int sf) {
+ using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
+ using StridedVec =
+ Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
+ using ConstStridedVec =
+ Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
+
+ for (int i = 0; i < n; ++i) {
+ auto base = (i / sf) * sf * m + (i % sf);
+ ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
+ auto norm = xVec.template lpNorm<1>();
+ if (norm != 0) {
+ StridedVec yVec(yData + base, 1, m, InnerStride(sf));
+ yVec = xVec / norm;
+ }
+ }
+};
+
+REGISTER_CPU_OPERATOR(NormalizeL1, NormalizeL1Op<float, CPUContext>);
+OPERATOR_SCHEMA(NormalizeL1)
+ .NumInputs(1)
+ .NumOutputs(1)
+ .Arg("axis", "axis to normalize")
+ .SetDoc(R"DOC(
+ Given a matrix, apply L1-normalization along the specified axis.
+ )DOC");
+
+} // namespace caffe2
diff --git a/caffe2/operators/normalize_l1_op.h b/caffe2/operators/normalize_l1_op.h
new file mode 100644
index 0000000..bffff04
--- /dev/null
+++ b/caffe2/operators/normalize_l1_op.h
@@ -0,0 +1,55 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
+#define CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class NormalizeL1Op final : public Operator<Context> {
+ public:
+ USE_OPERATOR_CONTEXT_FUNCTIONS;
+ USE_SIMPLE_CTOR_DTOR(NormalizeL1Op)
+
+ bool RunOnDevice() override {
+ const auto& x = Input(0);
+ auto* y = Output(0);
+ const auto* xData = x.template data<T>();
+ y->ResizeLike(x);
+ auto* yData = y->template mutable_data<T>();
+
+ const auto canonical_axis = x.canonical_axis_index(
+ OperatorBase::GetSingleArgument<int>("axis", -1));
+ const int m = x.dim32(canonical_axis);
+ const int n = x.size() / m;
+ const int sf = x.size_from_dim(canonical_axis + 1);
+ DoNormalize(xData, yData, m, n, sf);
+ return true;
+ }
+
+ private:
+ void
+ DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
diff --git a/caffe2/operators/normalize_op.cc b/caffe2/operators/normalize_op.cc
index f75be3e..dcd010d 100644
--- a/caffe2/operators/normalize_op.cc
+++ b/caffe2/operators/normalize_op.cc
@@ -17,7 +17,6 @@
#include "caffe2/operators/normalize_op.h"
#include "caffe2/core/tensor.h"
-#include "caffe2/utils/math.h"
namespace caffe2 {
@@ -105,37 +104,4 @@
};
REGISTER_GRADIENT(Normalize, GetNormalizeGradient);
-template <typename T, class Context>
-void NormalizeL1Op<T, Context>::DoNormalize(
- const T* xData,
- T* yData,
- const int m,
- const int n,
- const int sf) {
- using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
- using StridedVec =
- Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
- using ConstStridedVec =
- Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
-
- for (int i = 0; i < n; ++i) {
- auto base = (i / sf) * sf * m + (i % sf);
- ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
- auto norm = xVec.template lpNorm<1>();
- if (norm != 0) {
- StridedVec yVec(yData + base, 1, m, InnerStride(sf));
- yVec = xVec / norm;
- }
- }
-};
-
-REGISTER_CPU_OPERATOR(NormalizeL1, NormalizeL1Op<float, CPUContext>);
-OPERATOR_SCHEMA(NormalizeL1)
- .NumInputs(1)
- .NumOutputs(1)
- .Arg("axis", "axis to normalize")
- .SetDoc(R"DOC(
-Given a matrix, apply L1-normalization along the specified axis.
-)DOC");
-
} // namespace caffe2
diff --git a/caffe2/operators/normalize_op.h b/caffe2/operators/normalize_op.h
index 11ffb14..7e92d5a 100644
--- a/caffe2/operators/normalize_op.h
+++ b/caffe2/operators/normalize_op.h
@@ -90,33 +90,6 @@
OUTPUT_TAGS(GRAD_IN);
};
-template <typename T, class Context>
-class NormalizeL1Op final : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(NormalizeL1Op)
-
- bool RunOnDevice() override {
- const auto& x = Input(0);
- auto* y = Output(0);
- const auto* xData = x.template data<T>();
- y->ResizeLike(x);
- auto* yData = y->template mutable_data<T>();
-
- const auto canonical_axis = x.canonical_axis_index(
- OperatorBase::GetSingleArgument<int>("axis", -1));
- const int m = x.dim32(canonical_axis);
- const int n = x.size() / m;
- const int sf = x.size_from_dim(canonical_axis + 1);
- DoNormalize(xData, yData, m, n, sf);
- return true;
- }
-
- private:
- void
- DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
-};
-
} // namespace caffe2
#endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
diff --git a/caffe2/operators/normalize_op.cu b/caffe2/operators/normalize_ops.cu
similarity index 98%
rename from caffe2/operators/normalize_op.cu
rename to caffe2/operators/normalize_ops.cu
index a97bba5..9fb1926 100644
--- a/caffe2/operators/normalize_op.cu
+++ b/caffe2/operators/normalize_ops.cu
@@ -17,6 +17,7 @@
#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/normalize_l1_op.h"
#include "caffe2/operators/normalize_op.h"
namespace caffe2 {
@@ -183,4 +184,4 @@
NormalizeGradient,
NormalizeGradientOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(NormalizeL1, NormalizeL1Op<float, CUDAContext>);
-} // namespace
+} // namespace caffe2
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index ea097a8..8ac5d17 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -74,9 +74,7 @@
REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
REGISTER_CPU_OPERATOR(Print, PrintOp<CPUContext>);
-REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
REGISTER_CPU_OPERATOR(FlattenToVec, FlattenToVecOp<CPUContext>);
-
REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>);
REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
@@ -85,10 +83,6 @@
REGISTER_CPU_OPERATOR(
ScatterWeightedSum,
ScatterWeightedSumOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
// From whatever the current context, ensure the output is TensorCPU
REGISTER_CPU_OPERATOR(
@@ -140,48 +134,6 @@
OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1);
-OPERATOR_SCHEMA(Flatten)
- .NumInputs(1)
- .NumOutputs(1)
- .TensorInferenceFunction([](const OperatorDef& def,
- const vector<TensorShape>& in) {
- ArgumentHelper helper(def);
- const int axis = helper.GetSingleArgument<int>("axis", 1);
- vector<TensorShape> out(1);
- TIndex outer = 1;
- TIndex inner = 1;
- std::size_t index = 0;
- for (auto d : in[0].dims()) {
- if (index < axis) {
- outer *= d;
- } else {
- inner *= d;
- }
- ++index;
- }
- out[0].set_data_type(in[0].data_type());
- out[0].add_dims(outer);
- out[0].add_dims(inner);
- return out;
- })
- .SetDoc(R"DOC(
-Flattens the input tensor into a 2D matrix. If input tensor has shape
-(d_0, d_1, ... d_n) then the output will have shape
-(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
-)DOC")
- .Input(0, "input", "A tensor of rank >= axis.")
- .Output(
- 0,
- "output",
- "A 2D tensor with the contents of the input tensor, "
- "with input dimensions up to axis flattened to the outer dimension "
- "of the output and remaining input dimensions flattened into the inner "
- "dimension of the output.")
- .Arg(
- "axis",
- "(Default to 1) Indicate up to which input dimensions "
- "(exclusive) should be flattened to the outer dimension of the output");
-
OPERATOR_SCHEMA(FlattenToVec)
.NumInputs(1)
.NumOutputs(1)
@@ -319,38 +271,6 @@
.Output(0, "X_0", "Has to be exactly the same tensor as the input 0")
.EnforceInplace({{0, 0}});
-OPERATOR_SCHEMA(Max)
- .NumInputs(1, INT_MAX)
- .NumOutputs(1)
- .IdenticalTypeAndShapeOfInput(0)
- .AllowInplace({{0, 0}})
- .SetDoc(R"DOC(
-Element-wise max of each of the input tensors. The first input tensor can be
-used in-place as the output tensor, in which case the max will be done in
-place and results will be accumulated in input0. All inputs and outputs must
-have the same shape and data type.
-)DOC")
- .Input(0, "data_0", "First of the input tensors. Can be inplace.")
- .Output(0, "max", "Output tensor. Same dimension as inputs.");
-
-OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
-
-OPERATOR_SCHEMA(Min)
- .NumInputs(1, INT_MAX)
- .NumOutputs(1)
- .IdenticalTypeAndShapeOfInput(0)
- .AllowInplace({{0, 0}})
- .SetDoc(R"DOC(
-Element-wise min of each of the input tensors. The first input tensor can be
-used in-place as the output tensor, in which case the min will be done in
-place and results will be accumulated in input0. All inputs and outputs must
-have the same shape and data type.
-)DOC")
- .Input(0, "data_0", "First of the input tensors. Can be inplace.")
- .Output(0, "min", "Output tensor. Same dimension as inputs.");
-
-OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
-
OPERATOR_SCHEMA(ScatterAssign)
.NumInputs(3)
.NumOutputs(1)
@@ -822,15 +742,6 @@
SHOULD_NOT_DO_GRADIENT(LengthsToShape);
SHOULD_NOT_DO_GRADIENT(UnsafeCoalesce);
-class GetFlattenGradient : public GradientMakerBase {
- using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- return SingleGradientDef(
- "ResizeLike", "", vector<string>{GO(0), I(0)}, vector<string>{GI(0)});
- }
-};
-REGISTER_GRADIENT(Flatten, GetFlattenGradient);
-
class GetAliasGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
@@ -883,34 +794,6 @@
};
REGISTER_GRADIENT(WeightedSum, GetWeightedSumGradient);
-class GetMaxGradient : public GradientMakerBase {
- using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- auto gradInputs = vector<string>();
- auto inputs = vector<string>{O(0), GO(0)};
- for (int i = 0; i < def_.input_size(); i++) {
- gradInputs.push_back(GI(i));
- inputs.push_back(I(i));
- }
- return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
- }
-};
-REGISTER_GRADIENT(Max, GetMaxGradient);
-
-class GetMinGradient : public GradientMakerBase {
- using GradientMakerBase::GradientMakerBase;
- vector<OperatorDef> GetGradientDefs() override {
- auto gradInputs = vector<string>();
- auto inputs = vector<string>{O(0), GO(0)};
- for (int i = 0; i < def_.input_size(); i++) {
- gradInputs.push_back(GI(i));
- inputs.push_back(I(i));
- }
- return SingleGradientDef("MinGradient", "", inputs, gradInputs);
- }
-};
-REGISTER_GRADIENT(Min, GetMinGradient);
-
class GetGatherGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
@@ -1013,63 +896,6 @@
SHOULD_NOT_DO_GRADIENT(LengthsGather);
SHOULD_NOT_DO_GRADIENT(AccumulateHistogram);
-template <typename T, class Context>
-bool MaxOp<T, Context>::Compute() {
- auto& input0 = Input(0);
- const int N = input0.size();
- T* output_data = Output(0)->template mutable_data<T>();
-
- for (int i = 1; i < InputSize(); i++) {
- auto input_data = Input(i).template data<T>();
- EigenVectorMap<T> output_vec(output_data, N);
- output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N));
- }
-
- return true;
-}
-
-template <typename T, class Context>
-bool MinOp<T, Context>::Compute() {
- auto& input0 = Input(0);
- const int N = input0.size();
- T* output_data = Output(0)->template mutable_data<T>();
-
- for (int i = 1; i < InputSize(); i++) {
- auto input_data = Input(i).template data<T>();
- EigenVectorMap<T> output_vec(output_data, N);
- output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N));
- }
-
- return true;
-}
-
-template <typename T, class Context>
-bool SelectGradientOpBase<T, Context>::RunOnDevice() {
- auto& output = Input(0);
- auto& grad_output = Input(1);
- const int kInputStartOffset = 2;
-
- const T* data = output.template data<T>();
- ConstEigenArrayMap<T> output_array(
- output.template data<T>(), 1, output.size());
- ConstEigenArrayMap<T> grad_out_array(
- grad_output.template data<T>(), 1, grad_output.size());
-
- for (int i = 0; i < OutputSize(); i++) {
- auto& input = Input(i + kInputStartOffset);
- ConstEigenArrayMap<T> input_array(
- input.template data<T>(), 1, input.size());
-
- auto* grad_input = Output(i);
- grad_input->ResizeLike(input);
- EigenArrayMap<T> grad_in_array(
- grad_input->template mutable_data<T>(), 1, grad_input->size());
- grad_in_array = grad_out_array *
- input_array.cwiseEqual(output_array).template cast<T>();
- }
- return true;
-}
-
template <>
bool NanCheckOp<CPUContext>::RunOnDevice() {
auto& X = Input(0);
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 6172609..1644b7f 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -26,6 +26,8 @@
#include <thrust/system/cuda/execution_policy.h>
#include <thrust/unique.h>
#include "caffe2/core/context_gpu.h"
+#include "flatten_op.h"
+#include "minmax_ops.h"
#include "utility_ops.h"
namespace caffe2 {
@@ -33,7 +35,6 @@
REGISTER_CUDA_OPERATOR(EnsureDense, EnsureDenseOp<CUDAContext>);
-
__global__ void NanCheckKernel(int N, const float* X, bool* result) {
bool has_nan = false;
CUDA_1D_KERNEL_LOOP(i, N) {
@@ -220,13 +221,17 @@
return true;
}
-template<typename T_INDEX>
-__global__ void
-GatherKernel(const float* X, float* Y, const T_INDEX* indices, const int N, const int block_size) {
+template <typename T_INDEX>
+__global__ void GatherKernel(
+ const float* X,
+ float* Y,
+ const T_INDEX* indices,
+ const int N,
+ const int block_size) {
for (int i = blockIdx.x; i < N; i += gridDim.x) {
T_INDEX idx = indices[i];
const float* src_offset = X + idx * block_size;
- float* dst_offset = Y + i * block_size;
+ float* dst_offset = Y + i * block_size;
for (int j = threadIdx.x; j < block_size; j += blockDim.x) {
dst_offset[j] = src_offset[j];
}
@@ -235,7 +240,7 @@
template <>
bool GatherOp<CUDAContext>::RunOnDevice() {
- return DispatchHelper<TensorTypes<int32_t,int64_t>>::call(
+ return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, OperatorBase::Input<TensorCUDA>(INDICES));
}
@@ -272,9 +277,7 @@
std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
- context_.cuda_stream()>>>(
- src_base, out, idxs, N, block_size
- );
+ context_.cuda_stream()>>>(src_base, out, idxs, N, block_size);
return true;
}
@@ -320,7 +323,7 @@
template <>
template <typename Index>
-bool ScatterWeightedSumOp<float,CUDAContext>::DoRunWithType() {
+bool ScatterWeightedSumOp<float, CUDAContext>::DoRunWithType() {
CAFFE_ENFORCE_EQ(InputSize() % 2, 1);
auto& X0 = Input(0);
auto& weight0 = Input(1);
@@ -551,4 +554,4 @@
}
REGISTER_CUDA_OPERATOR(Range, RangeOp<CUDAContext>);
-} // namespace caffe2
+} // namespace caffe2
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 87a9bfe..dac1855 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -219,33 +219,6 @@
};
template <class Context>
-class FlattenOp : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
-
- FlattenOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
- axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
-
- bool RunOnDevice() override {
- auto& input = Input(0);
- auto* output = Output(0);
- CAFFE_ENFORCE_GE(
- input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
- output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
- context_.template CopyItems<Context, Context>(
- input.meta(),
- input.size(),
- input.raw_data(),
- output->raw_mutable_data(input.meta()));
- return true;
- }
-
- private:
- int axis_;
-};
-
-template <class Context>
class FlattenToVecOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -581,87 +554,6 @@
Tensor<Context> weights_device_;
};
-template <typename T, class Context>
-class MaxMinOpBase : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
-
- bool RunOnDevice() override {
- auto& input0 = Input(0);
- auto* output = Output(0);
-
- output->ResizeLike(input0);
- output->CopyFrom(input0, &context_);
-
- if (InputSize() == 1) {
- return true;
- }
-
- // Dimension checking
- for (int i = 1; i < InputSize(); ++i) {
- CAFFE_ENFORCE_EQ(
- output->dims(),
- Input(i).dims(),
- "Description: Input #",
- i,
- ", input dimension:",
- Input(i).dims(),
- " should match output dimension: ",
- output->dims());
- }
-
- return this->Compute();
- }
-
- virtual bool Compute() = 0;
-};
-
-template <typename T, class Context>
-class SelectGradientOpBase : public Operator<Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
-
- bool RunOnDevice() override;
-};
-
-template <typename T, class Context>
-class MaxOp : public MaxMinOpBase<T, Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- MaxOp(const OperatorDef& operator_def, Workspace* ws)
- : MaxMinOpBase<T, Context>(operator_def, ws) {}
- virtual ~MaxOp() noexcept {}
- bool Compute() override;
-};
-
-template <typename T, class Context>
-class MaxGradientOp : public SelectGradientOpBase<T, Context> {
- public:
- MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
- : SelectGradientOpBase<T, Context>(operator_def, ws) {}
- virtual ~MaxGradientOp() noexcept {}
-};
-
-template <typename T, class Context>
-class MinOp : public MaxMinOpBase<T, Context> {
- public:
- USE_OPERATOR_CONTEXT_FUNCTIONS;
- MinOp(const OperatorDef& operator_def, Workspace* ws)
- : MaxMinOpBase<T, Context>(operator_def, ws) {}
- virtual ~MinOp() noexcept {}
- bool Compute() override;
-};
-
-template <typename T, class Context>
-class MinGradientOp : public SelectGradientOpBase<T, Context> {
- public:
- MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
- : SelectGradientOpBase<T, Context>(operator_def, ws) {}
- virtual ~MinGradientOp() noexcept {}
-};
-
/**
* @brief Update slices of the tensor in-place by overriding.
*
diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc
index 880d032..36d4460 100644
--- a/caffe2/operators/utility_ops_gpu.cc
+++ b/caffe2/operators/utility_ops_gpu.cc
@@ -15,7 +15,7 @@
*/
#include "caffe2/core/context_gpu.h"
-#include "caffe2/operators/reshape_op.h"
+#include "caffe2/operators/flatten_op.h"
#include "caffe2/operators/utility_ops.h"
#include "caffe2/utils/math.h"