blob: 48f983b2917ef8cf2f49df83d181d9a326394568 [file] [log] [blame]
# mypy: allow-untyped-defs
import operator
from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.export._trace
from torch.export.exported_program import ExportedProgram
from torch.export.graph_signature import (
ConstantArgument,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
TensorArgument,
)
from torch.fx import subgraph_rewriter
from torch.onnx.utils import _create_jit_graph
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
def pattern(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
scalar_tensor = torch.ops.aten.scalar_tensor(sym_size_int)
div_scalar_mode = torch.ops.aten.div.Scalar_mode(
scalar_tensor, scale, rounding_mode="trunc"
)
int_tensor = torch.ops.aten.Int.Tensor(div_scalar_mode)
return int_tensor
def replacement(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
return sym_size_int // scale
replaced_patterns = subgraph_rewriter.replace_pattern(gm, pattern, replacement)
def normalize_name(name: str) -> str:
return name.replace(".", "_")
def ir_name_to_func_name(name: str) -> str:
"""prim::If -> convert_prim_If"""
name_list = name.split("::")
return "convert_" + "_".join(name_list)
def get_node_for_param_and_buffer(fx_graph, name, is_top_level_graph):
if is_top_level_graph:
return fx_graph.get_attr(name)
return fx_graph.placeholder(name)
_TORCH_DTYPE_TO_ENUM = {
torch.uint8: 0,
torch.int8: 1,
torch.int16: 2,
torch.int32: 3,
torch.int64: 4,
torch.float16: 5,
torch.float32: 6,
torch.float64: 7,
torch.complex32: 8,
torch.complex64: 9,
torch.complex128: 10,
torch.bool: 11,
torch.bfloat16: 15,
}
def get_dtype_as_int(tensor):
"""
prim::dtype has the signature "Tensor a) -> int", where it gets the dtype of
the tensor and returns the integer corresponding to this dtype based on the
enum in ScalarType.h
"""
dtype = tensor.dtype
if dtype not in _TORCH_DTYPE_TO_ENUM:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _TORCH_DTYPE_TO_ENUM[dtype]
# Those operators will be automatically populated to a instance method
# of TS2FXGraphConverter with name convert_<namespace>_<opname>().
# Please check __init__ for method population implementations.
kind_to_standard_operators = {
"prim::TupleIndex": operator.getitem,
"aten::__is__": operator.is_,
"aten::__isnot__": operator.is_not,
"aten::__not__": operator.not_,
"aten::__contains__": operator.contains,
"prim::dtype": get_dtype_as_int,
"aten::len": len,
}
def get_ir_value_parent_name_and_attr_name(node):
irv_parent_name, irv_name = node.input().debugName(), node.output().debugName()
attr_name = node.s("name")
return irv_name, irv_parent_name, attr_name
def construct_fqn(ir, ref_map, name_map):
name_list = []
while ir in ref_map:
name_list.append(name_map[ir])
ir = ref_map[ir]
return ".".join(reversed(name_list))
def get_block_to_lifted_attrs(graph: torch._C.Graph) -> Dict[torch._C.Block, Set[str]]:
"""
Perform two passes to get a mapping of blocks to a set of FQNs of its lifted attributes.
When a graph has control flow, the graph will be divided into multiple blocks. We want to convert
each block to a graph which will be passed into torch.cond. A restriction for torch.cond is that model
parameters/buffers are expected to be lifted as inputs to the subgraphs. Before converting the model,
we will run this pass which will:
1. Figure out which params/buffers are used within blocks through tracing the GetAttr calls.
2. Process the graph bottom up to find the lifted attributes of each block by taking the union
of the attributes used in the current block, and the lifted attributes of all its child blocks.
Returns:
A mapping of blocks to a set of FQNs of its lifted attributes.
"""
# A map from a block to its expected to be lifted arguments.
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]] = dict()
# Reference map stores the input (i.e., src) and output (i.e., dest) IR of a
# GetAttr node. By traversing this reference map, we can figure out the
# full IR aliasing pass and figure out the FQN of an attribute.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_parent_map["%2"] = "%1"
node_to_parent_map: Dict[str, str] = dict()
# Used for reconstructing the FQN of an attribute based on the reference map.
# In nutshell, for each GetAttr call, GetAttr(input IR, attribute name) -> output IR
# This name map stores which attribute name is called for a src IR --> dest IR action.
# E.g., %2 = GetAttr(linear)[%1] --> node_to_attr_name["%2"] = "linear"
node_to_attr_name: Dict[str, str] = dict()
def _dfs_get_attr_dependency(entry):
"""
First DFS path to construct reference map and name map.
"""
for node in entry.nodes():
if node.kind() == "prim::GetAttr":
(
irv_name,
irv_parent_name,
attr_name,
) = get_ir_value_parent_name_and_attr_name(node)
node_to_parent_map[irv_name] = irv_parent_name
node_to_attr_name[irv_name] = attr_name
for block in node.blocks():
_dfs_get_attr_dependency(block)
def _map_blocks_to_lifted_attrs(entry):
"""
Walk the graph in a bottom-up fashion to build the expected to be
lifted arguments for each block.
"""
arguments: Set[str] = set()
for node in entry.nodes():
for block in node.blocks():
# Recursively build.
arguments = arguments.union(_map_blocks_to_lifted_attrs(block))
if node.kind() == "prim::GetAttr":
irv_name = node.output().debugName()
# Skip for intermediate GetAttr, which will anyway not result a FQN.
# E.g., node_to_parent_name: {"%3": "%2", "%2": "%1"}
# node_to_attr_name: {"%3": "weight", "%2": "linear", "%1": "self"}
# There is only one FQN %3-->%2-->%1: self.linear.weight
# %2-->%1 is not a FQN: self.linear
if irv_name not in set(node_to_parent_map.values()):
arguments.add(
construct_fqn(irv_name, node_to_parent_map, node_to_attr_name)
)
if not isinstance(entry, torch._C.Graph): # Skip the top level.
blocks_to_lifted_attrs[entry] = arguments
return arguments
_dfs_get_attr_dependency(graph)
_map_blocks_to_lifted_attrs(graph)
return blocks_to_lifted_attrs
def get_op_overload(node: torch._C.Node):
schema_str = node.schema()
schema: torch._C.FunctionSchema = torch._C.parse_schema(schema_str)
ns, op_name = str(schema.name).split("::")
override = schema.overload_name
try:
op_overload_mod = getattr(torch.ops, ns)
op_overload_packet = getattr(op_overload_mod, op_name)
if override:
op_overload = getattr(op_overload_packet, override)
else:
op_overload = op_overload_packet.default
except Exception as e:
raise RuntimeError(
f"Unable to find operator {node.kind()} with schema {node.schema}"
) from e
return op_overload
class TS2FXGraphConverter:
def __init__(
self,
ts_graph: Union[torch._C.Graph, torch._C.Block],
name_to_param_map: Dict[str, torch.Tensor],
name_to_buffer_map: Dict[str, torch.Tensor],
blocks_to_lifted_attrs: Dict[torch._C.Block, Set[str]],
):
self.ts_graph = ts_graph
self.name_to_param_map = name_to_param_map
self.name_to_buffer_map = name_to_buffer_map
self.fx_graph: torch.fx.Graph = torch.fx.Graph()
self.input_specs: List[InputSpec] = []
self.output_specs: List[OutputSpec] = []
self.name_to_node: Dict[
str, Union[torch.fx.Node, List[torch.fx.Node], Dict[Any, torch.fx.Node]]
] = {}
self.constant_map: Dict[str, Any] = {}
self.attribute_map: Dict[str, Any] = {}
self.tensor_constants: Dict[str, torch.Tensor] = {}
self.subgraphs: Dict[str, torch.fx.GraphModule] = {}
self.blocks_to_lifted_attrs = blocks_to_lifted_attrs
# Populate methods for the standard operators.
for k in kind_to_standard_operators.keys():
handler_func_name = ir_name_to_func_name(k)
# Create an indirect function call:
# convert_<namespace>_<opname> --> lambda node: _convert_standard_operator(node)
setattr(
self,
handler_func_name,
lambda node: self._convert_standard_operators(node),
)
def is_top_level_graph(self):
return isinstance(self.ts_graph, torch._C.Graph)
def add_subgraph(self, subgraph) -> str:
name = f"subgraph_{len(self.subgraphs)}"
self.subgraphs[name] = subgraph
return name
def get_args_kwargs(self, node: torch._C.Node, schema):
args = []
kwargs = {}
for input, schema_arg in zip(node.inputs(), schema.arguments):
if schema_arg.kwarg_only:
kwargs[schema_arg.name] = self.get_fx_value(input)
else:
args.append(self.get_fx_value(input))
return tuple(args), kwargs
def get_fx_value(self, value: torch._C.Value):
value_name = value.debugName()
if value_name in self.name_to_node:
input_node = self.name_to_node[value_name]
return input_node
elif value_name in self.attribute_map:
attr_name = self.attribute_map[value_name]
if attr_name in self.name_to_node:
input_node = self.name_to_node[attr_name]
return input_node
else:
raise ValueError(f"Value {attr_name} not found")
elif value_name in self.constant_map:
return self.constant_map[value_name]
else:
raise ValueError(f"Input {value_name} not found")
def convert(self) -> torch.fx.GraphModule:
self.convert_graph_inputs()
for node in self.ts_graph.nodes():
self.convert_node(node)
self.convert_graph_outputs()
# Pass parameter and buffer to the root for lookup.
gm = torch.fx.GraphModule(
{
**self.subgraphs,
**self.name_to_param_map,
**self.name_to_buffer_map,
**self.tensor_constants,
},
self.fx_graph,
)
inplace_optimize_sym_size_div(gm)
gm.graph.lint()
return gm
def convert_graph_inputs(self):
for graph_input in self.ts_graph.inputs():
name = graph_input.debugName()
normalized_name = normalize_name(name)
if name in self.name_to_param_map:
self.input_specs.append(
InputSpec(
InputKind.PARAMETER,
arg=TensorArgument(name=normalized_name),
target=name,
)
)
fx_node = get_node_for_param_and_buffer(
self.fx_graph, name, self.is_top_level_graph()
)
elif name in self.name_to_buffer_map:
self.input_specs.append(
InputSpec(
InputKind.BUFFER,
arg=TensorArgument(name=normalized_name),
target=name,
persistent=True,
)
)
fx_node = get_node_for_param_and_buffer(
self.fx_graph, name, self.is_top_level_graph()
)
else:
self.input_specs.append(
InputSpec(
InputKind.USER_INPUT,
arg=TensorArgument(name=normalized_name),
target=name,
)
)
fx_node = self.fx_graph.placeholder(normalized_name)
self.name_to_node[name] = fx_node
def convert_aten_tensor(self, node: torch._C.Node):
"""aten::tensor creates a constant tensor ad-hoc --> GetAttr"""
args, kwargs = self.get_args_kwargs(node, torch.ops.aten.tensor.default._schema)
for k in kwargs:
if k == "requires_grad":
kwargs[k] = bool(kwargs[k]) # 0 -> False, 1 -> True
tensor = torch.tensor(*args, **kwargs)
output_name = node.output().debugName()
alias_name = f"lifted_tensor_{output_name}"
fx_node = self.fx_graph.get_attr(alias_name)
self.name_to_node[output_name] = fx_node
self.tensor_constants[alias_name] = tensor
def convert_prim_Constant(self, node: torch._C.Node):
name = node.output().debugName()
value: Any = None
if node.hasAttribute("value"):
constant_kind = node.kindOf("value")
if constant_kind == "i":
value = node.i("value")
elif constant_kind == "f":
value = node.f("value")
elif constant_kind == "s":
value = node.s("value")
elif constant_kind == "t":
alias_name = (
f"lifted_tensor_{name}" # Follow naming convention from EP tracing.
)
fx_node = self.fx_graph.get_attr(alias_name)
self.tensor_constants[alias_name] = node.t("value")
value = fx_node
elif constant_kind == "ival":
value = node.ival("value")
else:
raise ValueError(f"Unsupported constant type: {node.kindOf('value')}")
else:
value = None
self.constant_map[name] = value
def convert_prim_device(self, node: torch._C.Node):
input_type = node.input().type()
if input_type.isSubtypeOf(torch._C.TensorType.get()):
device = input_type.device() # type: ignore[attr-defined]
output_name = node.output().debugName()
self.constant_map[output_name] = device
else:
raise ValueError(f"Unsupported JitType ({input_type}) when get device")
def convert_prim_GetAttr(self, node: torch._C.Node):
def get_attr(name: str):
if name in self.attribute_map:
return self.attribute_map[name]
else:
raise ValueError(f"Attribute {name} not found")
output_name = node.output().debugName()
attr_name = node.s("name")
input_name = node.input().debugName()
root_attr_name = get_attr(input_name)
self.attribute_map[output_name] = (
f"{root_attr_name}.{attr_name}" if root_attr_name else attr_name
)
def convert_call_function_op(self, node: torch._C.Node):
target = get_op_overload(node)
if target is torch.ops.aten.size.int:
target = torch.ops.aten.sym_size.int
args, kwargs = self.get_args_kwargs(node, target._schema)
fx_node = self.fx_graph.call_function(target, args, kwargs)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_TupleConstruct(self, node: torch._C.Node):
self._convert_prim_iterator(node)
def convert_prim_ListConstruct(self, node: torch._C.Node):
self._convert_prim_iterator(node)
def _convert_prim_iterator(self, node: torch._C.Node):
output_list = []
for inp in node.inputs():
output_list.append(self.get_fx_value(inp))
output_name = node.output().debugName()
self.name_to_node[output_name] = output_list
def convert_prim_DictConstruct(self, node: torch._C.Node):
output_dict = {}
k, v = None, None
for i, inp in enumerate(node.inputs()):
# We assume key value are stored in pair in the DictConstruct.
# The first element is the key and the following is the value.
if i % 2 == 0:
k = self.get_fx_value(inp)
else:
v = self.get_fx_value(inp)
assert (
k is not None and v is not None
), "DictConstruct has an empty key value pair."
output_dict[k] = v
k, v = None, None
assert (
k is None and v is None
), "DictConstruct has an odd number of elements (violating our assumption)."
output_name = node.output().debugName()
self.name_to_node[output_name] = output_dict
def convert_prim_ListUnpack(self, node: torch._C.Node):
self._convert_prim_unpack_iterator(node)
def convert_prim_TupleUnpack(self, node: torch._C.Node):
self._convert_prim_unpack_iterator(node)
def _convert_prim_unpack_iterator(self, node: torch._C.Node):
# Single input and multiple outputs for unpacking.
for i, outp in enumerate(node.outputs()):
outp_name = outp.debugName()
inp = self.get_fx_value(node.input())
fx_node = self.fx_graph.call_function(operator.getitem, (inp, i))
self.name_to_node[outp_name] = fx_node
def convert_aten_Int(self, node: torch._C.Node):
# converts aten::Int as aten._to_copy + aten::_local_scalar_dense
target = torch.ops.aten._to_copy.default
args = tuple(self.get_fx_value(input) for input in node.inputs())
to_copy_node = self.fx_graph.call_function(target, args, {"dtype": torch.int32})
fx_node = self.fx_graph.call_function(
torch.ops.aten._local_scalar_dense.default, (to_copy_node,)
)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_NumToTensor(self, node: torch._C.Node):
# converts prim::NumToTensor as aten.scalar_tensor
target = torch.ops.aten.scalar_tensor
args = tuple(self.get_fx_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_CreateObject(self, node: torch._C.Node):
output_name = node.output().debugName()
self.attribute_map[output_name] = ""
def convert_aten__convolution(self, node: torch._C.Node):
# converts aten::_convolution as aten.convolution, since aten::_convolution
# doesn't have a meta function
target = torch.ops.aten.convolution.default
args, kwargs = self.get_args_kwargs(node, target._schema)
fx_node = self.fx_graph.call_function(target, args, kwargs)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_aten_div(self, node: torch._C.Node):
target = get_op_overload(node)
schema = target._schema
args, kwargs = self.get_args_kwargs(node, schema)
# converts aten::div.Tensor_mode(x, tensor_constant)
# as aten.div.Scalar_mode(x, tensor_constant.item())
if schema.overload_name == "Tensor_mode":
arg1_name = args[1].name
if arg1_name in self.tensor_constants:
tensor_constant = self.tensor_constants[arg1_name]
if tensor_constant.numel() == 1:
updated_args = list(args)
updated_args[1] = self.tensor_constants[arg1_name].item()
fx_node = self.fx_graph.call_function(
torch.ops.aten.div.Scalar_mode,
tuple(updated_args),
kwargs,
)
# TODO: covnert sourceRange() into stack_trace
# fx_node.meta["stack_trace"] = node.sourceRange()
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
return
self.convert_call_function_op(node)
def convert_aten___getitem__(self, node: torch._C.Node):
input_container, index = tuple(
self.get_fx_value(input) for input in node.inputs()
)
fx_node = self.fx_graph.call_function(
operator.getitem, (input_container, index)
)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_prim_If(self, node: torch._C.Node):
inputs = list(node.inputs())
assert len(inputs) == 1
predicate = self.get_fx_value(inputs[0])
# Get union of inputs to blocks
arguments = set()
for block in node.blocks():
block_args = set()
# TODO: block.inputs(), not sure what theyre used for
for block_node in block.nodes():
for block_node_in in block_node.inputs():
if block_node_in.debugName() in self.name_to_node:
block_args.add(block_node_in.debugName())
arguments.update(block_args)
# Lift parameters as inputs.
for block in node.blocks():
arguments = arguments.union(self.blocks_to_lifted_attrs[block])
arguments = list(arguments)
# Convert blocks to subgraphs
subgraph_nodes = []
for block in node.blocks():
subgraph_converter = TS2FXGraphConverter(
block, dict(), dict(), self.blocks_to_lifted_attrs
)
subgraph_converter.constant_map = self.constant_map
subgraph_converter.attribute_map = self.attribute_map
for block_arg in arguments:
normalized_block_arg_name = normalize_name(block_arg)
placeholder_node = subgraph_converter.fx_graph.placeholder(
normalized_block_arg_name
)
subgraph_converter.name_to_node[block_arg] = placeholder_node
subgraph = subgraph_converter.convert()
subgraph_name = self.add_subgraph(subgraph)
subgraph_nodes.append(self.fx_graph.get_attr(subgraph_name))
assert len(subgraph_nodes) == 2
fx_block_args = [self.name_to_node[arg_name] for arg_name in arguments]
args = (
predicate,
subgraph_nodes[0],
subgraph_nodes[1],
tuple(fx_block_args),
)
cond_node = self.fx_graph.call_function(torch.cond, args, {})
output_name = node.output().debugName()
self.name_to_node[output_name] = cond_node
def convert_aten_Bool(self, node: torch._C.Node):
self._convert_as_noop(node)
def _convert_as_noop(self, node: torch._C.Node):
# Converts the node as a no-op by mapping its output node as arg[0]
target = get_op_overload(node)
schema = target._schema
args, kwargs = self.get_args_kwargs(node, schema)
output_name = node.output().debugName()
self.name_to_node[output_name] = args[0]
def convert_profiler__record_function_enter_new(self, node: torch._C.Node):
target = torch.ops.profiler._record_function_enter_new
args = tuple(self.get_fx_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_profiler__record_function_exit(self, node: torch._C.Node):
# _record_function_exit has side effect so we keep it in fx.graph
# currently, _record_function_enter_new and _record_function_exit are
# discarded during `retrace_as_exported_program`.
target = torch.ops.profiler._record_function_exit
args = tuple(self.get_fx_value(input) for input in node.inputs())
self.fx_graph.call_function(target, args)
def convert_prim_tolist(self, node: torch._C.Node):
# prim::tolist cannot be supported by `_convert_standard_operators`
# since it requires call_method instead of call_function.
target = "tolist"
args = (self.get_fx_value(next(node.inputs())),)
fx_node = self.fx_graph.call_method(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def _convert_standard_operators(self, node: torch._C.Node):
target = kind_to_standard_operators[node.kind()]
args = tuple(self.get_fx_value(input) for input in node.inputs())
fx_node = self.fx_graph.call_function(target, args)
output_name = node.output().debugName()
self.name_to_node[output_name] = fx_node
def convert_node(self, node: torch._C.Node):
node_kind = node.kind()
# Get handler based on namespace and operator name.
# Provide a default node handler as well in case we don't find
# matching converter for that.
handler_func_name = ir_name_to_func_name(node_kind)
handler_func = getattr(self, handler_func_name, self.convert_call_function_op)
handler_func(node)
def convert_graph_outputs(self):
args = []
for graph_output in self.ts_graph.outputs():
output_name = graph_output.debugName()
if output_name in self.name_to_node:
args.append(self.name_to_node[output_name])
self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=TensorArgument(name=output_name),
target=output_name,
)
)
elif output_name in self.constant_map:
args.append(self.constant_map[output_name])
self.output_specs.append(
OutputSpec(
OutputKind.USER_OUTPUT,
arg=ConstantArgument(
name=output_name, value=self.constant_map[output_name]
),
target=output_name,
)
)
else:
raise ValueError(f"Output {output_name} not found")
self.fx_graph.output(
args[0]
) # Get rid of an extra list wrapped around final output.
class TS2EPConverter:
# TorchScript model to ExportedProgram converter
def __init__(
self,
ts_model: Union[torch.jit.ScriptModule, torch.jit.ScriptFunction],
sample_args: Tuple[Any, ...],
sample_kwargs: Optional[Dict[str, Any]] = None,
):
self.ts_model = ts_model
self.ts_graph, self.params, _, _ = _create_jit_graph(ts_model, sample_args)
self.sample_args = sample_args
self.sample_kwargs = sample_kwargs
self.name_to_param_map: Dict[str, torch.Tensor] = (
dict(ts_model.named_parameters())
if isinstance(ts_model, torch.jit.ScriptModule)
else dict()
)
self.name_to_buffer_map: Dict[str, torch.Tensor] = (
dict(ts_model.named_buffers())
if isinstance(ts_model, torch.jit.ScriptModule)
else dict()
)
def convert(self) -> ExportedProgram:
blocks_to_lifted_attrs = get_block_to_lifted_attrs(self.ts_graph)
graph_converter = TS2FXGraphConverter(
self.ts_graph,
self.name_to_param_map,
self.name_to_buffer_map,
blocks_to_lifted_attrs,
)
gm = graph_converter.convert()
ep = self.retrace_as_exported_program(gm, graph_converter.tensor_constants)
return ep
def retrace_as_exported_program(
self, gm: torch.fx.GraphModule, tensor_constants: Dict[str, torch.Tensor]
):
# TODO: adjust input orders to match GraphSignature convention
ep = torch.export._trace._export(
gm,
self.sample_args,
strict=False,
pre_dispatch=True,
)
# Post-processing to make sure the ExportedProgram states are correct.
# Because during conversion, we set tensor constants as GetAttr,
# retracing cannot recognize them as tensor constants but instead
# treat them as buffers. We need to set them again here.
ep._constants = tensor_constants
for k in tensor_constants:
ep.state_dict.pop(k, None)
for spec in ep.graph_signature.input_specs:
# Mark as constant tensors for erroneously traced buffers.
if spec.kind == InputKind.BUFFER and spec.target in tensor_constants:
spec.kind = InputKind.CONSTANT_TENSOR
ep.verifier().check(ep)
return ep