Fix ONNX_ATEN mode (#14239)
Summary:
Fix ONNX_ATEN mode by adding it to the validateBlock method.
Before this pr, validateBlock will throw an exception when using this mode.
I will add related test cases for ONNX_ATEN mode in a different pr once this is merged, since we dont have any currently.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14239
Differential Revision: D13145443
Pulled By: zrphercule
fbshipit-source-id: 60e7942aa126acfe67bdb428ef231ac3066234b1
diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp
index 7f1b47e..4d870a5 100644
--- a/torch/csrc/jit/export.cpp
+++ b/torch/csrc/jit/export.cpp
@@ -90,8 +90,11 @@
"Cannot export individual pack_padded_sequence or pad_packed_sequence; these operations must occur in pairs.\n\nUsage of this operation occurred at:\n" +
getNodeStackTraceString(node));
}
- bool is_aten_fallback = operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK;
- if (!node->kind().is_onnx() && !is_aten_fallback && node->kind() != prim::Undefined) {
+ bool is_aten_enabled = operator_export_type ==
+ onnx_torch::OperatorExportTypes::ONNX_ATEN_FALLBACK ||
+ operator_export_type == onnx_torch::OperatorExportTypes::ONNX_ATEN;
+ if (!node->kind().is_onnx() && !is_aten_enabled &&
+ node->kind() != prim::Undefined) {
FAIL_EXPORT(
"Couldn't export operator " + node->kind().toDisplayString() + "\n\nDefined at:\n" +
getNodeStackTraceString(node));