Support upsample (#13152)
Summary:
This will enable the updated attribute and input format of operator upsample.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13152
Reviewed By: houseroad
Differential Revision: D12812491
Pulled By: zrphercule
fbshipit-source-id: d5db200365f1ab2bd1f052667795841d7ee6beb3
diff --git a/test/onnx/expect/TestOperators.test_upsample.expect b/test/onnx/expect/TestOperators.test_upsample.expect
index e6d222d..2a36818 100644
--- a/test/onnx/expect/TestOperators.test_upsample.expect
+++ b/test/onnx/expect/TestOperators.test_upsample.expect
@@ -3,22 +3,28 @@
producer_version: "0.4"
graph {
node {
- input: "0"
output: "1"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ dims: 4
+ data_type: FLOAT
+ raw_data: "\000\000\200?\000\000\200?\000\000\000@\000\000\000@"
+ }
+ type: TENSOR
+ }
+ }
+ node {
+ input: "0"
+ input: "1"
+ output: "2"
op_type: "Upsample"
attribute {
name: "mode"
s: "linear"
type: STRING
}
- attribute {
- name: "scales"
- floats: 1
- floats: 1
- floats: 2
- floats: 2
- type: FLOATS
- }
}
name: "torch-jit-export"
input {
@@ -44,7 +50,7 @@
}
}
output {
- name: "1"
+ name: "2"
type {
tensor_type {
elem_type: FLOAT
diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py
index f08f121..cb55973 100644
--- a/test/onnx/test_operators.py
+++ b/test/onnx/test_operators.py
@@ -75,7 +75,6 @@
import test_onnx_common
model_def = onnx.ModelProto.FromString(onnx_model_pb)
onnx.checker.check_model(model_def)
-
if _onnx_test:
test_function = inspect.stack()[1][0].f_code.co_name
test_name = test_function[0:4] + "_operator" + test_function[4:]
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index cb8ee80..a0ac1dc 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -633,8 +633,10 @@
def upsample_nearest2d(g, input, output_size):
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
- return g.op("Upsample", input,
- scales_f=[1., 1., height_scale, width_scale],
+ scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
+ width_scale]))
+
+ return g.op("Upsample", input, scales,
mode_s="nearest")
@@ -644,8 +646,9 @@
return _unimplemented("upsample_bilinear2d", "align_corners == True")
height_scale = float(output_size[-2]) / input.type().sizes()[-2]
width_scale = float(output_size[-1]) / input.type().sizes()[-1]
- return g.op("Upsample", input,
- scales_f=[1., 1., height_scale, width_scale],
+ scales = g.op("Constant", value_t=torch.tensor([1., 1., height_scale,
+ width_scale]))
+ return g.op("Upsample", input, scales,
mode_s="linear")