[JIT] Implement may_contain_alias in FunctionSchema (#81352)

- Created may_contain_alias method in FunctionSchema to publicize more detailed aliasing information about inputs and outputs of a schema. This method returns whether the first argument may contain an alias to the second argument (ie if the first argument is a list[Tensor], it can contain an alias to the second argument of the second argument is Tensor(*)) and vice versa if bidirectional = true.
- Created helper methods are explained more thoroughly in detail in function_schema.h
-Tested may_contain_alias methods for basic functionality, bidirectional functionality, wildcard functionality and dual container functionality in test_schema_info.cpp.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81352
Approved by: https://github.com/davidberard98, https://github.com/Gamrix
diff --git a/aten/src/ATen/core/function_schema.cpp b/aten/src/ATen/core/function_schema.cpp
index 2f8d079..213c044 100644
--- a/aten/src/ATen/core/function_schema.cpp
+++ b/aten/src/ATen/core/function_schema.cpp
@@ -1,6 +1,7 @@
 #include <ATen/core/function_schema.h>
 
 #include <iostream>
+#include <stack>
 
 namespace c10 {
 
@@ -16,6 +17,48 @@
   }
 }
 
+bool FunctionSchema::canAliasTypeSetsAlias(const c10::optional<std::vector<TypePtr>> &lhs, const c10::optional<std::vector<TypePtr>> &rhs) const {
+  if (!lhs || !rhs) {
+    return false;
+  }
+  for (const TypePtr& lhsType : *lhs) {
+    for (const TypePtr& rhsType : *rhs) {
+      if (lhsType == rhsType) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+c10::optional<std::vector<TypePtr>> FunctionSchema::getAliasTypeSetContainedTypes(const c10::optional<std::vector<TypePtr>> &aliasTypeSet) const {
+  if (!aliasTypeSet) {
+    return c10::nullopt;
+  }
+  std::unordered_set<TypePtr> containedTypes;
+  std::stack<TypePtr> typeStack;
+  // Push all 1st level contained types into the stack.
+  for (const TypePtr& type: *aliasTypeSet) {
+    for (const TypePtr& containedType : type->containedTypes()){
+      typeStack.push(containedType);
+    }
+  }
+
+  // process all further level contained types.
+  while (!typeStack.empty()) {
+    TypePtr current = typeStack.top();
+    typeStack.pop();
+    if (!containedTypes.count(current)) {
+      for (const TypePtr& containedType : current->containedTypes()) {
+        typeStack.push(containedType);
+      }
+    }
+    containedTypes.insert(current);
+  }
+
+  return std::vector<TypePtr>(containedTypes.begin(), containedTypes.end());
+}
+
 c10::optional<std::vector<TypePtr>> FunctionSchema::mapTypeToAliasTypeSet(const TypePtr& type) const {
   switch(type->kind()) {
     case TypeKind::ListType:
@@ -65,7 +108,7 @@
   }
 }
 
-  bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
+bool FunctionSchema::may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const {
   TORCH_INTERNAL_ASSERT(
       (lhs.index < getCorrectList(lhs.type).size()),
       "Invalid index for schema.");
@@ -79,20 +122,8 @@
   c10::optional<std::vector<TypePtr>> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
   c10::optional<std::vector<TypePtr>> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
 
-  // Check to see if the lhs and rhs types can alias each other
-  bool typesCanAlias = false;
-  if (lhsTypes && rhsTypes) {
-    for (const TypePtr& lhsType : *lhsTypes) {
-      for (const TypePtr& rhsType : *rhsTypes) {
-        if (lhsType == rhsType) {
-          typesCanAlias = true;
-        }
-      }
-    }
-  }
-
   // Check to see if lhs and rhs have the same alias set
-  if (typesCanAlias) {
+  if (canAliasTypeSetsAlias(lhsTypes, rhsTypes)) {
     if (lhsArg.alias_info() && rhsArg.alias_info()) {
       for (const auto& lhsSet : lhsArg.alias_info()->afterSets()) {
         for (const auto& rhsSet : rhsArg.alias_info()->afterSets()) {
@@ -107,4 +138,27 @@
   return false;
 }
 
+bool FunctionSchema::may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional) const {
+  bool may_alias_result = may_alias(lhs, rhs);
+  if (may_alias_result) {
+    return true;
+  }
+
+  const c10::Argument lhsArg = getCorrectList(lhs.type)[lhs.index];
+  const c10::Argument rhsArg = getCorrectList(rhs.type)[rhs.index];
+  c10::optional<std::vector<TypePtr>> lhsTypes = mapTypeToAliasTypeSet(lhsArg.type());
+  c10::optional<std::vector<TypePtr>> rhsTypes = mapTypeToAliasTypeSet(rhsArg.type());
+  c10::optional<std::vector<TypePtr>> lhsContainedTypes = getAliasTypeSetContainedTypes(lhsTypes);
+  c10::optional<std::vector<TypePtr>> rhsContainedTypes = getAliasTypeSetContainedTypes(rhsTypes);
+
+  // Checks if one side is wildcard and the other side is a container of the same type
+  bool lhsWildcard = lhsArg.alias_info() && lhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(lhsTypes, rhsContainedTypes);
+  bool rhsWildcard = rhsArg.alias_info() && rhsArg.alias_info()->isWildcardAfter() && canAliasTypeSetsAlias(rhsTypes, lhsContainedTypes);
+
+  if (bidirectional) {
+    return lhsWildcard || rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
+  } else {
+    return rhsWildcard || canAliasTypeSetsAlias(lhsContainedTypes, rhsContainedTypes);
+  }
 }
+} // namespace c10
diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h
index de497be5..2cb89c1 100644
--- a/aten/src/ATen/core/function_schema.h
+++ b/aten/src/ATen/core/function_schema.h
@@ -339,6 +339,13 @@
     }
   }
 
+  // Returns whether the two AliasTypeSets contain any similarities
+  // ie: whether the two type sets can alias.
+  bool canAliasTypeSetsAlias(const c10::optional<std::vector<TypePtr>> &lhs, const c10::optional<std::vector<TypePtr>> &rhs) const;
+
+  // Recursively Finds all contained types within the AliasTypeSet.
+  c10::optional<std::vector<TypePtr>> getAliasTypeSetContainedTypes(const c10::optional<std::vector<TypePtr>> &aliasTypeSet) const ;
+
   // Similar to mapTypeToAliasTypeSet defined in alias_analysis.cpp.
   // Used to map types to a type such that all types that can alias will be mapped to the same type.
   // For example, calling this method on 'Optional[List[int]]' is the same as calling this method
@@ -402,6 +409,12 @@
   // FunctionSchema::may_contain_alias will include that functionality.
   bool may_alias(const SchemaArgument& lhs, const SchemaArgument& rhs) const;
 
+  // Returns whether lhs and rhs may alias directly or whether lhs/rhs are a container
+  // that may contain elements that alias the other argument.
+  // bidirectional = false only returns whether lhs may contain an alias of rhs
+  // while bidirectional = true returns both directions.
+  bool may_contain_alias(const SchemaArgument& lhs, const SchemaArgument& rhs, bool bidirectional = true) const;
+
   c10::optional<int> argumentIndexWithName(c10::string_view name) const {
     for (const auto i : c10::irange(arguments().size())) {
       if(name == arguments()[i].name())
diff --git a/test/cpp/jit/test_schema_info.cpp b/test/cpp/jit/test_schema_info.cpp
index e4c5d33..828129f3 100644
--- a/test/cpp/jit/test_schema_info.cpp
+++ b/test/cpp/jit/test_schema_info.cpp
@@ -173,5 +173,43 @@
       {c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
 }
 
+TEST(FunctionSchemaMayContainAliasTest, Basic) {
+  c10::FunctionSchema schema = torch::jit::parseSchema(
+      "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))");
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}));
+  ASSERT_FALSE(schema.may_contain_alias(
+      {c10::SchemaArgType::input, 1}, {c10::SchemaArgType::output, 0}));
+  ASSERT_FALSE(schema.may_contain_alias(
+      {c10::SchemaArgType::input, 1}, {c10::SchemaArgType::input, 0}));
+}
+
+TEST(FunctionSchemaMayContainAliasTest, Wildcard) {
+  c10::FunctionSchema schema = torch::jit::parseSchema(
+      "aten::test.Tensor(Tensor(*) self) -> (Tensor[], Tensor)");
+  ASSERT_FALSE(schema.may_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}, false));
+  ASSERT_FALSE(schema.may_contain_alias(
+      {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}, false));
+  ASSERT_FALSE(schema.may_alias(
+      {c10::SchemaArgType::output, 1}, {c10::SchemaArgType::input, 0}));
+}
+
+TEST(FunctionSchemaMayContainAliasTest, InputAndOutputContainers) {
+  c10::FunctionSchema schema =
+      torch::jit::parseSchema("aten::test.Tensor(Tensor[] self) -> Tensor[]");
+  ASSERT_FALSE(schema.may_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}));
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::output, 0}, {c10::SchemaArgType::input, 0}, false));
+  ASSERT_TRUE(schema.may_contain_alias(
+      {c10::SchemaArgType::input, 0}, {c10::SchemaArgType::output, 0}, false));
+}
 } // namespace utils
 } // namespace torch