Extract some utility operators to their own source files to reduce build size.

Summary: Extract some operators from utility_ops and normalize_op to reduce build size impact of depending on these files.

Reviewed By: Maratyszcza

Differential Revision: D6616741

fbshipit-source-id: 1757b6b8a3ce4e2a248deee61322344e5095e940
diff --git a/caffe2/operators/flatten_op.cc b/caffe2/operators/flatten_op.cc
new file mode 100644
index 0000000..48f6573
--- /dev/null
+++ b/caffe2/operators/flatten_op.cc
@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/flatten_op.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
+
+OPERATOR_SCHEMA(Flatten)
+    .NumInputs(1)
+    .NumOutputs(1)
+    .TensorInferenceFunction([](const OperatorDef& def,
+                                const vector<TensorShape>& in) {
+      ArgumentHelper helper(def);
+      const int axis = helper.GetSingleArgument<int>("axis", 1);
+      vector<TensorShape> out(1);
+      TIndex outer = 1;
+      TIndex inner = 1;
+      std::size_t index = 0;
+      for (auto d : in[0].dims()) {
+        if (index < axis) {
+          outer *= d;
+        } else {
+          inner *= d;
+        }
+        ++index;
+      }
+      out[0].set_data_type(in[0].data_type());
+      out[0].add_dims(outer);
+      out[0].add_dims(inner);
+      return out;
+    })
+    .SetDoc(R"DOC(
+Flattens the input tensor into a 2D matrix. If input tensor has shape
+(d_0, d_1, ... d_n) then the output will have shape
+(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
+)DOC")
+    .Input(0, "input", "A tensor of rank >= axis.")
+    .Output(
+        0,
+        "output",
+        "A 2D tensor with the contents of the input tensor, "
+        "with input dimensions up to axis flattened to the outer dimension "
+        "of the output and remaining input dimensions flattened into the inner "
+        "dimension of the output.")
+    .Arg(
+        "axis",
+        "(Default to 1) Indicate up to which input dimensions "
+        "(exclusive) should be flattened to the outer dimension of the output");
+
+class GetFlattenGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  vector<OperatorDef> GetGradientDefs() override {
+    return SingleGradientDef(
+        "ResizeLike", "", vector<string>{GO(0), I(0)}, vector<string>{GI(0)});
+  }
+};
+
+REGISTER_GRADIENT(Flatten, GetFlattenGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/flatten_op.h b/caffe2/operators/flatten_op.h
new file mode 100644
index 0000000..db0a1b7
--- /dev/null
+++ b/caffe2/operators/flatten_op.h
@@ -0,0 +1,53 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
+#define CAFFE2_OPERATORS_FLATTEN_OP_H_
+
+#include "caffe2/core/operator.h"
+
+namespace caffe2 {
+
+template <class Context>
+class FlattenOp : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+
+  FlattenOp(const OperatorDef& operator_def, Workspace* ws)
+      : Operator<Context>(operator_def, ws),
+        axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
+
+  bool RunOnDevice() override {
+    auto& input = Input(0);
+    auto* output = Output(0);
+    CAFFE_ENFORCE_GE(
+        input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
+    output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
+    context_.template CopyItems<Context, Context>(
+        input.meta(),
+        input.size(),
+        input.raw_data(),
+        output->raw_mutable_data(input.meta()));
+    return true;
+  }
+
+ private:
+  int axis_;
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
diff --git a/caffe2/operators/minmax_gradient_ops.cc b/caffe2/operators/minmax_gradient_ops.cc
new file mode 100644
index 0000000..4eb091b
--- /dev/null
+++ b/caffe2/operators/minmax_gradient_ops.cc
@@ -0,0 +1,82 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/minmax_ops.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
+
+OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
+
+template <typename T, class Context>
+bool SelectGradientOpBase<T, Context>::RunOnDevice() {
+  auto& output = Input(0);
+  auto& grad_output = Input(1);
+  const int kInputStartOffset = 2;
+
+  const T* data = output.template data<T>();
+  ConstEigenArrayMap<T> output_array(
+      output.template data<T>(), 1, output.size());
+  ConstEigenArrayMap<T> grad_out_array(
+      grad_output.template data<T>(), 1, grad_output.size());
+
+  for (int i = 0; i < OutputSize(); i++) {
+    auto& input = Input(i + kInputStartOffset);
+    ConstEigenArrayMap<T> input_array(
+        input.template data<T>(), 1, input.size());
+
+    auto* grad_input = Output(i);
+    grad_input->ResizeLike(input);
+    EigenArrayMap<T> grad_in_array(
+        grad_input->template mutable_data<T>(), 1, grad_input->size());
+    grad_in_array = grad_out_array *
+        input_array.cwiseEqual(output_array).template cast<T>();
+  }
+  return true;
+}
+
+class GetMaxGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  vector<OperatorDef> GetGradientDefs() override {
+    auto gradInputs = vector<string>();
+    auto inputs = vector<string>{O(0), GO(0)};
+    for (int i = 0; i < def_.input_size(); i++) {
+      gradInputs.push_back(GI(i));
+      inputs.push_back(I(i));
+    }
+    return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
+  }
+};
+REGISTER_GRADIENT(Max, GetMaxGradient);
+
+class GetMinGradient : public GradientMakerBase {
+  using GradientMakerBase::GradientMakerBase;
+  vector<OperatorDef> GetGradientDefs() override {
+    auto gradInputs = vector<string>();
+    auto inputs = vector<string>{O(0), GO(0)};
+    for (int i = 0; i < def_.input_size(); i++) {
+      gradInputs.push_back(GI(i));
+      inputs.push_back(I(i));
+    }
+    return SingleGradientDef("MinGradient", "", inputs, gradInputs);
+  }
+};
+REGISTER_GRADIENT(Min, GetMinGradient);
+
+} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.cc b/caffe2/operators/minmax_ops.cc
new file mode 100644
index 0000000..4f47ce4
--- /dev/null
+++ b/caffe2/operators/minmax_ops.cc
@@ -0,0 +1,82 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/minmax_ops.h"
+
+namespace caffe2 {
+
+REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
+REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>);
+
+OPERATOR_SCHEMA(Max)
+    .NumInputs(1, INT_MAX)
+    .NumOutputs(1)
+    .IdenticalTypeAndShapeOfInput(0)
+    .AllowInplace({{0, 0}})
+    .SetDoc(R"DOC(
+  Element-wise max of each of the input tensors. The first input tensor can be
+  used in-place as the output tensor, in which case the max will be done in
+  place and results will be accumulated in input0. All inputs and outputs must
+  have the same shape and data type.
+  )DOC")
+    .Input(0, "data_0", "First of the input tensors. Can be inplace.")
+    .Output(0, "max", "Output tensor. Same dimension as inputs.");
+
+OPERATOR_SCHEMA(Min)
+    .NumInputs(1, INT_MAX)
+    .NumOutputs(1)
+    .IdenticalTypeAndShapeOfInput(0)
+    .AllowInplace({{0, 0}})
+    .SetDoc(R"DOC(
+Element-wise min of each of the input tensors. The first input tensor can be
+used in-place as the output tensor, in which case the min will be done in
+place and results will be accumulated in input0. All inputs and outputs must
+have the same shape and data type.
+)DOC")
+    .Input(0, "data_0", "First of the input tensors. Can be inplace.")
+    .Output(0, "min", "Output tensor. Same dimension as inputs.");
+
+template <typename T, class Context>
+bool MaxOp<T, Context>::Compute() {
+  auto& input0 = Input(0);
+  const int N = input0.size();
+  T* output_data = Output(0)->template mutable_data<T>();
+
+  for (int i = 1; i < InputSize(); i++) {
+    auto input_data = Input(i).template data<T>();
+    EigenVectorMap<T> output_vec(output_data, N);
+    output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N));
+  }
+
+  return true;
+}
+
+template <typename T, class Context>
+bool MinOp<T, Context>::Compute() {
+  auto& input0 = Input(0);
+  const int N = input0.size();
+  T* output_data = Output(0)->template mutable_data<T>();
+
+  for (int i = 1; i < InputSize(); i++) {
+    auto input_data = Input(i).template data<T>();
+    EigenVectorMap<T> output_vec(output_data, N);
+    output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N));
+  }
+
+  return true;
+}
+
+} // namespace caffe2
diff --git a/caffe2/operators/minmax_ops.h b/caffe2/operators/minmax_ops.h
new file mode 100644
index 0000000..c4cc642
--- /dev/null
+++ b/caffe2/operators/minmax_ops.h
@@ -0,0 +1,112 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_MINMAX_OPS_H_
+#define CAFFE2_OPERATORS_MINMAX_OPS_H_
+
+#include "caffe2/core/common_omp.h"
+#include "caffe2/core/context.h"
+#include "caffe2/core/logging.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/core/types.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class MaxMinOpBase : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
+
+  bool RunOnDevice() override {
+    auto& input0 = Input(0);
+    auto* output = Output(0);
+
+    output->ResizeLike(input0);
+    output->CopyFrom(input0, &context_);
+
+    if (InputSize() == 1) {
+      return true;
+    }
+
+    // Dimension checking
+    for (int i = 1; i < InputSize(); ++i) {
+      CAFFE_ENFORCE_EQ(
+          output->dims(),
+          Input(i).dims(),
+          "Description: Input #",
+          i,
+          ", input dimension:",
+          Input(i).dims(),
+          " should match output dimension: ",
+          output->dims());
+    }
+
+    return this->Compute();
+  }
+
+  virtual bool Compute() = 0;
+};
+
+template <typename T, class Context>
+class MaxOp : public MaxMinOpBase<T, Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  MaxOp(const OperatorDef& operator_def, Workspace* ws)
+      : MaxMinOpBase<T, Context>(operator_def, ws) {}
+  virtual ~MaxOp() noexcept {}
+  bool Compute() override;
+};
+
+template <typename T, class Context>
+class SelectGradientOpBase : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
+
+  bool RunOnDevice() override;
+};
+
+template <typename T, class Context>
+class MaxGradientOp : public SelectGradientOpBase<T, Context> {
+ public:
+  MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
+      : SelectGradientOpBase<T, Context>(operator_def, ws) {}
+  virtual ~MaxGradientOp() noexcept {}
+};
+
+template <typename T, class Context>
+class MinOp : public MaxMinOpBase<T, Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  MinOp(const OperatorDef& operator_def, Workspace* ws)
+      : MaxMinOpBase<T, Context>(operator_def, ws) {}
+  virtual ~MinOp() noexcept {}
+  bool Compute() override;
+};
+
+template <typename T, class Context>
+class MinGradientOp : public SelectGradientOpBase<T, Context> {
+ public:
+  MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
+      : SelectGradientOpBase<T, Context>(operator_def, ws) {}
+  virtual ~MinGradientOp() noexcept {}
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_MINMAX_OPS_H_
diff --git a/caffe2/operators/normalize_l1_op.cc b/caffe2/operators/normalize_l1_op.cc
new file mode 100644
index 0000000..13d9ee4
--- /dev/null
+++ b/caffe2/operators/normalize_l1_op.cc
@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "caffe2/operators/normalize_l1_op.h"
+
+#include "caffe2/core/tensor.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+void NormalizeL1Op<T, Context>::DoNormalize(
+    const T* xData,
+    T* yData,
+    const int m,
+    const int n,
+    const int sf) {
+  using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
+  using StridedVec =
+      Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
+  using ConstStridedVec =
+      Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
+
+  for (int i = 0; i < n; ++i) {
+    auto base = (i / sf) * sf * m + (i % sf);
+    ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
+    auto norm = xVec.template lpNorm<1>();
+    if (norm != 0) {
+      StridedVec yVec(yData + base, 1, m, InnerStride(sf));
+      yVec = xVec / norm;
+    }
+  }
+};
+
+REGISTER_CPU_OPERATOR(NormalizeL1, NormalizeL1Op<float, CPUContext>);
+OPERATOR_SCHEMA(NormalizeL1)
+    .NumInputs(1)
+    .NumOutputs(1)
+    .Arg("axis", "axis to normalize")
+    .SetDoc(R"DOC(
+  Given a matrix, apply L1-normalization along the specified axis.
+  )DOC");
+
+} // namespace caffe2
diff --git a/caffe2/operators/normalize_l1_op.h b/caffe2/operators/normalize_l1_op.h
new file mode 100644
index 0000000..bffff04
--- /dev/null
+++ b/caffe2/operators/normalize_l1_op.h
@@ -0,0 +1,55 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
+#define CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
+
+#include "caffe2/core/context.h"
+#include "caffe2/core/operator.h"
+#include "caffe2/utils/math.h"
+
+namespace caffe2 {
+
+template <typename T, class Context>
+class NormalizeL1Op final : public Operator<Context> {
+ public:
+  USE_OPERATOR_CONTEXT_FUNCTIONS;
+  USE_SIMPLE_CTOR_DTOR(NormalizeL1Op)
+
+  bool RunOnDevice() override {
+    const auto& x = Input(0);
+    auto* y = Output(0);
+    const auto* xData = x.template data<T>();
+    y->ResizeLike(x);
+    auto* yData = y->template mutable_data<T>();
+
+    const auto canonical_axis = x.canonical_axis_index(
+        OperatorBase::GetSingleArgument<int>("axis", -1));
+    const int m = x.dim32(canonical_axis);
+    const int n = x.size() / m;
+    const int sf = x.size_from_dim(canonical_axis + 1);
+    DoNormalize(xData, yData, m, n, sf);
+    return true;
+  }
+
+ private:
+  void
+  DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
+};
+
+} // namespace caffe2
+
+#endif // CAFFE2_OPERATORS_NORMALIZE_L1_OP_H_
diff --git a/caffe2/operators/normalize_op.cc b/caffe2/operators/normalize_op.cc
index f75be3e..dcd010d 100644
--- a/caffe2/operators/normalize_op.cc
+++ b/caffe2/operators/normalize_op.cc
@@ -17,7 +17,6 @@
 #include "caffe2/operators/normalize_op.h"
 
 #include "caffe2/core/tensor.h"
