Move away from ConstantFill (#16214)

Summary:
Prerequisite of https://github.com/onnx/onnx/pull/1434
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16214

Reviewed By: BIT-silence

Differential Revision: D13755116

Pulled By: houseroad

fbshipit-source-id: a46be8d7df959b5ede93e1f9c911a9a9326e6879
diff --git a/caffe2/onnx/backend.cc b/caffe2/onnx/backend.cc
index 8bcef71..2c060f0 100644
--- a/caffe2/onnx/backend.cc
+++ b/caffe2/onnx/backend.cc
@@ -1776,37 +1776,37 @@
       } else {
         CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(float));
         float f;
-        memcpy(&f, &onnx_tensor.raw_data(), sizeof(float));
+        memcpy(&f, onnx_tensor.raw_data().c_str(), sizeof(float));
         c2_values->set_f(f);
       }
-    } else if (onnx_tensor.data_type() == TensorProto::DOUBLE){
+    } else if (onnx_tensor.data_type() == TensorProto::DOUBLE) {
       c2_dtype->set_i(caffe2::TensorProto::DOUBLE);
       if (onnx_tensor.double_data_size() > 0) {
         c2_values->set_f(static_cast<float>(onnx_tensor.double_data(0)));
       } else {
         CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(double));
         double d;
-        memcpy(&d, &onnx_tensor.raw_data(), sizeof(double));
+        memcpy(&d, onnx_tensor.raw_data().c_str(), sizeof(double));
         c2_values->set_f(static_cast<float>(d));
       }
-    } else if (onnx_tensor.data_type() == TensorProto::INT64){
+    } else if (onnx_tensor.data_type() == TensorProto::INT64) {
       c2_dtype->set_i(caffe2::TensorProto::INT64);
       if (onnx_tensor.int64_data_size() > 0) {
         c2_values->set_i(onnx_tensor.int64_data(0));
       } else {
         CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int64_t));
         int64_t i;
-        memcpy(&i, &onnx_tensor.raw_data(), sizeof(int64_t));
+        memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int64_t));
         c2_values->set_i(i);
       }
-    } else if (onnx_tensor.data_type() == TensorProto::INT32){
+    } else if (onnx_tensor.data_type() == TensorProto::INT32) {
       c2_dtype->set_i(caffe2::TensorProto::INT32);
       if (onnx_tensor.int32_data_size() > 0) {
         c2_values->set_i(onnx_tensor.int32_data(0));
       } else {
         CAFFE_ENFORCE(onnx_tensor.raw_data().size() == sizeof(int32_t));
         int32_t i;
-        memcpy(&i, &onnx_tensor.raw_data(), sizeof(int32_t));
+        memcpy(&i, onnx_tensor.raw_data().c_str(), sizeof(int32_t));
         c2_values->set_i(i);
       }
     } else {
diff --git a/test/onnx/expect/TestOperators.test_full.expect b/test/onnx/expect/TestOperators.test_full.expect
index 455d5bb..3ce773b 100644
--- a/test/onnx/expect/TestOperators.test_full.expect
+++ b/test/onnx/expect/TestOperators.test_full.expect
@@ -92,21 +92,14 @@
   node {
     input: "9"
     output: "10"
-    op_type: "ConstantFill"
-    attribute {
-      name: "dtype"
-      i: 1
-      type: INT
-    }
-    attribute {
-      name: "input_as_shape"
-      i: 1
-      type: INT
-    }
+    op_type: "ConstantOfShape"
     attribute {
       name: "value"
-      f: 2
-      type: FLOAT
+      t {
+        data_type: 1
+        raw_data: "\000\000\000@"
+      }
+      type: TENSOR
     }
   }
   name: "torch-jit-export"
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index a3de283..88ba735 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -1128,7 +1128,8 @@
 @parse_args('v', 'i', 'v', 'v')
 def zeros(g, sizes, dtype, layout, device):
     # NOTE: no way to set device and layout in ONNX, so we ignore it
-    return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=0)
+    return g.op("ConstantOfShape", sizes,
+                value_t=torch.tensor(0, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
@@ -1140,7 +1141,8 @@
 
 @parse_args('v', 'i', 'v', 'v')
 def ones(g, sizes, dtype, layout, device):
-    return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype], input_as_shape_i=1, value_f=1)
+    return g.op("ConstantOfShape", sizes,
+                value_t=torch.tensor(1, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'i', 'v', 'v')
@@ -1157,8 +1159,8 @@
         return add(tmp, value, g.op("Constant", value_t=torch.tensor(1)))
     else:
         dtype = _get_const(dtype, 'i', 'dtype')
-        return g.op("ConstantFill", sizes, dtype_i=scalar_type_to_onnx[dtype],
-                    input_as_shape_i=1, value_f=const_value)
+        return g.op("ConstantOfShape", sizes,
+                    value_t=torch.tensor(const_value, dtype=scalar_type_to_pytorch_type[dtype]))
 
 
 @parse_args('v', 'f', 'i', 'v', 'v')