[ONNX] Allow None as operator argument (#105040)
Needed by 'aten.index.Tensor', where 'indices' is list of optional
tensors.
Related https://github.com/microsoft/onnxscript/pull/862
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105040
Approved by: https://github.com/titaiwangms, https://github.com/thiagocrepaldi
diff --git a/torch/onnx/_internal/fx/fx_onnx_interpreter.py b/torch/onnx/_internal/fx/fx_onnx_interpreter.py
index 13541a6..493c2b2 100644
--- a/torch/onnx/_internal/fx/fx_onnx_interpreter.py
+++ b/torch/onnx/_internal/fx/fx_onnx_interpreter.py
@@ -139,11 +139,11 @@
output.shape = [len(sequence_mixed_elements)]
return output
elif isinstance(onnx_tensor, (tuple, list)) and all(
- isinstance(node, torch.fx.Node) for node in onnx_tensor
+ isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor
):
sequence_elements: List[
Union[
- onnxscript_graph_building.TorchScriptTensor,
+ Optional[onnxscript_graph_building.TorchScriptTensor],
Tuple[
onnxscript_graph_building.TorchScriptTensor,
...,
@@ -151,7 +151,9 @@
]
] = []
for tensor in onnx_tensor:
- sequence_elements.append(fx_name_to_onnxscript_value[tensor.name])
+ sequence_elements.append(
+ fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None
+ )
return sequence_elements
if isinstance(onnx_tensor, torch.dtype):
onnx_tensor = int(