-#include "caffe2/utils/math.h"
 
 namespace caffe2 {
 
@@ -105,37 +104,4 @@
 };
 REGISTER_GRADIENT(Normalize, GetNormalizeGradient);
 
-template <typename T, class Context>
-void NormalizeL1Op<T, Context>::DoNormalize(
-    const T* xData,
-    T* yData,
-    const int m,
-    const int n,
-    const int sf) {
-  using InnerStride = Eigen::InnerStride<Eigen::Dynamic>;
-  using StridedVec =
-      Eigen::Map<Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
-  using ConstStridedVec =
-      Eigen::Map<const Eigen::Matrix<T, 1, Eigen::Dynamic>, 0, InnerStride>;
-
-  for (int i = 0; i < n; ++i) {
-    auto base = (i / sf) * sf * m + (i % sf);
-    ConstStridedVec xVec(xData + base, 1, m, InnerStride(sf));
-    auto norm = xVec.template lpNorm<1>();
-    if (norm != 0) {
-      StridedVec yVec(yData + base, 1, m, InnerStride(sf));
-      yVec = xVec / norm;
-    }
-  }
-};
-
-REGISTER_CPU_OPERATOR(NormalizeL1, NormalizeL1Op<float, CPUContext>);
-OPERATOR_SCHEMA(NormalizeL1)
-    .NumInputs(1)
-    .NumOutputs(1)
-    .Arg("axis", "axis to normalize")
-    .SetDoc(R"DOC(
-Given a matrix, apply L1-normalization along the specified axis.
-)DOC");
-
 } // namespace caffe2
