misc changes to reduce binary size.
diff --git a/build.py b/build.py
index fce583c..331712d 100644
--- a/build.py
+++ b/build.py
@@ -41,6 +41,11 @@
     # True.
     USE_GLOG = False
 
+    # Whether to use RTTI or not. Note that this might not always work; to
+    # disable RTTI means that all your dependencies, most notably protobuf, have
+    # to be built without RTTI. If you don't know, leave USE_RTTI True.
+    USE_RTTI = False
+
     # Manually specified defines.
     DEFINES = []
 
diff --git a/build_android.py b/build_android.py
index c71fe27..f6ca328 100644
--- a/build_android.py
+++ b/build_android.py
@@ -31,6 +31,11 @@
     # True.
     USE_GLOG = False
 
+    # Whether to use RTTI or not. Note that this might not always work; to
+    # disable RTTI means that all your dependencies, most notably protobuf, have
+    # to be built without RTTI. If you don't know, leave USE_RTTI True.
+    USE_RTTI = False
+
     # Manually specified defines.
     DEFINES = []
 
@@ -46,9 +51,7 @@
     CFLAGS = []
 
     # Additional link flags you would like to add to the compilation.
-    LINKFLAGS = [
-        "-pie",
-    ]
+    LINKFLAGS = []
 
     ###########################################################################
     # (optional) CUDA. If you do not specify this, the GPU part of Caffe2 will
@@ -107,7 +110,7 @@
     # build command, do it here.
     ENVIRONMENTAL_VARS = {}
     # Optimization flags: -O2 in default.
-    OPTIMIZATION_FLAGS = ["-O2"]
+    OPTIMIZATION_FLAGS = ["-Os"]
 
 
 # brew.py
diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc
index 63206da..1acad9a 100644
--- a/caffe2/core/operator.cc
+++ b/caffe2/core/operator.cc
@@ -3,6 +3,7 @@
 
 #include "caffe2/core/net.h"
 #include "caffe2/core/operator.h"
+#include "caffe2/core/operator_gradient.h"
 #include "caffe2/core/workspace.h"
 #include "caffe2/proto/caffe2.pb.h"
 
@@ -146,4 +147,50 @@
 
 DEFINE_REGISTRY(GradientRegistry, vector<OperatorDef>, const OperatorDef&);
 
+
+vector<OperatorDef>* CreateGradientDefsInternal(
+    const OperatorDef& def, GetGradientDefBaseVerbose* obj) {
+  CAFFE_DCHECK_NOTNULL(obj);
+  vector<OperatorDef>* grad_defs = obj->Create(def);
+  CAFFE_CHECK(grad_defs != nullptr);
+  // Copy device option if needed.
+  if (obj->CopyDeviceOption() && def.has_device_option()) {
+    for (OperatorDef& grad_def : *grad_defs) {
+      grad_def.mutable_device_option()->CopyFrom(def.device_option());
+    }
+  }
+  // Copy engine if needed.
+  if (obj->CopyEngine() && def.has_engine()) {
+    for (OperatorDef& grad_def : *grad_defs) {
+      grad_def.set_engine(def.engine());
+    }
+  }
+  // Copy arguments if needed.
+  if (obj->CopyArguments() && def.arg_size()) {
+    for (OperatorDef& grad_def : *grad_defs) {
+      grad_def.mutable_arg()->CopyFrom(def.arg());
+    }
+  }
+  for (const OperatorDef& grad_def : *grad_defs) {
+    CAFFE_VLOG(1) << "Gradient: " << grad_def.DebugString();
+  }
+  delete obj;
+  return grad_defs;
+}
+
+vector<OperatorDef>* ThrowTheTowelIfGradientIsCalled::Create(
+    const OperatorDef& def) {
+  CAFFE_LOG_FATAL << "You should not call the gradient of operator of type "
+                  << def.type();
+  // Just to suppress compiler warnings
+  return new vector<OperatorDef>();
+}
+
+vector<OperatorDef>* GradientNotImplementedYet::Create(
+    const OperatorDef& def) {
+  CAFFE_LOG_FATAL << "Gradient for operator type "
+                  << def.type() << " has not been implemented yet.";
+  return nullptr;
+}
+
 }  // namespace caffe2
