Add utilities for typing equality and subtyping, along with tests.
PiperOrigin-RevId: 413412663
Change-Id: I2c3fd03cb33102ed1222123b6ddacf9d4c4df27c
diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD
index 3dd93fd..2bdb65e 100644
--- a/tensorflow/core/framework/BUILD
+++ b/tensorflow/core/framework/BUILD
@@ -1059,6 +1059,7 @@
":op_def_proto_cc",
":tensor",
":types_proto_cc",
+ "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:statusor",
"//tensorflow/core/protobuf:error_codes_proto_impl_cc",
],
@@ -1288,6 +1289,7 @@
"device_base_test.cc",
"disable_jit_test.cc",
"full_type_inference_util_test.cc",
+ "full_type_util_test.cc",
"function_test.cc",
"graph_def_util_test.cc",
"graph_to_functiondef_test.cc",
diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc
index 89617dc..06e9c62 100644
--- a/tensorflow/core/framework/full_type_util.cc
+++ b/tensorflow/core/framework/full_type_util.cc
@@ -15,12 +15,16 @@
#include "tensorflow/core/framework/full_type_util.h"
+#include <algorithm>
+#include <string>
+
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
@@ -132,6 +136,92 @@
return ft;
}
+const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) {
+ static FullTypeDef* unset_type = []() {
+ FullTypeDef* t = new FullTypeDef();
+ return t;
+ }();
+
+ if (i < t.args_size()) {
+ return t.args(i);
+ }
+ return *unset_type;
+}
+
+const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i) {
+ static FullTypeDef* any_type = []() {
+ FullTypeDef* t = new FullTypeDef();
+ t->set_type_id(TFT_ANY);
+ return t;
+ }();
+
+ if (i < t.args_size()) {
+ const FullTypeDef& f_val = t.args(i);
+ if (f_val.type_id() == TFT_UNSET) {
+ return *any_type;
+ }
+ return f_val;
+ }
+ return *any_type;
+}
+
+bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs) {
+ if (lhs.type_id() != rhs.type_id()) {
+ return false;
+ }
+ const auto& lhs_s = lhs.s();
+ const auto& rhs_s = rhs.s();
+ if (lhs_s.empty()) {
+ if (!rhs_s.empty()) {
+ return false;
+ }
+ } else if (rhs_s != lhs_s) {
+ return false;
+ }
+ for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
+ const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
+ const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
+
+ if (!IsEqual(lhs_arg, rhs_arg)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs, bool covariant) {
+ // Rule: ANY is a supertype of all types.
+ if (rhs.type_id() == TFT_ANY) {
+ return true;
+ }
+ // Compatibility rule: UNSET is treated as ANY for the purpose of subtyping.
+ if (rhs.type_id() == TFT_UNSET) {
+ return true;
+ }
+ // Default rule: type IDs must match.
+ if (lhs.type_id() != rhs.type_id()) {
+ return false;
+ }
+
+ for (int i = 0; i < std::max(lhs.args_size(), rhs.args_size()); i++) {
+ const FullTypeDef& lhs_arg = GetArgDefaultAny(lhs, i);
+ const FullTypeDef& rhs_arg = GetArgDefaultAny(rhs, i);
+
+ if (covariant) {
+ if (!IsSubtype(lhs_arg, rhs_arg)) {
+ return false;
+ }
+ } else {
+ if (!IsSubtype(rhs_arg, lhs_arg)) {
+ return false;
+ }
+ }
+ }
+
+ // Invariant: type IDs are eaqual, and all args are subtype of one another.
+ return true;
+}
+
} // namespace full_type
} // namespace tensorflow
diff --git a/tensorflow/core/framework/full_type_util.h b/tensorflow/core/framework/full_type_util.h
index 44d6ad6..e5c9ff7 100644
--- a/tensorflow/core/framework/full_type_util.h
+++ b/tensorflow/core/framework/full_type_util.h
@@ -31,16 +31,19 @@
namespace full_type {
// TODO(mdan): Specific helpers won't get too far. Use a parser instead.
+// TODO(mdan): Move constructors into a separate file.
// Helpers that allow shorthand expression for the more common kinds of type
// constructors.
// Note: The arity below refers to the number of arguments of parametric types,
// not to the number of return values from a particular op.
+// Note: Type constructors are meant to create static type definitions in the
+// op definition (i.e. the OpDef proto).
// Helper for a type constructor of <t>[] (with no parameters).
OpTypeConstructor Nullary(FullTypeId t);
-// Helper for a type constructor of <t>[FT_VAR[<param_name>]].
+// Helper for a type constructor of <t>[FT_VAR[<var_name>]].
OpTypeConstructor Unary(FullTypeId t, const string& var_name);
// Helper for a type constructor of <t>[FT_ANY].
@@ -56,6 +59,14 @@
StatusOr<FullTypeDef> SpecializeType(const AttrSlice& attrs,
const OpDef& op_def);
+const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i);
+const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i);
+
+bool IsEqual(const FullTypeDef& lhs, const FullTypeDef& rhs);
+
+bool IsSubtype(const FullTypeDef& lhs, const FullTypeDef& rhs,
+ bool covariant = true);
+
} // namespace full_type
} // namespace tensorflow
diff --git a/tensorflow/core/framework/full_type_util_test.cc b/tensorflow/core/framework/full_type_util_test.cc
new file mode 100644
index 0000000..8eb5054
--- /dev/null
+++ b/tensorflow/core/framework/full_type_util_test.cc
@@ -0,0 +1,358 @@
+/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/full_type.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace full_type {
+
+namespace {
+
+// TODO(mdan): Use ParseTextProto, ProtoEquals when available in a clean lib.
+
+TEST(Nullary, Basic) {
+ OpTypeConstructor ctor = Nullary(TFT_TENSOR);
+
+ OpDef op;
+ op.add_output_arg();
+
+ TF_ASSERT_OK(ctor(&op));
+
+ const FullTypeDef& t = op.output_arg(0).experimental_full_type();
+ EXPECT_EQ(t.type_id(), TFT_TENSOR);
+ EXPECT_EQ(t.args_size(), 0);
+}
+
+TEST(Unary, Basic) {
+ OpTypeConstructor ctor = Unary(TFT_TENSOR, "T");
+
+ OpDef op;
+ op.add_output_arg();
+
+ TF_ASSERT_OK(ctor(&op));
+
+ const FullTypeDef& t = op.output_arg(0).experimental_full_type();
+ EXPECT_EQ(t.type_id(), TFT_TENSOR);
+ EXPECT_EQ(t.args_size(), 1);
+ EXPECT_EQ(t.args(0).type_id(), TFT_VAR);
+ EXPECT_EQ(t.args(0).args_size(), 0);
+ EXPECT_EQ(t.args(0).s(), "T");
+}
+
+TEST(UnaryGeneric, Basic) {
+ OpTypeConstructor ctor = UnaryGeneric(TFT_TENSOR);
+
+ OpDef op;
+ op.add_output_arg();
+
+ TF_ASSERT_OK(ctor(&op));
+
+ const FullTypeDef& t = op.output_arg(0).experimental_full_type();
+ EXPECT_EQ(t.type_id(), TFT_TENSOR);
+ EXPECT_EQ(t.args_size(), 1);
+ EXPECT_EQ(t.args(0).type_id(), TFT_ANY);
+ EXPECT_EQ(t.args(0).args_size(), 0);
+}
+
+TEST(UnaryTensorContainer, Fixed) {
+ OpTypeConstructor ctor = UnaryTensorContainer(TFT_ARRAY, TFT_INT32);
+
+ OpDef op;
+ op.add_output_arg();
+
+ TF_ASSERT_OK(ctor(&op));
+
+ const FullTypeDef& t = op.output_arg(0).experimental_full_type();
+ EXPECT_EQ(t.type_id(), TFT_ARRAY);
+ EXPECT_EQ(t.args_size(), 1);
+ EXPECT_EQ(t.args(0).type_id(), TFT_TENSOR);
+ EXPECT_EQ(t.args(0).args_size(), 1);
+ EXPECT_EQ(t.args(0).args(0).type_id(), TFT_INT32);
+ EXPECT_EQ(t.args(0).args(0).args_size(), 0);
+}
+
+TEST(GetArgDefaults, DefaultUnsetFromNoArgs) {
+ FullTypeDef t;
+
+ const auto& d = GetArgDefaultUnset(t, 0);
+
+ EXPECT_EQ(d.type_id(), TFT_UNSET);
+}
+
+TEST(GetArgDefaults, DefaultUnsetFromOutOfBounds) {
+ FullTypeDef t;
+ t.add_args()->set_type_id(TFT_TENSOR);
+
+ const auto& d = GetArgDefaultUnset(t, 1);
+
+ EXPECT_EQ(d.type_id(), TFT_UNSET);
+}
+
+TEST(GetArgDefaults, NoDefaultUnsetFromArg) {
+ FullTypeDef t;
+ t.add_args()->set_type_id(TFT_TENSOR);
+ t.mutable_args(0)->add_args();
+
+ const auto& d = GetArgDefaultUnset(t, 0);
+
+ EXPECT_EQ(d.type_id(), TFT_TENSOR);
+ EXPECT_EQ(d.args_size(), 1);
+}
+
+TEST(GetArgDefaults, DefaultAnyFromNoArgs) {
+ FullTypeDef t;
+
+ const auto& d = GetArgDefaultAny(t, 0);
+
+ EXPECT_EQ(d.type_id(), TFT_ANY);
+}
+
+TEST(GetArgDefaults, DefaultAnyFromOutOfBounds) {
+ FullTypeDef t;
+ t.add_args()->set_type_id(TFT_TENSOR);
+
+ const auto& d = GetArgDefaultAny(t, 1);
+
+ EXPECT_EQ(d.type_id(), TFT_ANY);
+}
+
+TEST(GetArgDefaults, DefaultAnyFromUnset) {
+ FullTypeDef t;
+ t.add_args();
+
+ const auto& d = GetArgDefaultAny(t, 0);
+
+ EXPECT_EQ(d.type_id(), TFT_ANY);
+}
+
+TEST(GetArgDefaults, NoDefaultAnyFromArg) {
+ FullTypeDef t;
+ t.add_args()->set_type_id(TFT_TENSOR);
+ t.mutable_args(0)->add_args();
+
+ const auto& d = GetArgDefaultAny(t, 0);
+
+ EXPECT_EQ(d.type_id(), TFT_TENSOR);
+ EXPECT_EQ(d.args_size(), 1);
+}
+
+TEST(IsEqual, Reflexivity) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ EXPECT_TRUE(IsEqual(t, t));
+}
+
+TEST(IsEqual, Copy) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ EXPECT_TRUE(IsEqual(t, u));
+ EXPECT_TRUE(IsEqual(u, t));
+}
+
+TEST(IsEqual, DifferentTypesNotEqual) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.set_type_id(TFT_ARRAY);
+
+ EXPECT_FALSE(IsEqual(t, u));
+ EXPECT_FALSE(IsEqual(u, t));
+}
+
+TEST(IsEqual, DifferentAritiesNotEqual) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.add_args()->set_type_id(TFT_FLOAT);
+
+ EXPECT_FALSE(IsEqual(t, u));
+ EXPECT_FALSE(IsEqual(u, t));
+}
+
+TEST(IsEqual, MissingArgsEquivalentToAny) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+
+ FullTypeDef u;
+ u = t;
+ u.add_args()->set_type_id(TFT_ANY);
+
+ EXPECT_TRUE(IsEqual(t, u));
+ EXPECT_TRUE(IsEqual(u, t));
+}
+
+TEST(IsEqual, DifferentArgsNotEqual) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.mutable_args(1)->set_type_id(TFT_FLOAT);
+
+ EXPECT_FALSE(IsEqual(t, u));
+ EXPECT_FALSE(IsEqual(u, t));
+}
+
+TEST(IsEqual, DifferentStringValuesNotEqual) {
+ FullTypeDef t;
+ t.set_type_id(TFT_VAR);
+ t.set_s("T");
+
+ FullTypeDef u;
+ u = t;
+ u.set_type_id(TFT_VAR);
+ u.set_s("U");
+
+ EXPECT_FALSE(IsEqual(t, u));
+ EXPECT_FALSE(IsEqual(u, t));
+}
+
+TEST(IsSubtype, Reflexivity) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ EXPECT_TRUE(IsSubtype(t, t));
+}
+
+TEST(IsSubtype, Copy) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ EXPECT_TRUE(IsSubtype(t, u));
+}
+
+TEST(IsSubtype, Any) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u.set_type_id(TFT_ANY);
+
+ EXPECT_TRUE(IsSubtype(t, u));
+ EXPECT_FALSE(IsSubtype(u, t));
+}
+
+TEST(IsSubtype, Unset) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u.set_type_id(TFT_UNSET);
+
+ EXPECT_TRUE(IsSubtype(t, u));
+ EXPECT_FALSE(IsSubtype(u, t));
+}
+
+TEST(IsSubtype, Covariance) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_ARRAY);
+ t.mutable_args(0)->add_args()->set_type_id(TFT_INT32);
+
+ FullTypeDef u;
+ u.set_type_id(TFT_TENSOR);
+ u.add_args()->set_type_id(TFT_ANY);
+
+ EXPECT_TRUE(IsSubtype(t, u, /*covariant=*/true));
+ EXPECT_FALSE(IsSubtype(u, t, /*covariant=*/true));
+
+ EXPECT_FALSE(IsSubtype(t, u, /*covariant=*/false));
+ EXPECT_TRUE(IsSubtype(u, t, /*covariant=*/false));
+}
+
+TEST(IsSubtype, DifferentTypesNotSubtype) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.set_type_id(TFT_ARRAY);
+
+ EXPECT_FALSE(IsSubtype(t, u));
+ EXPECT_FALSE(IsSubtype(u, t));
+}
+
+TEST(IsSubtype, DifferentAritiesDefaultToAny) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.add_args()->set_type_id(TFT_FLOAT);
+
+ EXPECT_FALSE(IsSubtype(t, u));
+ EXPECT_TRUE(IsSubtype(u, t));
+}
+
+TEST(IsSubtype, DifferentArgsNotSubtype) {
+ FullTypeDef t;
+ t.set_type_id(TFT_TENSOR);
+ t.add_args()->set_type_id(TFT_INT32);
+ t.add_args()->set_type_id(TFT_INT64);
+
+ FullTypeDef u;
+ u = t;
+ u.mutable_args(1)->set_type_id(TFT_FLOAT);
+
+ EXPECT_FALSE(IsSubtype(t, u));
+ EXPECT_FALSE(IsSubtype(u, t));
+}
+
+} // namespace
+
+} // namespace full_type
+
+} // namespace tensorflow