Remove python_default_init from ATen and use Optional (#15234)
Summary:
Optional clean up. This PR remove python_default_init from the yaml files, and the code-gen, and utilize optional type to do the work.
This also fix the bug in the #13149 to correctly adopt as_strided backward.
Fixes #9941
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15234
Differential Revision: D13502044
Pulled By: wanchaol
fbshipit-source-id: 774b61fc4414482cf11d56e22bd0275aefb352a4
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 252cc96..4975dc4 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -1629,8 +1629,7 @@
- arg: THTensor* result
output: True
- THTensor* self
- - arg: real p
- python_default_init: AS_REAL(2)
+ - real p
- arg: long dim
wrap_dim: self
- arg: bool keepdim
diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h
index c8d03e0..936a881 100644
--- a/aten/src/ATen/core/Tensor.h
+++ b/aten/src/ATen/core/Tensor.h
@@ -298,10 +298,8 @@
Tensor argmax() const;
Tensor argmin(int64_t dim, bool keepdim=false) const;
Tensor argmin() const;
- Tensor as_strided(IntList size, IntList stride) const;
- Tensor & as_strided_(IntList size, IntList stride);
- Tensor as_strided(IntList size, IntList stride, int64_t storage_offset) const;
- Tensor & as_strided_(IntList size, IntList stride, int64_t storage_offset);
+ Tensor as_strided(IntList size, IntList stride, c10::optional<int64_t> storage_offset=c10::nullopt) const;
+ Tensor & as_strided_(IntList size, IntList stride, c10::optional<int64_t> storage_offset=c10::nullopt);
Tensor asin() const;
Tensor & asin_();
Tensor atan() const;
@@ -449,7 +447,7 @@
Tensor & squeeze_();
Tensor & squeeze_(int64_t dim);
Tensor sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta=1, Scalar alpha=1) const;
- Tensor stft(int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
+ Tensor stft(int64_t n_fft, c10::optional<int64_t> hop_length=c10::nullopt, c10::optional<int64_t> win_length=c10::nullopt, const Tensor & window={}, bool normalized=false, bool onesided=true) const;
int64_t stride(int64_t dim) const;
Tensor sum(ScalarType dtype) const;
Tensor sum() const;
@@ -487,7 +485,7 @@
Tensor view_as(const Tensor & other) const;
Tensor where(const Tensor & condition, const Tensor & other) const;
Tensor norm(Scalar p=2) const;
- Tensor norm(Scalar p, int64_t dim, bool keepdim=false) const;
+ Tensor norm(c10::optional<Scalar> p, int64_t dim, bool keepdim=false) const;
Tensor clone() const;
Tensor & resize_as_(const Tensor & the_template);
Tensor pow(Scalar exponent) const;
diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h
index 17b74bc..b44ca0f 100644
--- a/aten/src/ATen/core/TensorMethods.h
+++ b/aten/src/ATen/core/TensorMethods.h
@@ -115,16 +115,10 @@
inline Tensor Tensor::argmin() const {
return type().argmin(*this);
}
-inline Tensor Tensor::as_strided(IntList size, IntList stride) const {
- return type().as_strided(*this, size, stride);
-}
-inline Tensor & Tensor::as_strided_(IntList size, IntList stride) {
- return type().as_strided_(*this, size, stride);
-}
-inline Tensor Tensor::as_strided(IntList size, IntList stride, int64_t storage_offset) const {
+inline Tensor Tensor::as_strided(IntList size, IntList stride, c10::optional<int64_t> storage_offset) const {
return type().as_strided(*this, size, stride, storage_offset);
}
-inline Tensor & Tensor::as_strided_(IntList size, IntList stride, int64_t storage_offset) {
+inline Tensor & Tensor::as_strided_(IntList size, IntList stride, c10::optional<int64_t> storage_offset) {
return type().as_strided_(*this, size, stride, storage_offset);
}
inline Tensor Tensor::asin() const {
@@ -568,7 +562,7 @@
inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const {
return type().sspaddmm(*this, mat1, mat2, beta, alpha);
}
-inline Tensor Tensor::stft(int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window, bool normalized, bool onesided) const {
+inline Tensor Tensor::stft(int64_t n_fft, c10::optional<int64_t> hop_length, c10::optional<int64_t> win_length, const Tensor & window, bool normalized, bool onesided) const {
return type().stft(*this, n_fft, hop_length, win_length, window, normalized, onesided);
}
inline int64_t Tensor::stride(int64_t dim) const {
@@ -682,7 +676,7 @@
inline Tensor Tensor::norm(Scalar p) const {
return type().norm(*this, p);
}
-inline Tensor Tensor::norm(Scalar p, int64_t dim, bool keepdim) const {
+inline Tensor Tensor::norm(c10::optional<Scalar> p, int64_t dim, bool keepdim) const {
return type().norm(*this, p, dim, keepdim);
}
inline Tensor Tensor::clone() const {
diff --git a/aten/src/ATen/core/Type.h b/aten/src/ATen/core/Type.h
index 55667e0..df41133 100644
--- a/aten/src/ATen/core/Type.h
+++ b/aten/src/ATen/core/Type.h
@@ -205,10 +205,8 @@
virtual Tensor argmax(const Tensor & self) const = 0;
virtual Tensor argmin(const Tensor & self, int64_t dim, bool keepdim) const = 0;
virtual Tensor argmin(const Tensor & self) const = 0;
- virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride) const = 0;
- virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride) const = 0;
- virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride, int64_t storage_offset) const = 0;
- virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride, int64_t storage_offset) const = 0;
+ virtual Tensor as_strided(const Tensor & self, IntList size, IntList stride, c10::optional<int64_t> storage_offset) const = 0;
+ virtual Tensor & as_strided_(Tensor & self, IntList size, IntList stride, c10::optional<int64_t> storage_offset) const = 0;
virtual Tensor asin(const Tensor & self) const = 0;
virtual Tensor & asin_(Tensor & self) const = 0;
virtual Tensor atan(const Tensor & self) const = 0;
@@ -356,7 +354,7 @@
virtual Tensor & squeeze_(Tensor & self) const = 0;
virtual Tensor & squeeze_(Tensor & self, int64_t dim) const = 0;
virtual Tensor sspaddmm(const Tensor & self, const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const = 0;
- virtual Tensor stft(const Tensor & self, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor & window, bool normalized, bool onesided) const = 0;
+ virtual Tensor stft(const Tensor & self, int64_t n_fft, c10::optional<int64_t> hop_length, c10::optional<int64_t> win_length, const Tensor & window, bool normalized, bool onesided) const = 0;
virtual int64_t stride(const Tensor & self, int64_t dim) const = 0;
virtual Tensor sum(const Tensor & self, ScalarType dtype) const = 0;
virtual Tensor sum(const Tensor & self) const = 0;
@@ -394,7 +392,7 @@
virtual Tensor view_as(const Tensor & self, const Tensor & other) const = 0;
virtual Tensor where(const Tensor & condition, const Tensor & self, const Tensor & other) const = 0;
virtual Tensor norm(const Tensor & self, Scalar p) const = 0;
- virtual Tensor norm(const Tensor & self, Scalar p, int64_t dim, bool keepdim) const = 0;
+ virtual Tensor norm(const Tensor & self, c10::optional<Scalar> p, int64_t dim, bool keepdim) const = 0;
virtual Tensor clone(const Tensor & self) const = 0;
virtual Tensor & resize_as_(Tensor & self, const Tensor & the_template) const = 0;
virtual Tensor pow(const Tensor & self, Scalar exponent) const = 0;
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index fd6a377..c56f7b8 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -395,7 +395,6 @@
'is_nullable': bool,
'default': str,
'default_init': str,
- 'python_default_init': str,
'output': bool,
'size': int,
'declared_type': str,
@@ -422,7 +421,6 @@
'is_nullable': bool,
'default': str,
'default_init': str,
- 'python_default_init': str,
'output': bool,
'size': int,
}, total=False)
@@ -619,10 +617,6 @@
default = translate_default(argument, type_str, argument['default'])
translated['default'] = default
translated['default_init'] = argument.get('default_init', default)
- if 'python_default_init' in argument:
- assert 'default' not in argument
- default = translate_default(argument, type_str, argument['python_default_init'])
- translated['python_default_init'] = default
if argument.get('output'):
translated['output'] = True
if argument.get('size'):
diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp
index e4b3f89..4282d37 100644
--- a/aten/src/ATen/native/ReduceOps.cpp
+++ b/aten/src/ATen/native/ReduceOps.cpp
@@ -103,7 +103,7 @@
// efficiency.
// not generalize this to common mismatched input/output types to avoid cross
// product of templated kernel launches.
- if (self.type().scalarType() == dtype ||
+ if (self.type().scalarType() == dtype ||
(self.is_cuda() && self.type().scalarType() == kHalf && dtype == kFloat)) {
return TensorIterator::reduce_op(viewed_result, self);
}
@@ -401,10 +401,11 @@
}
}
-Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool keepdim) {
+Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> pOpt, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
+ auto p = pOpt.value_or(2.0);
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
@@ -438,7 +439,7 @@
}
}
-Tensor norm(const Tensor& self, Scalar p, int64_t dim, bool keepdim) {
+Tensor norm(const Tensor& self, optional<Scalar> p, int64_t dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::norm_out(result, self, p, dim, keepdim);
}
diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp
index 4062fc2..5681079 100644
--- a/aten/src/ATen/native/SpectralOps.cpp
+++ b/aten/src/ATen/native/SpectralOps.cpp
@@ -173,8 +173,8 @@
}
-Tensor stft(const Tensor& self, const int64_t n_fft, const int64_t hop_length,
- const int64_t win_length, const Tensor& window,
+Tensor stft(const Tensor& self, const int64_t n_fft, const optional<int64_t> hop_lengthOpt,
+ const optional<int64_t> win_lengthOpt, const Tensor& window,
const bool normalized, const bool onesided) {
#define REPR(SS) \
SS << "stft(" << self.type() << self.sizes() << ", n_fft=" << n_fft \
@@ -187,6 +187,10 @@
} \
SS << ", normalized=" << normalized << ", onesided=" << onesided << ")"
+ // default_init hop_length and win_length
+ auto hop_length = hop_lengthOpt.value_or(n_fft >> 2);
+ auto win_length = win_lengthOpt.value_or(n_fft);
+
if (!at::isFloatingType(self.type().scalarType()) || self.dim() > 2 || self.dim() < 1) {
std::ostringstream ss;
REPR(ss) << ": expected a 1D or 2D tensor of floating types";
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index 02a96ae..52320c2 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -298,7 +298,8 @@
return sum_to(self, size);
}
-Tensor as_strided(const Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
+Tensor as_strided(const Tensor& self, IntList size, IntList stride, optional<int64_t> storage_offset_) {
+ auto storage_offset = storage_offset_.value_or(self.storage_offset());
auto tid = self.type_id();
AT_CHECK(
tid == CPUTensorId() || tid == CUDATensorId(),
@@ -308,19 +309,12 @@
return result;
}
-Tensor &as_strided_(Tensor& self, IntList size, IntList stride, int64_t storage_offset) {
+Tensor &as_strided_(Tensor& self, IntList size, IntList stride, optional<int64_t> storage_offset_) {
+ auto storage_offset = storage_offset_.value_or(self.storage_offset());
setStrided(self, size, stride, storage_offset);
return self;
}
-Tensor as_strided(const Tensor& self, IntList size, IntList stride) {
- return at::as_strided(self, size, stride, self.storage_offset());
-}
-
-Tensor &as_strided_(Tensor& self, IntList size, IntList stride) {
- return self.as_strided_(size, stride, self.storage_offset());
-}
-
Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_t length) {
int64_t allDim = self.dim();
int64_t end = start+length;
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 6f527f4..6c8d98f 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -205,26 +205,14 @@
- func: _argmin(Tensor self, int64_t dim, bool keepdim=false) -> Tensor
variants: function
-- func: as_strided(Tensor self, IntList size, IntList stride) -> Tensor
+- func: as_strided(Tensor self, IntList size, IntList stride, int64_t? storage_offset=None) -> Tensor
variants: function, method
device_guard: false
-- func: as_strided_(Tensor self, IntList size, IntList stride) -> Tensor
+- func: as_strided_(Tensor self, IntList size, IntList stride, int64_t? storage_offset=None) -> Tensor
variants: function, method
device_guard: false
-- func: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
- variants: function, method
- device_guard: false
- python_default_init:
- storage_offset: self.storage_offset()
-
-- func: as_strided_(Tensor self, IntList size, IntList stride, int64_t storage_offset) -> Tensor
- variants: function, method
- device_guard: false
- python_default_init:
- storage_offset: self.storage_offset()
-
- func: asin(Tensor self) -> Tensor
variants: function, method
@@ -1597,11 +1585,8 @@
# missing the `pad_mode` and `center` arguments, which are taken care of at
# `torch.functional.py`. They shall be moved here once we have mapping between
# Python strings and C++ Enum in codegen.
-- func: stft(Tensor self, int64_t n_fft, int64_t hop_length, int64_t win_length, Tensor? window={}, bool normalized=false, bool onesided=true) -> Tensor
+- func: stft(Tensor self, int64_t n_fft, int64_t? hop_length=None, int64_t? win_length=None, Tensor? window={}, bool normalized=false, bool onesided=true) -> Tensor
variants: function, method
- python_default_init:
- hop_length: n_fft >> 2
- win_length: n_fft
- func: stride(Tensor self, int64_t dim) -> int64_t
variants: function, method
@@ -1890,14 +1875,10 @@
- func: norm(Tensor self, Scalar p=2) -> Tensor
variants: function, method
-- func: norm(Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor
+- func: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
variants: function, method
- python_default_init:
- p: 2
-- func: norm_out(Tensor result, Tensor self, Scalar p, int64_t dim, bool keepdim=false) -> Tensor
- python_default_init:
- p: 2
+- func: norm_out(Tensor result, Tensor self, Scalar? p, int64_t dim, bool keepdim=false) -> Tensor
- func: frobenius_norm(Tensor self) -> Tensor
variants: function
diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py
index 7d2f29d..aacb235 100644
--- a/aten/src/ATen/native_parse.py
+++ b/aten/src/ATen/native_parse.py
@@ -48,7 +48,6 @@
def parse_arguments(args, func_decl, func_name, func_return):
arguments = []
- python_default_inits = func_decl.get('python_default_init', {})
is_out_fn = func_name.endswith('_out')
if is_out_fn and func_decl.get('variants', []) not in [[], 'function', ['function']]:
raise RuntimeError("Native functions suffixed with _out MUST be declared with only the function variant; "
@@ -71,16 +70,11 @@
t, name = type_and_name
default = None
- python_default_init = None
if '=' in name:
ns = name.split('=', 1)
name, default = ns[0], parse_default(ns[1])
- if name in python_default_inits:
- assert default is None
- python_default_init = python_default_inits[name]
-
typ = sanitize_types(t)
assert len(typ) == 1
argument_dict = {'type': typ[0].rstrip('?'), 'name': name, 'is_nullable': typ[0].endswith('?')}
@@ -90,8 +84,6 @@
argument_dict['size'] = int(match.group(1))
if default is not None:
argument_dict['default'] = default
- if python_default_init is not None:
- argument_dict['python_default_init'] = python_default_init
# TODO: convention is that the ith-argument correspond to the i-th return, but it would
# be better if we just named everything and matched by name.
if is_out_fn and arg_idx < len(func_return):
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 1cb5287..500dfb8 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -138,7 +138,7 @@
- name: alias(Tensor self)
self: grad
-- name: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset)
+- name: as_strided(Tensor self, IntList size, IntList stride, int64_t? storage_offset)
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
- name: asin(Tensor self)
@@ -561,7 +561,7 @@
- name: norm(Tensor self, Scalar p)
self: norm_backward(grad, self, p, result)
-- name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim)
+- name: norm(Tensor self, Scalar? p, int64_t dim, bool keepdim)
self: norm_backward(grad, self, p, result, dim, keepdim)
- name: _pdist_forward(Tensor self, double p)
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index b1e794e..2c9a9e7 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -251,6 +251,9 @@
'const Type &': 'scalartype',
'const THPLayout &': 'layout',
'const Device &': 'device',
+ 'c10::optional<ScalarType>': 'scalartypeOptional',
+ 'c10::optional<Scalar>': 'scalarOptional',
+ 'c10::optional<int64_t>': 'toInt64Optional',
'int64_t': 'toInt64',
'bool': 'toBool',
'double': 'toDouble',
@@ -320,11 +323,7 @@
default_expr += '.scalarType()'
expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
else:
- opt_match = re.match(r'c10::optional<(.+)>', typename)
- if (opt_match):
- unpack = opt_match.group(1).lower() + 'Optional'
- else:
- unpack = unpack_methods.get(typename, typename.lower())
+ unpack = unpack_methods.get(typename, typename.lower())
expr = 'r.{}({})'.format(unpack, arg_index)
if unpack_args:
@@ -349,7 +348,7 @@
formal_args.append(formal)
# We always want to unpack when we have TensorOptions.
- unpack = any(arg.get('python_default_init') for arg in inputs) or has_tensor_options
+ unpack = has_tensor_options
for arg in inputs:
if arg['simple_type'] in ['Type', 'TensorOptions']:
continue
@@ -397,7 +396,7 @@
elif arg['name'] == 'layout' and arg['simple_type'] == 'Layout':
# out(s) determines the type and layout if it is present, so only use this if there are no outputs.
if len(outputs) == 0:
- layout = parse_arg(arg, layout_idx, arg.get('python_default_init'))[0]
+ layout = parse_arg(arg, layout_idx)[0]
elif arg['name'] == 'device' and arg['simple_type'] == 'Device':
if len(outputs) == 0:
assert parsed_type_args
@@ -770,8 +769,6 @@
default = arg['default']
if default == 'nullptr' or default == 'nullopt' or default == '{}':
default = 'None'
- if arg.get('python_default_init') is not None:
- default = 'None'
if default is not None:
param += '=' + str(default)
return param
diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp
index 308865f..776f238 100644
--- a/tools/autograd/templates/Functions.cpp
+++ b/tools/autograd/templates/Functions.cpp
@@ -88,8 +88,8 @@
return size;
}
-Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) {
- double p = p_.toDouble();
+Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional<Scalar> & p_, const Tensor & norm) {
+ double p = p_.value_or(2.0).toDouble();
Tensor self_scaled;
Tensor scale_v;
if (p == 0.0) {
@@ -114,7 +114,7 @@
return self_scaled * scale_v;
}
-Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor norm, int64_t dim, bool keepdim) {
+Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> & p_, Tensor norm, int64_t dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = grad.unsqueeze(dim);
norm = norm.unsqueeze(dim);
@@ -1371,7 +1371,7 @@
}
// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
-Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, int64_t storage_offset) {
+Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, optional<int64_t> storage_offset_) {
// For output geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
@@ -1379,6 +1379,7 @@
// Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
+ auto storage_offset = storage_offset_.value_or(input_geometry.storage_offset());
auto odim = grad.dim();
std::vector<int64_t> out_sizes_, out_strides_;
out_sizes_.reserve(odim);
diff --git a/tools/jit/gen_jit_dispatch.py b/tools/jit/gen_jit_dispatch.py
index 66371ad..2f53e67 100644
--- a/tools/jit/gen_jit_dispatch.py
+++ b/tools/jit/gen_jit_dispatch.py
@@ -56,6 +56,7 @@
'ScalarType': 'ScalarType',
'ScalarType?': 'ScalarType?',
'int64_t': 'int',
+ 'int64_t?': 'int?',
'double': 'float',
'bool': 'bool',
'Generator': 'Generator',
@@ -91,6 +92,7 @@
'bool': '{}.toBool()',
'double': '{}.toDouble()',
'int64_t': '{}.toInt()',
+ 'int64_t?': '{}.toOptional<int64_t>()',
'std::string': '{}.toString()->string()',
'Generator': 'nullptr',
'std::array<bool,2>': 'as_bool_array<2>({}.toIntList()->elements())',
diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp
index 85465d5..69e3170 100644
--- a/torch/csrc/jit/passes/shape_analysis.cpp
+++ b/torch/csrc/jit/passes/shape_analysis.cpp
@@ -902,7 +902,7 @@
"aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
- "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
+ "aten::norm(Tensor self, Scalar? p, int dim, bool keepdim) -> Tensor",
"aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
"aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
"aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
@@ -1262,9 +1262,7 @@
node->matches(
"aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
node->matches(
- "aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") ||
- node->matches(
- "aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) {
+ "aten::as_strided(Tensor self, int[] size, int[] stride, int? storage_offset) -> Tensor")) {
return reshape_prop(node, attr::size, tensor_types);
} else if (node->matches(
"aten::reshape(Tensor self, int[] shape) -> Tensor")) {
diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp
index f2fc0c0..874fb4d 100644
--- a/torch/csrc/jit/tracer.cpp
+++ b/torch/csrc/jit/tracer.cpp
@@ -73,6 +73,18 @@
detail::genericAddInput(n, value);
}
}
+
+void addInputs(Node *n, const char * name, c10::optional<int64_t> value) {
+ if(value) {
+ detail::genericAddInput(n, *value);
+ } else {
+ Graph * g = n->owningGraph();
+ Value* none =
+ g->insertNode(g->createNone(IntType::get()))
+ ->output();
+ n->addInput(none);
+ }
+}
void addInputs(Node *n, const char * name, bool value) { detail::genericAddInput(n, value); }
void addInputs(Node *n, const char * name, double value) { detail::genericAddInput(n, value); }
void addInputs(Node *n, const char * name, const at::Scalar& value) { detail::genericAddInput(n, value); }
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index cc4bb6b..bbbca8d 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -200,6 +200,7 @@
// NB: those serve both as an intermediate steps in addInputs below,
// as well as the overloads that terminate template recursion
TORCH_API void addInputs(Node *n, const char * name, int64_t value);
+TORCH_API void addInputs(Node *n, const char * name, c10::optional<int64_t> value);
TORCH_API void addInputs(Node *n, const char * name, bool value);
TORCH_API void addInputs(Node *n, const char * name, double value);
TORCH_API void addInputs(Node *n, const char * name, const at::Scalar& value);
diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h
index 432383b..215c0b4 100644
--- a/torch/csrc/utils/python_arg_parser.h
+++ b/torch/csrc/utils/python_arg_parser.h
@@ -127,6 +127,7 @@
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
inline c10::optional<at::Scalar> scalarOptional(int i);
+ inline c10::optional<int64_t> toInt64Optional(int i);
inline const THPLayout& layout(int i);
inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
inline at::Device device(int i);
@@ -405,6 +406,12 @@
return toInt64(i);
}
+inline c10::optional<int64_t> PythonArgs::toInt64Optional(int i) {
+ if (!args[i])
+ return c10::nullopt;
+ return toInt64(i);
+}
+
inline double PythonArgs::toDouble(int i) {
if (!args[i]) return signature.params[i].default_double;
return THPUtils_unpackDouble(args[i]);