Avoid constant-folding for ops that produce variants.
This is generally a lossy transformation since it results in opaque proto blobs that can only be handled by the TensorFlow runtime. If we need any of these transformations, we should add a separate pass that does them.
PiperOrigin-RevId: 294316034
Change-Id: I934f070094de605ffd237d09dc718265f67c761b
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
index 11eafde..7b46c6a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
@@ -19,6 +19,7 @@
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
#include "tensorflow/core/platform/mutex.h"
@@ -32,6 +33,17 @@
// the original semantics.
if (!inst->hasNoSideEffect()) return failure();
+ // If any of the result types are variants, don't try to constant fold them.
+ // This creates opaque variant constants which lose information and would
+ // require "raising" later.
+ for (auto& type : inst->getResultTypes()) {
+ if (auto tensor_type = type.dyn_cast<TensorType>()) {
+ if (tensor_type.getElementType().isa<VariantType>()) {
+ return failure();
+ }
+ }
+ }
+
// TODO(jpienaar): Currently this persists the entire program execution. This
// should instead be per module/set from the Graph being executed in TF (if
// any) so that the value of variables in the context could be read.