add constraints for layer_norm function (#82597)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82597
Approved by: https://github.com/jansel
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index 1b792b9..e504070 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -428,6 +428,37 @@
         self.assertEqual(s.model()[output].arg(0).arg(1), b[0])
         self.assertEqual(s.model()[output].arg(1).arg(1), b[1])
 
+
+    def test_layer_norm_functional(self):
+
+        class BasicBlock(torch.nn.Module):
+            def __init__(self):
+                super(BasicBlock, self).__init__()
+
+            def forward(self, x: Dyn):
+                return torch.nn.functional.layer_norm(x, (1024,))
+
+        ast_rewriter = RewritingTracer()
+        graph = ast_rewriter.trace(BasicBlock())
+        traced = GraphModule(ast_rewriter.root, graph, "gm")
+        transformed = transform_all_constraints(traced, counter=0)
+
+        s = z3.Solver()
+        s.add(transformed)
+        self.assertEquals(s.check(), z3.sat)
+
+        # make the output a size 1 tensor which should result
+        # in the migration of the input
+
+        b = BasicBlock().forward(torch.rand(1024))
+        input = z3.Const(1, tensor_type)
+        output = z3.Const(2, tensor_type)
+        s.add(output == tensor_type.tensor1(D(1, 1024)))
+        s.check()
+        self.assertEqual(s.model()[input], s.model()[output])
+        # input shape = output shape
+        self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1))
+
     def test_ne_int_long_type_as(self):
 
         class BasicBlock(torch.nn.Module):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index 482daf0..116c70d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -841,6 +841,15 @@
     return [Disj([both_dyn, *const])], counter
 
 
+@register_inference_rule(torch.nn.functional.layer_norm)
+def layer_norm_functional(n: Node, symbols, constraints, counter):
+    """
+    We generate the constraint: input = output
+    """
+    assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
+
+
 @register_inference_rule(torch.nn.LayerNorm)
 def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
     """
@@ -848,6 +857,10 @@
     Input should be consistent with the normalized_shape
     """
     assert isinstance(n.args[0], Node)
+    return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
+
+
+def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
     output, counter = gen_tvar(counter)
     symbols[n] = output
     input = symbols[n.args[0]]
@@ -864,16 +877,11 @@
 
         c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
                            BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
-                          add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) +
+                          add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
                           nat_constraints)
         c2.append(c_tensor_i)
-
-
     return [Disj([c1, Disj(c2)])], counter
 
-    # return [BinConstraintT(input, output, op_eq),
-    #         BinConstraintT(input, normalized_shape, op_consistency)], counter
-
 @register_inference_rule(torch.nn.Dropout)
 @register_inference_rule(torch.nn.ReLU)
 def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):