Use cast::GetCastDataType to handle "from_type" and "to" arguments
Summary: Also enforce the "from_type" argument is supplied when getting gradient
Reviewed By: Yangqing
Differential Revision: D5684399
fbshipit-source-id: bee955d44a04c44142b2212cff548cea6e08b22f
diff --git a/caffe2/operators/cast_op.cc b/caffe2/operators/cast_op.cc
index 292d7d1..750b72a 100644
--- a/caffe2/operators/cast_op.cc
+++ b/caffe2/operators/cast_op.cc
@@ -129,15 +129,22 @@
// now modify the arguments in defs[0]
ArgumentHelper argsHelper(def_);
- auto to_name = argsHelper.GetSingleArgument<string>("to", "");
- auto from_name = argsHelper.GetSingleArgument<string>("from_type", "");
+ auto to_name = cast::GetCastDataType(argsHelper, "to");
+
+ CAFFE_ENFORCE(
+ argsHelper.HasSingleArgumentOfType<string>("from_type") ||
+ argsHelper.HasSingleArgumentOfType<int>("from_type"),
+ "Argument 'from_type' of type int or string"
+ " is required to get the gradient of CastOp");
+
+ auto from_name = cast::GetCastDataType(argsHelper, "from_type");
Argument *to = defs[0].add_arg();
to->set_name("to");
- to->set_s(from_name);
+ to->set_i(from_name);
Argument *from = defs[0].add_arg();
from->set_name("from_type");
- from->set_s(to_name);
+ from->set_i(to_name);
return defs;
}