[functorch] Removed some obsolete workarounds in ts_compile and added a new one (pytorch/functorch#875)

diff --git a/functorch/functorch/_src/compilers.py b/functorch/functorch/_src/compilers.py
index 10fe42a..74b71af 100644
--- a/functorch/functorch/_src/compilers.py
+++ b/functorch/functorch/_src/compilers.py
@@ -34,16 +34,9 @@
         Torch scripted model.
     """
     for node in fx_g.graph.nodes:
-        if node.target in (torch.ops.aten.new_zeros, torch.ops.aten.new_empty):
-            if node.args[1] == []:
-                args = list(node.args)
-                args[1] = [1]
-                node.args = tuple(args)
-        elif node.target is torch.ops.aten.masked_fill and node.args[2] == float("-inf"):
-            # Fx graph to torchscript fails for -inf
-            args = list(node.args)
-            args[2] = -3.403 * 10**37
-            node.args = tuple(args)
+        if (node.target == torch.ops.aten._to_copy and len(node.args) == 1
+           and len(node.kwargs) == 1 and 'dtype' in node.kwargs):
+            node.target = torch.ops.aten.to
 
     for node in fx_g.graph.nodes:
         new_kwargs = {}
@@ -55,15 +48,6 @@
 
     fx_g.graph.lint()
 
-    # print(set([i.target for i in fx_g.graph.nodes if i.op == 'call_function']))
-    # Works around this NVFuser issue: https://github.com/csarofeen/pytorch/issues/1311
-    for i in range(1000):
-        attr = f"_tensor_constant{i}"
-        if hasattr(fx_g, attr):
-            setattr(fx_g, attr, getattr(fx_g, attr).cuda())
-        else:
-            break
-
     fx_g.recompile()
 
     f = torch.jit.script(fx_g)
diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py
index df55629..5fe0aff 100644
--- a/functorch/functorch/_src/python_key.py
+++ b/functorch/functorch/_src/python_key.py
@@ -7,3 +7,4 @@
 from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
 
 pythonkey_decompose = decompose
+PythonTensor = ProxyTensor