blob: fa6d96b9a59376b63df87295e7daeee3248d24aa [file] [log] [blame]
"""
This is the centralized file for all PyTorch operator upgraders.
Each function definition here should satisfy following
requirements:
1. The body of the function must be Torchscriptable
2. The naming convention of the upgraders should be:
<op_name>_<op_overload>_upgrader_<old_version>_<new_version>
3. The name of the upgrader must be present in
torch/csrc/jit/operator_upgraders/version_map.h
"""
import torch
from typing import List, no_type_check, Optional, Union
# TODO (tugsuu) This context manager
# forbids the tests to override torch.jit.script
# Without it, torch.jit.load will fail due to
# circular dependency
with torch._jit_internal._disable_emit_hooks():
@torch.jit.script
def div_Tensor_0_3(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
@torch.jit.script
def div_Scalar_0_3(self: torch.Tensor, other: Union[int, float]) -> torch.Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide(other)
return self.divide(other, rounding_mode='trunc')
@torch.jit.script
def div_out_0_3(self: torch.Tensor, other: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
if (self.is_floating_point() or other.is_floating_point() or out.is_floating_point()):
return self.true_divide(other, out=out)
return self.divide(other, rounding_mode='trunc', out=out) # type: ignore[call-overload]
@torch.jit.script
def div__Tensor_0_3(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
if (self.is_floating_point() or other.is_floating_point()):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
@torch.jit.script
def div__Scalar_0_3(self: torch.Tensor, other: Union[int, float]) -> torch.Tensor:
if (self.is_floating_point() or isinstance(other, float)):
return self.true_divide_(other)
return self.divide_(other, rounding_mode='trunc')
# TODO: since TS relies on typecheck comment of mypy
# adding type: ignore at specific lines in this function
# messes up our type refinement. For now, let's not check
# type here.
@no_type_check
@torch.jit.script
def full_0_4(size: List[int], fill_value: Union[int, float], *,
dtype: Optional[int], layout: Optional[int], device: Optional[torch.device],
pin_memory: Optional[bool]) -> torch.Tensor:
if dtype is None:
fill_value = float(fill_value)
return torch.full(size, fill_value, dtype=dtype, layout=layout,
device=device, pin_memory=pin_memory)
@torch.jit.script
def full_out_0_4(size: List[int], fill_value: Union[int, float], *, out: torch.Tensor) -> torch.Tensor:
return torch.full(size, fill_value, out=out)
def format_bytecode(table):
# given a nested tuples, convert them to nested list
def listify(content):
if not isinstance(content, tuple):
return content
return [listify(i) for i in content]
formatted_table = {}
for entry in table:
identifier = entry[0]
content = entry[1]
content = listify(content)
formatted_table[identifier] = content
return formatted_table
def collect_available_upgraders():
# There needs to be 1 to 1 mapping between the
# upgrader entries here and the list of upgraders
# in the torch/csrc/operator_upgraders/version_map.h
entries = globals()
# ignore test operators
version_map = {k : v for k, v in torch._C._get_operator_version_map().items()
if not k.startswith("aten::_test")}
# 1. Check if everything in version_map.h is defined here
available_upgraders_in_version_map = set()
for op_name in version_map:
for upgrader_entry in version_map[op_name]:
if upgrader_entry.upgrader_name not in entries:
raise AssertionError("Upgrader entry {} needs to be defined in python".format(upgrader_entry.upgrader_name))
available_upgraders_in_version_map.add(upgrader_entry.upgrader_name)
# 2. Check if everything in this file is registered in version_map.h
for entry in entries:
if isinstance(entries[entry], torch.jit.ScriptFunction):
if entry not in available_upgraders_in_version_map:
raise AssertionError("The upgrader {} is not registered in the version_map.h".format(entry))
return available_upgraders_in_version_map
def generate_bytecode() -> List:
upgrader_set = collect_available_upgraders()
yaml_content = []
for upgrader_name in upgrader_set:
upgrader_graph = globals()[upgrader_name].graph
upgrader_bytecode = torch._C._compile_graph_to_code_table(upgrader_name, upgrader_graph)
entry = {upgrader_name: format_bytecode(upgrader_bytecode)}
yaml_content.append(entry)
return yaml_content
if __name__ == "__main__":
raise RuntimeError("This file is not meant to be run directly")