diff --git a/caffe2/core/operator_gradient.h b/caffe2/core/operator_gradient.h
index cf0c90b..6c9e72d2 100644
--- a/caffe2/core/operator_gradient.h
+++ b/caffe2/core/operator_gradient.h
@@ -14,48 +14,19 @@
 
 DECLARE_REGISTRY(GradientRegistry, vector<OperatorDef>, const OperatorDef&);
 
-template <class GetGradientDef>
-class GradientRegisterer {
- public:
-  GradientRegisterer(const string& key) {
-    GradientRegistry()->Register(
-        key, GradientRegisterer<GetGradientDef>::Creator);
-  }
 
-  static vector<OperatorDef>* Creator(const OperatorDef& def) {
-    CAFFE_VLOG(1) << "Creator: " << def.DebugString();
-    vector<OperatorDef>* grad_defs = GetGradientDef::Create(def);
-    CAFFE_CHECK(grad_defs != nullptr);
-    // Copy device option if needed.
-    if (GetGradientDef().CopyDeviceOption() && def.has_device_option()) {
-      for (OperatorDef& grad_def : *grad_defs) {
-        grad_def.mutable_device_option()->CopyFrom(def.device_option());
-      }
-    }
-    // Copy engine if needed.
-    if (GetGradientDef().CopyEngine() && def.has_engine()) {
-      for (OperatorDef& grad_def : *grad_defs) {
-        grad_def.set_engine(def.engine());
-      }
-    }
-    // Copy arguments if needed.
-    if (GetGradientDef().CopyArguments() && def.arg_size()) {
-      for (OperatorDef& grad_def : *grad_defs) {
-        grad_def.mutable_arg()->CopyFrom(def.arg());
-      }
-    }
-    for (const OperatorDef& grad_def : *grad_defs) {
-      CAFFE_VLOG(1) << "Gradient: " << grad_def.DebugString();
-    }
-    return grad_defs;
-  }
-};
-
-template <bool copy_device_option, bool copy_engine, bool copy_args>
 struct GetGradientDefBaseVerbose {
-  constexpr bool CopyDeviceOption() const { return copy_device_option; }
-  constexpr bool CopyEngine() const { return copy_engine; }
-  constexpr bool CopyArguments() const { return copy_args; }
+ public:
+  GetGradientDefBaseVerbose(
+      const bool copy_device_option, const bool copy_engine,
+      const bool copy_args)
+      : copy_device_option_(copy_device_option), copy_engine_(copy_engine),
+      copy_args_(copy_args) {}
+  virtual ~GetGradientDefBaseVerbose() {}
+
+  bool CopyDeviceOption() const { return copy_device_option_; }
+  bool CopyEngine() const { return copy_engine_; }
+  bool CopyArguments() const { return copy_args_; }
   inline static string I(const OperatorDef& def, const int i) {
     return def.input(i);
   }
@@ -68,32 +39,59 @@
   inline static string GO(const OperatorDef& def, const int i) {
     return GradientName(def.output(i));
   }
+
+  virtual vector<OperatorDef>* Create(const OperatorDef& def) {
+    NOT_IMPLEMENTED;
+    return nullptr;
+  }
+
   template <class... Args>
   inline static vector<OperatorDef>* SingleGradientDef(Args ... args) {
     return new vector<OperatorDef>{CreateOperatorDef(args...)};
   }
+
+  bool copy_device_option_;
+  bool copy_engine_;
+  bool copy_args_;
 };
-typedef struct GetGradientDefBaseVerbose<true, true, true> GetGradientDefBase;
+
+
+struct GetGradientDefBase : public GetGradientDefBaseVerbose {
+ public:
+  GetGradientDefBase() : GetGradientDefBaseVerbose(true, true, true) {}
+};
 
 struct NoGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return new vector<OperatorDef>();
   }
 };
 
+// This is used when the operator definition is designed to not have a gradient.
+// Calling a gradient on this operator def will cause Caffe2 to throw the towel.
 struct ThrowTheTowelIfGradientIsCalled : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
-    CAFFE_LOG_FATAL << "You should not call the gradient of operator of type "
-                    << def.type();
-    // Just to suppress compiler warnings
-    return new vector<OperatorDef>();
-  }
+  vector<OperatorDef>* Create(const OperatorDef& def) override;
 };
 
+// This should only be used sparsely when the gradient does exist, but we have
+// not implemented it yet.
 struct GradientNotImplementedYet : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
-    CAFFE_LOG_FATAL << "Gradient for operator type "
-                    << def.type() << " has not been implemented yet.";
+  vector<OperatorDef>* Create(const OperatorDef& def) override;
+};
+
+vector<OperatorDef>* CreateGradientDefsInternal(
+    const OperatorDef& def, GetGradientDefBaseVerbose* obj);
+
+template <class GetGradientDef>
+class GradientRegisterer {
+ public:
+  GradientRegisterer(const string& key) {
+    GradientRegistry()->Register(
+        key, GradientRegisterer<GetGradientDef>::Creator);
+  }
+
+  static vector<OperatorDef>* Creator(const OperatorDef& def) {
+    return CreateGradientDefsInternal(def, new GetGradientDef());
   }
 };
 
@@ -111,13 +109,6 @@
   GradientRegisterer<ThrowTheTowelIfGradientIsCalled>                          \
       g_GradientRegisterer_##name(#name)
 
-// SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have
-// gradient operators. If you attempt to call the gradient, a log fatal will
-// occur.
-#define SHOULD_NOT_DO_GRADIENT(name)                                           \
-  GradientRegisterer<ThrowTheTowelIfGradientIsCalled>                          \
-      g_GradientRegisterer_##name(#name)
-
 #define GRADIENT_NOT_IMPLEMENTED_YET(name)                                     \
   GradientRegisterer<GradientNotImplementedYet>                                \
       g_GradientRegisterer_##name(#name)
