add constraints for layer_norm function (#82597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82597
Approved by: https://github.com/jansel
diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py
index 1b792b9..e504070 100644
--- a/test/fx/test_z3_gradual_types.py
+++ b/test/fx/test_z3_gradual_types.py
@@ -428,6 +428,37 @@
self.assertEqual(s.model()[output].arg(0).arg(1), b[0])
self.assertEqual(s.model()[output].arg(1).arg(1), b[1])
+
+ def test_layer_norm_functional(self):
+
+ class BasicBlock(torch.nn.Module):
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+
+ def forward(self, x: Dyn):
+ return torch.nn.functional.layer_norm(x, (1024,))
+
+ ast_rewriter = RewritingTracer()
+ graph = ast_rewriter.trace(BasicBlock())
+ traced = GraphModule(ast_rewriter.root, graph, "gm")
+ transformed = transform_all_constraints(traced, counter=0)
+
+ s = z3.Solver()
+ s.add(transformed)
+ self.assertEquals(s.check(), z3.sat)
+
+ # make the output a size 1 tensor which should result
+ # in the migration of the input
+
+ b = BasicBlock().forward(torch.rand(1024))
+ input = z3.Const(1, tensor_type)
+ output = z3.Const(2, tensor_type)
+ s.add(output == tensor_type.tensor1(D(1, 1024)))
+ s.check()
+ self.assertEqual(s.model()[input], s.model()[output])
+ # input shape = output shape
+ self.assertEqual(b.shape[0], s.model()[input].arg(0).arg(1))
+
def test_ne_int_long_type_as(self):
class BasicBlock(torch.nn.Module):
diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
index 482daf0..116c70d 100644
--- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
+++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py
@@ -841,6 +841,15 @@
return [Disj([both_dyn, *const])], counter
+@register_inference_rule(torch.nn.functional.layer_norm)
+def layer_norm_functional(n: Node, symbols, constraints, counter):
+ """
+ We generate the constraint: input = output
+ """
+ assert isinstance(n.args[0], Node)
+ return gen_layer_norm_constraints(n, n.args[1], symbols, counter)
+
+
@register_inference_rule(torch.nn.LayerNorm)
def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter):
"""
@@ -848,6 +857,10 @@
Input should be consistent with the normalized_shape
"""
assert isinstance(n.args[0], Node)
+ return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter)
+
+
+def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter):
output, counter = gen_tvar(counter)
symbols[n] = output
input = symbols[n.args[0]]
@@ -864,16 +877,11 @@
c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq),
BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] +
- add_layer_norm_constraints(new_dims_rhs, list(module_instance.normalized_shape)) +
+ add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) +
nat_constraints)
c2.append(c_tensor_i)
-
-
return [Disj([c1, Disj(c2)])], counter
- # return [BinConstraintT(input, output, op_eq),
- # BinConstraintT(input, normalized_shape, op_consistency)], counter
-
@register_inference_rule(torch.nn.Dropout)
@register_inference_rule(torch.nn.ReLU)
def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter):