Use cast::GetCastDataType to handle "from_type" and "to" arguments

Summary: Also enforce the "from_type" argument is supplied when getting gradient

Reviewed By: Yangqing

Differential Revision: D5684399

fbshipit-source-id: bee955d44a04c44142b2212cff548cea6e08b22f
diff --git a/caffe2/operators/cast_op.cc b/caffe2/operators/cast_op.cc
index 292d7d1..750b72a 100644
--- a/caffe2/operators/cast_op.cc
+++ b/caffe2/operators/cast_op.cc
@@ -129,15 +129,22 @@
     // now modify the arguments in defs[0]
     ArgumentHelper argsHelper(def_);
 
-    auto to_name = argsHelper.GetSingleArgument<string>("to", "");
-    auto from_name = argsHelper.GetSingleArgument<string>("from_type", "");
+    auto to_name = cast::GetCastDataType(argsHelper, "to");
+
+    CAFFE_ENFORCE(
+        argsHelper.HasSingleArgumentOfType<string>("from_type") ||
+            argsHelper.HasSingleArgumentOfType<int>("from_type"),
+        "Argument 'from_type' of type int or string"
+        " is required to get the gradient of CastOp");
+
+    auto from_name = cast::GetCastDataType(argsHelper, "from_type");
     Argument *to = defs[0].add_arg();
     to->set_name("to");
-    to->set_s(from_name);
+    to->set_i(from_name);
 
     Argument *from = defs[0].add_arg();
     from->set_name("from_type");
-    from->set_s(to_name);
+    from->set_i(to_name);
 
     return defs;
   }