Add helper function _constant in onnx.py
diff --git a/torch/onnx.py b/torch/onnx.py
index 3494a1c..810a377 100644
--- a/torch/onnx.py
+++ b/torch/onnx.py
@@ -190,5 +190,37 @@
return self.op("ATen", *args, operator_s=opname, **kwargs)
+def _constant(self, value, dims, type=None, *args, **kwargs):
+ assert(isinstance(value, (int, long, float)))
+ # Infer the type based on value.
+ if type is None:
+ if isinstance(value, int):
+ type = "int"
+ elif isinstance(value, long):
+ type = "long"
+ elif isinstance(value, float):
+ type = "float"
+
+ if type == "char":
+ tensor = torch.CharTensor(*dims)
+ elif type == "short":
+ tensor = torch.ShortTensor(*dims)
+ elif type == "int":
+ tensor = torch.IntTensor(*dims)
+ elif type == "long":
+ tensor = torch.LongTensor(*dims)
+ elif type == "half":
+ tensor = torch.HalfTensor(*dims)
+ elif type == "float":
+ tensor = torch.FloatTensor(*dims)
+ elif type == "double":
+ tensor = torch.DoubleTensor(*dims)
+ else:
+ raise ValueError("Unknown type, type should be one of the following strings:"
+ "char, short, int, long, half, float, double")
+ tensor.fill_(value)
+ return self.op("Constant", *args, value_t=tensor, **kwargs)
+
torch._C.Graph.op = _op
torch._C.Graph.at = _at
+torch._C.Graph.constant = _constant