[functorch] Removed some obsolete workarounds in ts_compile and added a new one (pytorch/functorch#875)
diff --git a/functorch/functorch/_src/compilers.py b/functorch/functorch/_src/compilers.py
index 10fe42a..74b71af 100644
--- a/functorch/functorch/_src/compilers.py
+++ b/functorch/functorch/_src/compilers.py
@@ -34,16 +34,9 @@
Torch scripted model.
"""
for node in fx_g.graph.nodes:
- if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty):
- if node.args[1] == []:
- args = list(node.args)
- args[1] = [1]
- node.args = tuple(args)
- elif node.target is torch.ops.aten.masked_fill and node.args[2] == float("-inf"):
- # Fx graph to torchscript fails for -inf
- args = list(node.args)
- args[2] = -3.403 * 10**37
- node.args = tuple(args)
+ if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
+ and len(node.kwargs) == 1 and 'dtype' in node.kwargs):
+ node.target = torch.ops.aten.to
for node in fx_g.graph.nodes:
new_kwargs = {}
@@ -55,15 +48,6 @@
fx_g.graph.lint()
- # print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
- # Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
- for i in range(1000):
- attr = f"_tensor_constant{i}"
- if hasattr(fx_g, attr):
- setattr(fx_g, attr, getattr(fx_g, attr).cuda())
- else:
- break
-
fx_g.recompile()
f = torch.jit.script(fx_g)
diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py
index df55629..5fe0aff 100644
--- a/functorch/functorch/_src/python_key.py
+++ b/functorch/functorch/_src/python_key.py
@@ -7,3 +7,4 @@
from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
pythonkey_decompose = decompose
+PythonTensor = ProxyTensor