diff --git a/caffe2/operators/normalize_op.h b/caffe2/operators/normalize_op.h
index 11ffb14..7e92d5a 100644
--- a/caffe2/operators/normalize_op.h
+++ b/caffe2/operators/normalize_op.h
@@ -90,33 +90,6 @@
   OUTPUT_TAGS(GRAD_IN);
 };
 
-template <typename T, class Context>
-class NormalizeL1Op final : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  USE_SIMPLE_CTOR_DTOR(NormalizeL1Op)
-
-  bool RunOnDevice() override {
-    const auto& x = Input(0);
-    auto* y = Output(0);
-    const auto* xData = x.template data<T>();
-    y->ResizeLike(x);
-    auto* yData = y->template mutable_data<T>();
-
-    const auto canonical_axis = x.canonical_axis_index(
-        OperatorBase::GetSingleArgument<int>("axis", -1));
-    const int m = x.dim32(canonical_axis);
-    const int n = x.size() / m;
-    const int sf = x.size_from_dim(canonical_axis + 1);
-    DoNormalize(xData, yData, m, n, sf);
-    return true;
-  }
-
- private:
-  void
-  DoNormalize(const T* xData, T* yData, const int m, const int n, const int sf);
-};
-
 } // namespace caffe2
 
 #endif // CAFFE2_OPERATORS_NORMALIZE_OP_H_
