Fix some bug in the old code: the net device option overwite used to not work.
diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc
index 127d8ee..7e8d56a 100644
--- a/caffe2/core/net.cc
+++ b/caffe2/core/net.cc
@@ -21,13 +21,18 @@
SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws)
: NetBase(net_def, ws) {
+ bool net_def_has_device_option = net_def.has_device_option();
// Initialize the operators
for (const OperatorDef& operator_def : net_def.operators()) {
VLOG(1) << "Creating operator " << operator_def.name()
<< ":" << operator_def.type();
- if (!operator_def.has_device_option()) {
- operators_.emplace_back(
- CreateOperator(operator_def, net_def.device_option(), ws));
+ if (!operator_def.has_device_option() && net_def_has_device_option) {
+ // In the case that the operator def does not specify a device option but
+ // the net def has a default option, we copy the device option over to the
+ // operator def.
+ OperatorDef temp_def(operator_def);
+ temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
+ operators_.emplace_back(CreateOperator(temp_def, ws));
} else {
operators_.emplace_back(CreateOperator(operator_def, ws));
}
@@ -60,14 +65,16 @@
: NetBase(net_def, ws), operator_nodes_(net_def.operators_size()) {
// Blob creator allows us to track which operator created which blob.
std::map<string, int> blob_creator;
+ bool net_def_has_device_option = net_def.has_device_option();
// Initialize the operators
for (int idx = 0; idx < net_def.operators_size(); ++idx) {
const OperatorDef& op_def = net_def.operators(idx);
VLOG(1) << "Creating operator #" << idx << ": "
<< op_def.name() << ":" << op_def.type();
- if (!op_def.has_device_option()) {
- operator_nodes_[idx].operator_.reset(
- CreateOperator(op_def, net_def.device_option(), ws));
+ if (!op_def.has_device_option() && net_def_has_device_option) {
+ OperatorDef temp_def(op_def);
+ temp_def.mutable_device_option()->CopyFrom(net_def.device_option());
+ operator_nodes_[idx].operator_.reset(CreateOperator(temp_def, ws));
} else {
operator_nodes_[idx].operator_.reset(CreateOperator(op_def, ws));
}
diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc
index a9e2b45..fff6dd9 100644
--- a/caffe2/core/operator.cc
+++ b/caffe2/core/operator.cc
@@ -90,9 +90,7 @@
return true;
}
-OperatorBase* CreateOperator(const OperatorDef& operator_def,
- const DeviceOption& device_option,
- Workspace* ws) {
+OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws) {
const string& key = operator_def.type();
switch (operator_def.device_option().device_type()) {
case CPU:
diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h
index 9d60c63..b86930c 100644
--- a/caffe2/core/operator.h
+++ b/caffe2/core/operator.h
@@ -216,17 +216,8 @@
#define REGISTER_CUDNN_OPERATOR(name, ...) \
REGISTER_CLASS(CUDNNOperatorRegistry, name, __VA_ARGS__)
-// Creates an operator with the given operator definition and device option.
-OperatorBase* CreateOperator(const OperatorDef& operator_def,
- const DeviceOption& device_option,
- Workspace* ws);
-
-// Create an operator with the given operator definition, and the device
-// option that is specified in the operator definition.
-inline OperatorBase* CreateOperator(const OperatorDef& operator_def,
- Workspace* ws) {
- return CreateOperator(operator_def, operator_def.device_option(), ws);
-}
+// Creates an operator with the given operator definition.
+OperatorBase* CreateOperator(const OperatorDef& operator_def, Workspace* ws);
} // namespace caffe2