Fix bug in script for where (#12385)
Summary:
Where is declared as:
```
where(Tensor condition, Tensor self, Tensor other)
```
Previously the compiler assumed that self must be the first argument.
But this is not true in practice for `where` and for a few other exceptions.
This changes the compiler to take an explicit self argument which gets matched
to the `self` that appears in the schema.
Note that this requires renaming a variant of pow, which referred to
an exponent Tensor as `self` because otherwise that would cause `t^3`
to match against `t` being the exponent.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12385
Differential Revision: D10364658
Pulled By: zdevito
fbshipit-source-id: 39e030c6912dd19b4b0b9e35fcbabc167b4cc255
diff --git a/test/test_jit.py b/test/test_jit.py
index 90437ba..fa58964 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -6221,6 +6221,12 @@
self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+ def test_where_method(self):
+ def fn(x, y):
+ return x.where(x > 0.0, y)
+
+ self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
+
def test_reassign_module_lhs(self):
with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
' not a first-class value. Only reassignments to first-class values are allowed'):