[JIT] Implement may_contain_alias in FunctionSchema (#81352)
- Created may_contain_alias method in FunctionSchema to publicize more detailed aliasing information about inputs and outputs of a schema. This method returns whether the first argument may contain an alias to the second argument (ie if the first argument is a list[Tensor], it can contain an alias to the second argument of the second argument is Tensor(*)) and vice versa if bidirectional = true.
- Created helper methods are explained more thoroughly in detail in function_schema.h
-Tested may_contain_alias methods for basic functionality, bidirectional functionality, wildcard functionality and dual container functionality in test_schema_info.cpp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81352
Approved by: https://github.com/davidberard98, https://github.com/Gamrix
diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp
index 2f8d079..213c044 100644
--- a/aten/src/ATen/core/function_schema.cpp
+++ b/aten/src/ATen/core/function_schema.cpp
@@ -1,6 +1,7 @@
#include <ATen/core/function_schema.h>
#include <iostream>
+#include <stack>
namespace c10 {
@@ -16,6 +17,48 @@
}
}
+bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<std::vector<TypePtr>> &lhs, const c10::optional<std::vector<TypePtr>> &rhs) const {
+ if (!lhs || !rhs) {
+ return false;
+ }
+ for (const TypePtr& lhsType : *lhs) {
+ for (const TypePtr& rhsType : *rhs) {
+ if (lhsType == rhsType) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+c10::optional<std::vector<TypePtr>> FunctionSchema::getAliasTypeSetContainedTypes(const c10::optional<std::vector<TypePtr>> &aliasTypeSet) const {
+ if (!aliasTypeSet) {
+ return c10::nullopt;
+ }
+ std::unordered_set<TypePtr> containedTypes;
+ std::stack<TypePtr> typeStack;
+ // Push all 1st level contained types into the stack.
+ for (const TypePtr& type: *aliasTypeSet) {
+ for (const TypePtr& containedType : type->containedTypes()){
+ typeStack.push(containedType);
+ }
+ }
+
+ // process all further level contained types.
+ while (!typeStack.empty()) {
+ TypePtr current = typeStack.top();
+ typeStack.pop();
+ if (!containedTypes.count(current)) {
+ for (const TypePtr& containedType : current->containedTypes()) {
+ typeStack.push(containedType);
+ }
+ }
+ containedTypes.insert(current);
+ }
+
+ return std::vector<TypePtr>(containedTypes.begin(), containedTypes.end());
+}
+
c10::optional<std::vector<TypePtr>> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr& type) const {
switch(type->kind()) {
case TypeKind::ListType:
@@ -65,7 +108,7 @@
}
}
- bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
+bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
TORCH_INTERNAL_ASSERT(
(lhs.index < getCorrectList(lhs.type).size()),
"Invalid index for schema.");
@@ -79,20 +122,8 @@
c10::optional<std::vector<TypePtr>> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
c10::optional<std::vector<TypePtr>> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
- // Check to see if the lhs and rhs types can alias each other
- bool typesCanAlias = false;
- if (lhsTypes && rhsTypes) {
- for (const TypePtr& lhsType : *lhsTypes) {
- for (const TypePtr& rhsType : *rhsTypes) {
- if (lhsType == rhsType) {
- typesCanAlias = true;
- }
- }
- }
- }
-
// Check to see if lhs and rhs have the same alias set
- if (typesCanAlias) {
+ if (canAliasTypeSetsAlias(lhsTypes, rhsTypes)) {
if (lhsArg.alias_info() && rhsArg.alias_info()) {
for (const auto& lhsSet : lhsArg.alias_info()->afterSets()) {
for (const auto& rhsSet : rhsArg.alias_info()->afterSets()) {
@@ -107,4 +138,27 @@
return false;
}
+bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional) const {
+ bool may_alias_result = may_alias(lhs, rhs);
+ if (may_alias_result) {
+ return true;
+ }
+
+ const c10::Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
+ const c10::Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
+ c10::optional<std::vector<TypePtr>> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
+ c10::optional<std::vector<TypePtr>> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
+ c10::optional<std::vector<TypePtr>> lhsContainedTypes = getAliasTypeSetContainedTypes(lhsTypes);
+ c10::optional<std::vector<TypePtr>> rhsContainedTypes = getAliasTypeSetContainedTypes(rhsTypes);
+
+ // Checks if one side is wildcard and the other side is a container of the same type
+ bool lhsWildcard = lhsArg.alias_info() && lhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(lhsTypes, rhsContainedTypes);
+ bool rhsWildcard = rhsArg.alias_info() && rhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(rhsTypes, lhsContainedTypes);
+
+ if (bidirectional) {
+ return lhsWildcard || rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
+ } else {
+ return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
+ }
}
+} // namespace c10
diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h
index de497be5..2cb89c1 100644
--- a/aten/src/ATen/core/function_schema.h
+++ b/aten/src/ATen/core/function_schema.h
@@ -339,6 +339,13 @@
}
}
+ // Returns whether the two AliasTypeSets contain any similarities
+ // ie: whether the two type sets can alias.
+ bool canAliasTypeSetsAlias(const c10::optional<std::vector<TypePtr>> &lhs, const c10::optional<std::vector<TypePtr>> &rhs) const;
+
+ // Recursively Finds all contained types within the AliasTypeSet.
+ c10::optional<std::vector<TypePtr>> getAliasTypeSetContainedTypes(const c10::optional<std::vector<TypePtr>> &aliasTypeSet) const ;
+
// Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp.
// Used to map types to a type such that all types that can alias will be mapped to the same type.
// For example, calling this method on 'Optional[List[int]]' is the same as calling this method
@@ -402,6 +409,12 @@
// FunctionSchema::may_contain_alias will include that functionality.
bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const;
+ // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container
+ // that may contain elements that alias the other argument.
+ // bidirectional = false only returns whether lhs may contain an alias of rhs
+ // while bidirectional = true returns both directions.
+ bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const;
+
c10::optional<int> argumentIndexWithName(c10::string_view name) const {
for (const auto i : c10::irange(arguments().size())) {
if(name == arguments()[i].name())
diff --git a/test/cpp/jit/test_schema_info.cpp b/test/cpp/jit/test_schema_info.cpp
index e4c5d33..828129f3 100644
--- a/test/cpp/jit/test_schema_info.cpp
+++ b/test/cpp/jit/test_schema_info.cpp
@@ -173,5 +173,43 @@
{c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
}
+TEST(FunctionSchemaMayContainAliasTest, Basic) {
+ c10::FunctionSchema schema = torch::jit::parseSchema(
+ "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}));
+ ASSERT_FALSE(schema.may_contain_alias(
+ {c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
+ ASSERT_FALSE(schema.may_contain_alias(
+ {c10::SchemaArgType::input, 1}, {c10::SchemaArgType::input, 0}));
+}
+
+TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
+ c10::FunctionSchema schema = torch::jit::parseSchema(
+ "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
+ ASSERT_FALSE(schema.may_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}, false));
+ ASSERT_FALSE(schema.may_contain_alias(
+ {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}, false));
+ ASSERT_FALSE(schema.may_alias(
+ {c10::SchemaArgType::output, 1}, {c10::SchemaArgType::input, 0}));
+}
+
+TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
+ c10::FunctionSchema schema =
+ torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
+ ASSERT_FALSE(schema.may_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}, false));
+ ASSERT_TRUE(schema.may_contain_alias(
+ {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}, false));
+}
} // namespace utils
} // namespace torch