blob: d9462ba58ba439204f6aa0925258384f56056b49 [file] [log] [blame]
import importlib
import logging
import torch
from torch._dynamo import register_backend
log = logging.getLogger(__name__)
@register_backend
def ipex(model, inputs):
try:
import intel_extension_for_pytorch # type: ignore[import] # noqa: F401
except ImportError:
log.exception(
"Unable to import Intel Extension for PyTorch (IPEX). "
"Please install the right version of IPEX that matches the PyTorch version being used. "
"Refer to https://github.com/intel/intel-extension-for-pytorch for details."
)
raise
from torch.utils._mode_utils import no_dispatch
with no_dispatch():
static_inputs = []
for x in inputs:
if x._has_symbolic_sizes_strides:
size = [s.node.shape_env.size_hint(s.node.expr) for s in x.size()]
stride = [s.node.shape_env.size_hint(s.node.expr) for s in x.stride()]
static_inputs.append(
torch.as_strided(
torch.zeros(size, dtype=x.dtype, device=x.device), size, stride
)
)
else:
static_inputs.append(torch.zeros_like(x))
try:
with torch.no_grad():
traced_model = torch.jit.trace(model.eval(), static_inputs)
traced_model = torch.jit.freeze(traced_model)
return traced_model
except Exception:
log.warning("JIT trace failed during the 'ipex' optimize process.")
return model
def has_ipex():
try:
importlib.import_module("intel_extension_for_pytorch")
return True
except ImportError:
return False