blob: 48286bf4bcf64bee03b34693c75d0d59902e906b [file] [log] [blame]
import copy
from typing import Optional, Tuple
import torch
from torch.fx.experimental.symbolic_shapes import ShapeEnv
import torch.utils._pytree as pytree
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from .logical_schema import TensorMeta # type: ignore[attr-defined]
__all__ = ["convert_fake_tensor_to_tensor_meta", "convert_tensor_meta_to_fake_tensor"]
def _extract_tensor_meta(result: torch.Tensor) -> TensorMeta:
"""
Extract a TensorMeta describing `result`.
"""
return TensorMeta(
dtype=result.dtype,
sizes=result.shape,
requires_grad=result.requires_grad,
device=result.device,
strides=result.stride(),
storage_offset=0,
layout=result.layout,
)
def convert_fake_tensor_to_tensor_meta(
gm: torch.fx.GraphModule
) -> Tuple[torch.fx.GraphModule, Optional[ShapeEnv]]:
"""
Replace the faketensor metadata with the tensor metadata dataclass since we
cannot serialize faketensors
"""
gm = copy.deepcopy(gm)
shape_env = None
for node in gm.graph.nodes:
def get_shape_env(val) -> Optional[ShapeEnv]:
val_flat, _ = pytree.tree_flatten(val)
curr_shape_env = None
for v in val_flat:
if not isinstance(v, FakeTensor):
continue
if curr_shape_env is None:
curr_shape_env = v.fake_mode.shape_env
else:
assert (
curr_shape_env is v.fake_mode.shape_env
), "Multiple shape envs detected."
return curr_shape_env
if (val := node.meta.get("val", None)) is not None:
if shape_env is None:
shape_env = get_shape_env(val)
elif (new_shape_env := get_shape_env(val)) is not None:
assert (
shape_env is new_shape_env
), "Multiple shape envs detected."
node.meta["tensor_meta"] = pytree.tree_map_only(
torch.Tensor, _extract_tensor_meta, val
)
del node.meta["val"]
return gm, shape_env
def convert_tensor_meta_to_fake_tensor(gm: torch.fx.GraphModule, shape_env: ShapeEnv = None) -> torch.fx.GraphModule:
"""
Replace (inplace) the tensor metadata with faketensor
"""
fake_tensor_mode = FakeTensorMode(allow_non_fake_inputs=True, shape_env=shape_env)
for node in gm.graph.nodes:
if (val := node.meta.get("tensor_meta", None)) is not None:
def _extract_faketensor(tensor_meta: TensorMeta):
return FakeTensor(
fake_tensor_mode,
torch.empty(
tensor_meta.sizes,
dtype=tensor_meta.dtype,
device="meta",
requires_grad=tensor_meta.requires_grad,
),
torch.device("cpu"),
)
node.meta["val"] = pytree.tree_map_only(
TensorMeta, _extract_faketensor, val
)
return gm