[FX] Remove extraneous newlines at end of code (#50117)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50117
Test Plan: Imported from OSS
Reviewed By: ansley
Differential Revision: D25791847
Pulled By: jamesr66a
fbshipit-source-id: 9c0b296e117e6bcf69ed9624ad0b243fa3db0f76
diff --git a/test/test_fx.py b/test/test_fx.py
index 65d5aa3..2511adc 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -861,6 +861,11 @@
x, w = torch.rand(3, 4), torch.rand(4, 4)
self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
+ def test_empty_graph_codegen(self):
+ graph = torch.fx.Graph()
+ gm = torch.fx.GraphModule(torch.nn.Module(), graph)
+ self.assertEqual(gm(), None)
+
def test_sequential(self):
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
gm = torch.fx.symbolic_trace(m)
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index fd0087d..6e49367 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -693,13 +693,18 @@
import_strs = [f'import {name}' for name in sorted(modules_used)]
import_block = '\n'.join(import_strs)
+ if len(body) == 0:
+ # If the Graph has no non-placeholder nodes, no lines for the body
+ # have been emitted. To continue to have valid Python code, emit a
+ # single pass statement
+ body.append('pass\n')
+
code = ''.join(body)
- code = '\n'.join(' ' + line for line in code.split('\n')) + '\n'
+ code = '\n'.join(' ' + line for line in code.split('\n'))
fn_code = f"""\
{import_block}
def forward(self, {', '.join(free_vars)}){maybe_return_annotation[0]}:
-{code}
-"""
+{code}"""
return fn_code