diff --git a/caffe2/operators/normalize_op.cu b/caffe2/operators/normalize_ops.cu
similarity index 98%
rename from caffe2/operators/normalize_op.cu
rename to caffe2/operators/normalize_ops.cu
index a97bba5..9fb1926 100644
--- a/caffe2/operators/normalize_op.cu
+++ b/caffe2/operators/normalize_ops.cu
@@ -17,6 +17,7 @@
 #include <cub/block/block_reduce.cuh>
 
 #include "caffe2/core/context_gpu.h"
+#include "caffe2/operators/normalize_l1_op.h"
 #include "caffe2/operators/normalize_op.h"
 
 namespace caffe2 {
@@ -183,4 +184,4 @@
     NormalizeGradient,
     NormalizeGradientOp<float, CUDAContext>);
 REGISTER_CUDA_OPERATOR(NormalizeL1, NormalizeL1Op<float, CUDAContext>);
-} // namespace
+} // namespace caffe2
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index ea097a8..8ac5d17 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -74,9 +74,7 @@
 
 REGISTER_CPU_OPERATOR(WallClockTime, WallClockTimeOp<CPUContext>);
 REGISTER_CPU_OPERATOR(Print, PrintOp<CPUContext>);
-REGISTER_CPU_OPERATOR(Flatten, FlattenOp<CPUContext>);
 REGISTER_CPU_OPERATOR(FlattenToVec, FlattenToVecOp<CPUContext>);
