Move away from ConstantFill (#16214)
Summary:
Prerequisite of https://github.com/onnx/onnx/pull/1434
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16214
Reviewed By: BIT-silence
Differential Revision: D13755116
Pulled By: houseroad
fbshipit-source-id: a46be8d7df959b5ede93e1f9c911a9a9326e6879
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 8bcef71..2c060f0 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -1776,37 +1776,37 @@
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(float));
float f;
- memcpy(&f, &onnx_tensor.raw_data(), sizeof(float));
+ memcpy(&f, onnx_tensor.raw_data().c_str(), sizeof(float));
c2_values->set_f(f);
}
- } else if (onnx_tensor.data_type() == TensorProto::DOUBLE){
+ } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
c2_dtype->set_i(caffe2::TensorProto::DOUBLE);
if (onnx_tensor.double_data_size() > 0) {
c2_values->set_f(static_cast<float>(onnx_tensor.double_data(0)));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(double));
double d;
- memcpy(&d, &onnx_tensor.raw_data(), sizeof(double));
+ memcpy(&d, onnx_tensor.raw_data().c_str(), sizeof(double));
c2_values->set_f(static_cast<float>(d));
}
- } else if (onnx_tensor.data_type() == TensorProto::INT64){
+ } else if (onnx_tensor.data_type() == TensorProto::INT64) {
c2_dtype->set_i(caffe2::TensorProto::INT64);
if (onnx_tensor.int64_data_size() > 0) {
c2_values->set_i(onnx_tensor.int64_data(0));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int64_t));
int64_t i;
- memcpy(&i, &onnx_tensor.raw_data(), sizeof(int64_t));
+ memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int64_t));
c2_values->set_i(i);
}
- } else if (onnx_tensor.data_type() == TensorProto::INT32){
+ } else if (onnx_tensor.data_type() == TensorProto::INT32) {
c2_dtype->set_i(caffe2::TensorProto::INT32);
if (onnx_tensor.int32_data_size() > 0) {
c2_values->set_i(onnx_tensor.int32_data(0));
} else {
CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int32_t));
int32_t i;
- memcpy(&i, &onnx_tensor.raw_data(), sizeof(int32_t));
+ memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int32_t));
c2_values->set_i(i);
}
} else {
diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect
index 455d5bb..3ce773b 100644
--- a/test/onnx/expect/TestOperators.test_full.expect
+++ b/test/onnx/expect/TestOperators.test_full.expect
@@ -92,21 +92,14 @@
node {
input: "9"
output: "10"
- op_type: "ConstantFill"
- attribute {
- name: "dtype"
- i: 1
- type: INT
- }
- attribute {
- name: "input_as_shape"
- i: 1
- type: INT
- }
+ op_type: "ConstantOfShape"
attribute {
name: "value"
- f: 2
- type: FLOAT
+ t {
+ data_type: 1
+ raw_data: "\000\000\000@"
+ }
+ type: TENSOR
}
}
name: "torch-jit-export"
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index a3de283..88ba735 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -1128,7 +1128,8 @@
@parse_args('v', 'i', 'v', 'v')
def zeros(g, sizes, dtype, layout, device):
# NOTE: no way to set device and layout in ONNX, so we ignore it
- return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
+ return g.op("ConstantOfShape", sizes,
+ value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v')
@@ -1140,7 +1141,8 @@
@parse_args('v', 'i', 'v', 'v')
def ones(g, sizes, dtype, layout, device):
- return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
+ return g.op("ConstantOfShape", sizes,
+ value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'i', 'v', 'v')
@@ -1157,8 +1159,8 @@
return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
else:
dtype = _get_const(dtype, 'i', 'dtype')
- return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype],
- input_as_shape_i=1, value_f=const_value)
+ return g.op("ConstantOfShape", sizes,
+ value_t=torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype]))
@parse_args('v', 'f', 'i', 'v', 'v')