Avoid constant-folding for ops that produce variants.

This is generally a lossy transformation since it results in opaque proto blobs that can only be handled by the TensorFlow runtime. If we need any of these transformations, we should add a separate pass that does them.

PiperOrigin-RevId: 294316034
Change-Id: I934f070094de605ffd237d09dc718265f67c761b
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
index 11eafde..7b46c6a 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc
@@ -19,6 +19,7 @@
 
 #include "tensorflow/c/eager/c_api.h"
 #include "tensorflow/c/tf_status.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
 #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h"
 #include "tensorflow/core/platform/mutex.h"
 
@@ -32,6 +33,17 @@
   // the original semantics.
   if (!inst->hasNoSideEffect()) return failure();
 
+  // If any of the result types are variants, don't try to constant fold them.
+  // This creates opaque variant constants which lose information and would
+  // require "raising" later.
+  for (auto& type : inst->getResultTypes()) {
+    if (auto tensor_type = type.dyn_cast<TensorType>()) {
+      if (tensor_type.getElementType().isa<VariantType>()) {
+        return failure();
+      }
+    }
+  }
+
   // TODO(jpienaar): Currently this persists the entire program execution. This
   // should instead be per module/set from the Graph being executed in TF (if
   // any) so that the value of variables in the context could be read.