blob: 21e09e092807ed93fd8accf8c64133a11a6349b4 [file] [log] [blame] [edit]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from torchgen.model import FunctionSchema, SchemaKind
from torchgen.native_function_generation import (
functional_to_out_signature,
mutable_to_out_signature,
self_to_out_signature,
)
from torchgen.utils import NamespaceHelper
def gen_out_variant_schema(func_op_schema: str) -> str:
"""
Generate schema for the out= variant of a given functional operator schema.
"""
# Parse the operator schema
namespace_helper = NamespaceHelper.from_namespaced_entity(
namespaced_entity=func_op_schema, max_level=1
)
func = FunctionSchema.parse(namespace_helper.entity_name)
namespace = namespace_helper.get_cpp_namespace(default="")
# Convert it to out variant schema
if func.kind() == SchemaKind.inplace:
schema = str(self_to_out_signature(func))
elif func.kind() == SchemaKind.functional:
schema = str(functional_to_out_signature(func))
elif func.kind() == SchemaKind.mutable:
schema = str(mutable_to_out_signature(func))
elif func.kind() == SchemaKind.out:
schema = str(func)
else:
raise RuntimeError(f"SchemaKind: {func.kind()} is not supported")
return f"{namespace}::{schema}" if namespace else schema