-
 REGISTER_CPU_OPERATOR(Alias, AliasOp<CPUContext>);
 REGISTER_CPU_OPERATOR(ResizeLike, ResizeLikeOp<CPUContext>);
 REGISTER_CPU_OPERATOR(SumInt, SumOp<CPUContext>);
@@ -85,10 +83,6 @@
 REGISTER_CPU_OPERATOR(
     ScatterWeightedSum,
     ScatterWeightedSumOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(Max, MaxOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(MaxGradient, MaxGradientOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(Min, MinOp<float, CPUContext>);
-REGISTER_CPU_OPERATOR(MinGradient, MinGradientOp<float, CPUContext>);
 REGISTER_CPU_OPERATOR(ScatterAssign, ScatterAssignOp<CPUContext>);
 // From whatever the current context, ensure the output is TensorCPU
 REGISTER_CPU_OPERATOR(
@@ -140,48 +134,6 @@
 
 OPERATOR_SCHEMA(LengthsToShape).NumInputs(1).NumOutputs(1);
 
-OPERATOR_SCHEMA(Flatten)
-    .NumInputs(1)
-    .NumOutputs(1)
-    .TensorInferenceFunction([](const OperatorDef& def,
-                                const vector<TensorShape>& in) {
-      ArgumentHelper helper(def);
-      const int axis = helper.GetSingleArgument<int>("axis", 1);
-      vector<TensorShape> out(1);
-      TIndex outer = 1;
-      TIndex inner = 1;
-      std::size_t index = 0;
-      for (auto d : in[0].dims()) {
-        if (index < axis) {
-          outer *= d;
-        } else {
-          inner *= d;
-        }
-        ++index;
-      }
-      out[0].set_data_type(in[0].data_type());
-      out[0].add_dims(outer);
-      out[0].add_dims(inner);
-      return out;
-    })
-    .SetDoc(R"DOC(
-Flattens the input tensor into a 2D matrix. If input tensor has shape
-(d_0, d_1, ... d_n) then the output will have shape
-(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)
-)DOC")
-    .Input(0, "input", "A tensor of rank >= axis.")
-    .Output(
-        0,
-        "output",
-        "A 2D tensor with the contents of the input tensor, "
-        "with input dimensions up to axis flattened to the outer dimension "
-        "of the output and remaining input dimensions flattened into the inner "
-        "dimension of the output.")
-    .Arg(
-        "axis",
-        "(Default to 1) Indicate up to which input dimensions "
-        "(exclusive) should be flattened to the outer dimension of the output");
-
 OPERATOR_SCHEMA(FlattenToVec)
     .NumInputs(1)
     .NumOutputs(1)
@@ -319,38 +271,6 @@
     .Output(0, "X_0", "Has to be exactly the same tensor as the input 0")
     .EnforceInplace({{0, 0}});
 
-OPERATOR_SCHEMA(Max)
-    .NumInputs(1, INT_MAX)
-    .NumOutputs(1)
-    .IdenticalTypeAndShapeOfInput(0)
-    .AllowInplace({{0, 0}})
-    .SetDoc(R"DOC(
-Element-wise max of each of the input tensors. The first input tensor can be
-used in-place as the output tensor, in which case the max will be done in
-place and results will be accumulated in input0. All inputs and outputs must
-have the same shape and data type.
-)DOC")
-    .Input(0, "data_0", "First of the input tensors. Can be inplace.")
-    .Output(0, "max", "Output tensor. Same dimension as inputs.");
-
-OPERATOR_SCHEMA(MaxGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
-
-OPERATOR_SCHEMA(Min)
-    .NumInputs(1, INT_MAX)
-    .NumOutputs(1)
-    .IdenticalTypeAndShapeOfInput(0)
-    .AllowInplace({{0, 0}})
-    .SetDoc(R"DOC(
-Element-wise min of each of the input tensors. The first input tensor can be
-used in-place as the output tensor, in which case the min will be done in
-place and results will be accumulated in input0. All inputs and outputs must
-have the same shape and data type.
-)DOC")
-    .Input(0, "data_0", "First of the input tensors. Can be inplace.")
-    .Output(0, "min", "Output tensor. Same dimension as inputs.");
-
-OPERATOR_SCHEMA(MinGradient).NumInputs(3, INT_MAX).NumOutputs(1, INT_MAX);
-
 OPERATOR_SCHEMA(ScatterAssign)
     .NumInputs(3)
     .NumOutputs(1)
@@ -822,15 +742,6 @@
 SHOULD_NOT_DO_GRADIENT(LengthsToShape);
 SHOULD_NOT_DO_GRADIENT(UnsafeCoalesce);
 
-class GetFlattenGradient : public GradientMakerBase {
-  using GradientMakerBase::GradientMakerBase;
-  vector<OperatorDef> GetGradientDefs() override {
-    return SingleGradientDef(
-        "ResizeLike", "", vector<string>{GO(0), I(0)}, vector<string>{GI(0)});
-  }
-};
-REGISTER_GRADIENT(Flatten, GetFlattenGradient);
-
 class GetAliasGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
   vector<OperatorDef> GetGradientDefs() override {
@@ -883,34 +794,6 @@
 };
 REGISTER_GRADIENT(WeightedSum, GetWeightedSumGradient);
 
-class GetMaxGradient : public GradientMakerBase {
-  using GradientMakerBase::GradientMakerBase;
-  vector<OperatorDef> GetGradientDefs() override {
-    auto gradInputs = vector<string>();
-    auto inputs = vector<string>{O(0), GO(0)};
-    for (int i = 0; i < def_.input_size(); i++) {
-      gradInputs.push_back(GI(i));
-      inputs.push_back(I(i));
-    }
-    return SingleGradientDef("MaxGradient", "", inputs, gradInputs);
-  }
-};
-REGISTER_GRADIENT(Max, GetMaxGradient);
-
-class GetMinGradient : public GradientMakerBase {
-  using GradientMakerBase::GradientMakerBase;
-  vector<OperatorDef> GetGradientDefs() override {
-    auto gradInputs = vector<string>();
-    auto inputs = vector<string>{O(0), GO(0)};
-    for (int i = 0; i < def_.input_size(); i++) {
-      gradInputs.push_back(GI(i));
-      inputs.push_back(I(i));
-    }
-    return SingleGradientDef("MinGradient", "", inputs, gradInputs);
-  }
-};
-REGISTER_GRADIENT(Min, GetMinGradient);
-
 class GetGatherGradient : public GradientMakerBase {
   using GradientMakerBase::GradientMakerBase;
   vector<OperatorDef> GetGradientDefs() override {
@@ -1013,63 +896,6 @@
 SHOULD_NOT_DO_GRADIENT(LengthsGather);
 SHOULD_NOT_DO_GRADIENT(AccumulateHistogram);
 
-template <typename T, class Context>
-bool MaxOp<T, Context>::Compute() {
-  auto& input0 = Input(0);
-  const int N = input0.size();
-  T* output_data = Output(0)->template mutable_data<T>();
-
-  for (int i = 1; i < InputSize(); i++) {
-    auto input_data = Input(i).template data<T>();
-    EigenVectorMap<T> output_vec(output_data, N);
-    output_vec = output_vec.cwiseMax(ConstEigenVectorMap<T>(input_data, N));
-  }
-
-  return true;
-}
-
-template <typename T, class Context>
-bool MinOp<T, Context>::Compute() {
-  auto& input0 = Input(0);
-  const int N = input0.size();
-  T* output_data = Output(0)->template mutable_data<T>();
-
-  for (int i = 1; i < InputSize(); i++) {
-    auto input_data = Input(i).template data<T>();
-    EigenVectorMap<T> output_vec(output_data, N);
-    output_vec = output_vec.cwiseMin(ConstEigenVectorMap<T>(input_data, N));
-  }
-
-  return true;
-}
-
-template <typename T, class Context>
-bool SelectGradientOpBase<T, Context>::RunOnDevice() {
-  auto& output = Input(0);
-  auto& grad_output = Input(1);
-  const int kInputStartOffset = 2;
-
-  const T* data = output.template data<T>();
-  ConstEigenArrayMap<T> output_array(
-      output.template data<T>(), 1, output.size());
-  ConstEigenArrayMap<T> grad_out_array(
-      grad_output.template data<T>(), 1, grad_output.size());
-
-  for (int i = 0; i < OutputSize(); i++) {
-    auto& input = Input(i + kInputStartOffset);
-    ConstEigenArrayMap<T> input_array(
-        input.template data<T>(), 1, input.size());
-
-    auto* grad_input = Output(i);
-    grad_input->ResizeLike(input);
-    EigenArrayMap<T> grad_in_array(
-        grad_input->template mutable_data<T>(), 1, grad_input->size());
-    grad_in_array = grad_out_array *
-        input_array.cwiseEqual(output_array).template cast<T>();
-  }
-  return true;
-}
-
 template <>
 bool NanCheckOp<CPUContext>::RunOnDevice() {
   auto& X = Input(0);
diff --git a/caffe2/operators/utility_ops.cu b/caffe2/operators/utility_ops.cu
index 6172609..1644b7f 100644
--- a/caffe2/operators/utility_ops.cu
+++ b/caffe2/operators/utility_ops.cu
@@ -26,6 +26,8 @@
 #include <thrust/system/cuda/execution_policy.h>
 #include <thrust/unique.h>
 #include "caffe2/core/context_gpu.h"
+#include "flatten_op.h"
+#include "minmax_ops.h"
 #include "utility_ops.h"
 
 namespace caffe2 {
@@ -33,7 +35,6 @@
 
 REGISTER_CUDA_OPERATOR(EnsureDense, EnsureDenseOp<CUDAContext>);
 
-
 __global__ void NanCheckKernel(int N, const float* X, bool* result) {
   bool has_nan = false;
   CUDA_1D_KERNEL_LOOP(i, N) {
@@ -220,13 +221,17 @@
   return true;
 }
 
-template<typename T_INDEX>
-__global__ void
-GatherKernel(const float* X, float* Y, const T_INDEX* indices, const int N, const int block_size) {
+template <typename T_INDEX>
+__global__ void GatherKernel(
+    const float* X,
+    float* Y,
+    const T_INDEX* indices,
+    const int N,
+    const int block_size) {
   for (int i = blockIdx.x; i < N; i += gridDim.x) {
     T_INDEX idx = indices[i];
     const float* src_offset = X + idx * block_size;
-    float* dst_offset = Y + i   * block_size;
+    float* dst_offset = Y + i * block_size;
     for (int j = threadIdx.x; j < block_size; j += blockDim.x) {
       dst_offset[j] = src_offset[j];
     }
@@ -235,7 +240,7 @@
 
 template <>
 bool GatherOp<CUDAContext>::RunOnDevice() {
-  return DispatchHelper<TensorTypes<int32_t,int64_t>>::call(
+  return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
       this, OperatorBase::Input<TensorCUDA>(INDICES));
 }
 
@@ -272,9 +277,7 @@
       std::min(N, CAFFE_MAXIMUM_NUM_BLOCKS),
       CAFFE_CUDA_NUM_THREADS,
       0,
-      context_.cuda_stream()>>>(
-        src_base, out, idxs, N, block_size
-      );
+      context_.cuda_stream()>>>(src_base, out, idxs, N, block_size);
   return true;
 }
 
@@ -320,7 +323,7 @@
 
 template <>
 template <typename Index>
-bool ScatterWeightedSumOp<float,CUDAContext>::DoRunWithType() {
+bool ScatterWeightedSumOp<float, CUDAContext>::DoRunWithType() {
   CAFFE_ENFORCE_EQ(InputSize() % 2, 1);
   auto& X0 = Input(0);
   auto& weight0 = Input(1);
@@ -551,4 +554,4 @@
 }
 
 REGISTER_CUDA_OPERATOR(Range, RangeOp<CUDAContext>);
-}  // namespace caffe2
+} // namespace caffe2
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 87a9bfe..dac1855 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -219,33 +219,6 @@
 };
 
 template <class Context>
-class FlattenOp : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-
-  FlattenOp(const OperatorDef& operator_def, Workspace* ws)
-      : Operator<Context>(operator_def, ws),
-        axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
-
-  bool RunOnDevice() override {
-    auto& input = Input(0);
-    auto* output = Output(0);
-    CAFFE_ENFORCE_GE(
-        input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
-    output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
-    context_.template CopyItems<Context, Context>(
-        input.meta(),
-        input.size(),
-        input.raw_data(),
-        output->raw_mutable_data(input.meta()));
-    return true;
-  }
-
- private:
-  int axis_;
-};
-
-template <class Context>
 class FlattenToVecOp : public Operator<Context> {
  public:
   USE_OPERATOR_CONTEXT_FUNCTIONS;
@@ -581,87 +554,6 @@
   Tensor<Context> weights_device_;
 };
 
-template <typename T, class Context>
-class MaxMinOpBase : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  USE_SIMPLE_CTOR_DTOR(MaxMinOpBase)
-
-  bool RunOnDevice() override {
-    auto& input0 = Input(0);
-    auto* output = Output(0);
-
-    output->ResizeLike(input0);
-    output->CopyFrom(input0, &context_);
-
-    if (InputSize() == 1) {
-      return true;
-    }
-
-    // Dimension checking
-    for (int i = 1; i < InputSize(); ++i) {
-      CAFFE_ENFORCE_EQ(
-          output->dims(),
-          Input(i).dims(),
-          "Description: Input #",
-          i,
-          ", input dimension:",
-          Input(i).dims(),
-          " should match output dimension: ",
-          output->dims());
-    }
-
-    return this->Compute();
-  }
-
-  virtual bool Compute() = 0;
-};
-
-template <typename T, class Context>
-class SelectGradientOpBase : public Operator<Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  USE_SIMPLE_CTOR_DTOR(SelectGradientOpBase)
-
-  bool RunOnDevice() override;
-};
-
-template <typename T, class Context>
-class MaxOp : public MaxMinOpBase<T, Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  MaxOp(const OperatorDef& operator_def, Workspace* ws)
-      : MaxMinOpBase<T, Context>(operator_def, ws) {}
-  virtual ~MaxOp() noexcept {}
-  bool Compute() override;
-};
-
-template <typename T, class Context>
-class MaxGradientOp : public SelectGradientOpBase<T, Context> {
- public:
-  MaxGradientOp(const OperatorDef& operator_def, Workspace* ws)
-      : SelectGradientOpBase<T, Context>(operator_def, ws) {}
-  virtual ~MaxGradientOp() noexcept {}
-};
-
-template <typename T, class Context>
-class MinOp : public MaxMinOpBase<T, Context> {
- public:
-  USE_OPERATOR_CONTEXT_FUNCTIONS;
-  MinOp(const OperatorDef& operator_def, Workspace* ws)
-      : MaxMinOpBase<T, Context>(operator_def, ws) {}
-  virtual ~MinOp() noexcept {}
-  bool Compute() override;
-};
-
-template <typename T, class Context>
-class MinGradientOp : public SelectGradientOpBase<T, Context> {
- public:
-  MinGradientOp(const OperatorDef& operator_def, Workspace* ws)
-      : SelectGradientOpBase<T, Context>(operator_def, ws) {}
-  virtual ~MinGradientOp() noexcept {}
-};
-
 /**
  * @brief Update slices of the tensor in-place by overriding.
  *
diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc
index 880d032..36d4460 100644
--- a/caffe2/operators/utility_ops_gpu.cc
+++ b/caffe2/operators/utility_ops_gpu.cc
@@ -15,7 +15,7 @@
  */
 
 #include "caffe2/core/context_gpu.h"
-#include "caffe2/operators/reshape_op.h"
+#include "caffe2/operators/flatten_op.h"
 #include "caffe2/operators/utility_ops.h"
 #include "caffe2/utils/math.h"