diff --git a/caffe2/core/operator_test.cc b/caffe2/core/operator_test.cc
index be6baf9..8bd252e 100644
--- a/caffe2/core/operator_test.cc
+++ b/caffe2/core/operator_test.cc
@@ -303,7 +303,7 @@
 }
 
 struct GetFooGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return new vector<OperatorDef>{
         CreateOperatorDef(
             "FooGradient", "",
diff --git a/caffe2/core/typeid.h b/caffe2/core/typeid.h
index c64ecfe..ea9454f 100644
--- a/caffe2/core/typeid.h
+++ b/caffe2/core/typeid.h
@@ -2,7 +2,9 @@
 #define CAFFE2_CORE_TYPEID_H_
 
 #include <map>
+#ifdef __GXX_RTTI
 #include <typeinfo>
+#endif
 
 #include "caffe2/core/common.h"
 
@@ -79,7 +81,13 @@
    * Returns the printable name of the type.
    */
   template <typename T>
-  static const char* Name() { return typeid(T).name(); }
+  static const char* Name() {
+#ifdef __GXX_RTTI
+    return typeid(T).name();
+#else  // __GXX_RTTI
+    return "(RTTI disabled, cannot show name)";
+#endif
+  }
   /**
    * Returns a TypeMeta object that corresponds to the typename T.
    */
diff --git a/caffe2/operators/averagepool_op.cc b/caffe2/operators/averagepool_op.cc
index 2d94268..45860a7 100644
--- a/caffe2/operators/averagepool_op.cc
+++ b/caffe2/operators/averagepool_op.cc
@@ -193,7 +193,7 @@
                       AveragePoolGradientOp<float, CPUContext>);
 
 struct GetAveragePoolGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "AveragePoolGradient", "",
         vector<string>{I(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/clip_op.cc b/caffe2/operators/clip_op.cc
index 3ee0105..33bef74 100644
--- a/caffe2/operators/clip_op.cc
+++ b/caffe2/operators/clip_op.cc
@@ -38,7 +38,7 @@
 REGISTER_CPU_OPERATOR(ClipGradient, ClipGradientOp<float, CPUContext>);
 
 struct GetClipGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "ClipGradient", "",
         vector<string>{O(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/conv_op.cc b/caffe2/operators/conv_op.cc
index 6cdadd2..2cfa874 100644
--- a/caffe2/operators/conv_op.cc
+++ b/caffe2/operators/conv_op.cc
@@ -7,7 +7,7 @@
 REGISTER_CPU_OPERATOR(ConvGradient, ConvGradientOp<float, CPUContext>);
 
 struct GetConvGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     CAFFE_CHECK_EQ(def.input_size(), 3);
     return SingleGradientDef(
         "ConvGradient", "",
diff --git a/caffe2/operators/conv_op_cudnn.cc b/caffe2/operators/conv_op_cudnn.cc
index 28bd798..da43830 100644
--- a/caffe2/operators/conv_op_cudnn.cc
+++ b/caffe2/operators/conv_op_cudnn.cc
@@ -380,7 +380,7 @@
 REGISTER_CUDNN_OPERATOR(ConvFp16Gradient, CudnnConvGradientOp<float16>);
 
 struct GetConvFp16Gradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     CAFFE_CHECK_EQ(def.input_size(), 3);
     return SingleGradientDef(
         "ConvFp16Gradient", "",
diff --git a/caffe2/operators/cross_entropy_op.cc b/caffe2/operators/cross_entropy_op.cc
index bec6d61..ec31767 100644
--- a/caffe2/operators/cross_entropy_op.cc
+++ b/caffe2/operators/cross_entropy_op.cc
@@ -58,7 +58,7 @@
                       LabelCrossEntropyGradientOp<float, CPUContext>);
 
 struct GetLabelCrossEntropyGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "LabelCrossEntropyGradient", "",
         vector<string>{I(def, 0), I(def, 1), GO(def, 0)},
diff --git a/caffe2/operators/depth_split_op.cc b/caffe2/operators/depth_split_op.cc
index e02c863..029551d 100644
--- a/caffe2/operators/depth_split_op.cc
+++ b/caffe2/operators/depth_split_op.cc
@@ -6,7 +6,7 @@
 REGISTER_CPU_OPERATOR(DepthConcat, DepthConcatOp<CPUContext>);
 
 struct GetDepthSplitGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     vector<string> grad_inputs;
     for (const string& out : def.output()) {
       grad_inputs.push_back(GradientName(out));
@@ -19,7 +19,7 @@
 REGISTER_GRADIENT(DepthSplit, GetDepthSplitGradient);
 
 struct GetDepthConcatGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     vector<string> grad_outputs;
     for (const string& in : def.input()) {
       grad_outputs.push_back(GradientName(in));
diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc
index 54f252d..a447e98 100644
--- a/caffe2/operators/dropout_op.cc
+++ b/caffe2/operators/dropout_op.cc
@@ -61,7 +61,7 @@
 REGISTER_CPU_OPERATOR(DropoutGrad, DropoutGradientOp<float, CPUContext>);
 
 struct GetDropoutGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "DropoutGrad", "",
         vector<string>{GO(def, 0), O(def, 1)},
diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc
index 0300771..26c7438 100644
--- a/caffe2/operators/fully_connected_op.cc
+++ b/caffe2/operators/fully_connected_op.cc
@@ -7,7 +7,7 @@
 REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp<float, CPUContext>);
 
 struct GetFCGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     CAFFE_CHECK_EQ(def.input_size(), 3);
     return SingleGradientDef(
         "FCGradient", "",
diff --git a/caffe2/operators/half_float_ops.cu b/caffe2/operators/half_float_ops.cu
index a274084..4bdd43a 100644
--- a/caffe2/operators/half_float_ops.cu
+++ b/caffe2/operators/half_float_ops.cu
@@ -64,7 +64,7 @@
 REGISTER_CUDA_OPERATOR(HalfToFloat, HalfToFloatCUDA);
 
 struct GetFloatToHalfGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "HalfToFloat", "",
         vector<string>{GO(def, 0)},
@@ -74,7 +74,7 @@
 REGISTER_GRADIENT(FloatToHalf, GetFloatToHalfGradient);
 
 struct GetHalfToFloatGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "FloatToHalf", "",
         vector<string>{GO(def, 0)},
diff --git a/caffe2/operators/l2_distance_op.cc b/caffe2/operators/l2_distance_op.cc
index d3c0413..0f618ef 100644
--- a/caffe2/operators/l2_distance_op.cc
+++ b/caffe2/operators/l2_distance_op.cc
@@ -35,7 +35,7 @@
                       SquaredL2DistanceGradientOp<float, CPUContext>);
 
 struct GetSquaredL2DistanceGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "SquaredL2DistanceGradient", "",
         vector<string>{I(def, 0), I(def, 1), GO(def, 0)},
diff --git a/caffe2/operators/local_response_normalization_op.cc b/caffe2/operators/local_response_normalization_op.cc
index 59ac2fa..5da070a 100644
--- a/caffe2/operators/local_response_normalization_op.cc
+++ b/caffe2/operators/local_response_normalization_op.cc
@@ -232,7 +232,7 @@
 REGISTER_CPU_OPERATOR(LRNGradient, LRNGradientOp<float, CPUContext>);
 
 struct GetLRNGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "LRNGradient", "",
         vector<string>{I(def, 0), O(def, 0), O(def, 1), GO(def, 0)},
diff --git a/caffe2/operators/loss_op.cc b/caffe2/operators/loss_op.cc
index 4d5e535..15c12e4 100644
--- a/caffe2/operators/loss_op.cc
+++ b/caffe2/operators/loss_op.cc
@@ -12,7 +12,7 @@
 
 
 struct GetAveragedLossGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "AveragedLossGradient", "",
         vector<string>{I(def, 0)},
@@ -22,7 +22,7 @@
 REGISTER_GRADIENT(AveragedLoss, GetAveragedLossGradient);
 
 struct GetWeightedSumLossGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "WeightedSumLossGradient", "",
         vector<string>{I(def, 1)},
diff --git a/caffe2/operators/maxpool_op.cc b/caffe2/operators/maxpool_op.cc
index d3c744a..e78aa156 100644
--- a/caffe2/operators/maxpool_op.cc
+++ b/caffe2/operators/maxpool_op.cc
@@ -141,7 +141,7 @@
 REGISTER_CPU_OPERATOR(MaxPoolGradient, MaxPoolGradientOp<float, CPUContext>);
 
 struct GetMaxPoolGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "MaxPoolGradient", "",
         vector<string>{I(def, 0), GO(def, 0), O(def, 1)},
diff --git a/caffe2/operators/order_switch_ops.cc b/caffe2/operators/order_switch_ops.cc
index 3fdf7b9..5b584fc 100644
--- a/caffe2/operators/order_switch_ops.cc
+++ b/caffe2/operators/order_switch_ops.cc
@@ -50,7 +50,7 @@
 REGISTER_CPU_OPERATOR(NCHW2NHWC, NCHW2NHWCOp<float, CPUContext>);
 
 struct GetNHWC2NCHWGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "NCHW2NHWC", "",
         vector<string>{GO(def, 0)},
@@ -60,7 +60,7 @@
 REGISTER_GRADIENT(NHWC2NCHW, GetNHWC2NCHWGradient);
 
 struct GetNCHW2NHWCGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "NHWC2NCHW", "",
         vector<string>{GO(def, 0)},
diff --git a/caffe2/operators/relu_op.cc b/caffe2/operators/relu_op.cc
index 857819c..42e577a1 100644
--- a/caffe2/operators/relu_op.cc
+++ b/caffe2/operators/relu_op.cc
@@ -38,7 +38,7 @@
 REGISTER_CPU_OPERATOR(ReluGradient, ReluGradientOp<float, CPUContext>);
 
 struct GetReluGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "ReluGradient", "",
         vector<string>{O(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/sigmoid_op.cc b/caffe2/operators/sigmoid_op.cc
index 323cc34..dfb831b 100644
--- a/caffe2/operators/sigmoid_op.cc
+++ b/caffe2/operators/sigmoid_op.cc
@@ -42,7 +42,7 @@
                                      SigmoidGradientCPUFunctor<float> >);
 
 struct GetSigmoidGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "SigmoidGradient", "",
         vector<string>{O(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/softmax_op.cc b/caffe2/operators/softmax_op.cc
index 1735f31..f5356a1 100644
--- a/caffe2/operators/softmax_op.cc
+++ b/caffe2/operators/softmax_op.cc
@@ -90,7 +90,7 @@
 REGISTER_CPU_OPERATOR(SoftmaxGradient, SoftmaxGradientOp<float, CPUContext>);
 
 struct GetSoftmaxGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "SoftmaxGradient", "",
         vector<string>{O(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/spatial_batch_norm_op.cc b/caffe2/operators/spatial_batch_norm_op.cc
index 3bb1806..ef06a1a 100644
--- a/caffe2/operators/spatial_batch_norm_op.cc
+++ b/caffe2/operators/spatial_batch_norm_op.cc
@@ -8,7 +8,7 @@
 // is a bit more complex than usual gradient operators.
 namespace {
 struct GetSpatialBNGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     // Check if we are in training or testing mode.
     bool is_test = false;
     if (HasArgument(def, "is_test")) {
diff --git a/caffe2/operators/tanh_op.cc b/caffe2/operators/tanh_op.cc
index bfbf50a..91bf13e 100644
--- a/caffe2/operators/tanh_op.cc
+++ b/caffe2/operators/tanh_op.cc
@@ -42,7 +42,7 @@
                                      TanhGradientCPUFunctor<float> >);
 
 struct GetTanhGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "TanhGradient", "",
         std::vector<string>{O(def, 0), GO(def, 0)},
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index ab6ec97..47cb7c7 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -19,7 +19,7 @@
 SHOULD_NOT_DO_GRADIENT(PrintInt);
 
 struct GetFlattenGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "ReshapeLike", "",
         vector<string>{GO(def, 0), I(def, 0)},
@@ -29,7 +29,7 @@
 REGISTER_GRADIENT(Flatten, GetFlattenGradient);
 
 struct GetAliasGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     return SingleGradientDef(
         "Alias", "",
         vector<string>{GO(def, 0)},
@@ -41,7 +41,7 @@
 SHOULD_NOT_DO_GRADIENT(ReshapeLike);
 
 struct GetSplitGradient : public GetGradientDefBase {
-  static vector<OperatorDef>* Create(const OperatorDef& def) {
+  vector<OperatorDef>* Create(const OperatorDef& def) override {
     vector<string> grad_input;
     for (const string out : def.output()) {
       grad_input.push_back(GradientName(out));
diff --git a/pycaffe2/convnet_benchmarks.py b/pycaffe2/convnet_benchmarks.py
index 8bec892..48b7a70 100644
--- a/pycaffe2/convnet_benchmarks.py
+++ b/pycaffe2/convnet_benchmarks.py
@@ -259,7 +259,7 @@
   if (not args.batch_size or not args.model or not args.order or not args.cudnn_ws):
     parser.print_help()
 
-  workspace.GlobalInit(['caffe2', '--caffe2_log_level=0'])
+  workspace.GlobalInit(['caffe2', '--caffe2_log_level=-1'])
   model_map = {'AlexNet': AlexNet, 'OverFeat': OverFeat, 'VGGA': VGGA, 'Inception': Inception}
   Benchmark(model_map[args.model], args.order, args.batch_size, args.cudnn_ws,
             args.forward_only, args.iterations)