blob: 1dc83d5335a754040caa4f8b1dac21aa806b9853 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import List, Tuple, Union
import torch
from executorch.exir.sym_util import eval_shape
from executorch.exir.tensor import TensorSpec
from torch.utils import _pytree as pytree
from typing_extensions import TypeAlias
TensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype]
AllocSpec: TypeAlias = Union[
TensorAllocSpec,
List[TensorAllocSpec],
]
def alloc(spec: AllocSpec) -> pytree.PyTree:
if isinstance(spec, list):
return [alloc(s) for s in spec]
shape, dtype = spec
# evaluate the shape to int so we can run the traced module
# in python for testing
shape = eval_shape(shape)
return torch.empty(shape, dtype=dtype)
def free(spec: TensorSpec) -> None:
"""
The function is nop. The major purpose is to put it in the Fx IR.
E.g., it can be the target of call_function node.
"""
pass