Revert D20010383: [jit] Unify augmented assign handling
Test Plan: revert-hammer
Differential Revision:
D20010383
Original commit changeset: 52e559ce907e
fbshipit-source-id: 7ca938070d5e98c91e7a7b8485a3c1e790c3ceb2
diff --git a/test/test_jit.py b/test/test_jit.py
index f42b538..6119dbb 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5418,73 +5418,6 @@
with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
torch.jit.script(A())
- def test_var_aug_assign(self):
- @torch.jit.script
- class SomeNonAddableClass(object):
- def __init__(self):
- self.num = 99
-
- def __eq__(self, other):
- # type: (SomeNonAddableClass) -> bool
- return self.num == other.num
-
- with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
- @torch.jit.script
- def fn():
- a = SomeNonAddableClass()
- a += SomeNonAddableClass()
- return a
-
- @torch.jit.script
- class SomeClass(object):
- def __init__(self):
- self.num = 99
-
- def __iadd__(self, x):
- # type: (int)
- self.num += x
- return self
-
- def __eq__(self, other):
- # type: (SomeClass) -> bool
- return self.num == other.num
-
- @torch.jit.script
- class SomeOutOfPlaceClass(object):
- def __init__(self):
- self.num = 99
-
- def __add__(self, x):
- # type: (int)
- self.num = x
- return self
-
- def __eq__(self, other):
- # type: (SomeClass) -> bool
- return self.num == other.num
-
- def fn():
- a = SomeClass()
- a_copy = a
- a += 20
- assert a is a_copy
- b = SomeOutOfPlaceClass()
- b_copy = b
- b += 99
- assert b is b_copy
- c = [1, 2, 3]
- c_copy = c
- c *= 2
- assert c is c_copy
- c += [4, 5, 6]
- d = torch.ones(2, 2)
- d_copy = d
- d += torch.ones(2, 2)
- assert d is d_copy
- return a, b, c, d
-
- self.checkScript(fn, [])
-
def test_nested_list_construct(self):
def foo():
return [[4]] + [[4, 5]]
diff --git a/torch/csrc/jit/script/ir_emitter.cpp b/torch/csrc/jit/script/ir_emitter.cpp
index 42cf7e3..ea677d0 100644
--- a/torch/csrc/jit/script/ir_emitter.cpp
+++ b/torch/csrc/jit/script/ir_emitter.cpp
@@ -1852,21 +1852,7 @@
const auto lhsValue =
lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
->asValue(lhs.range(), method);
- auto result = emitAugAssignmentHelper(stmt, lhsValue);
- lhsSugaredVar->setAttr(stmt.range(), method, lhs.selector().name(), result);
- }
-
- void emitAugAssignmentToVar(const AugAssign& stmt) {
- const auto lhs = Var(stmt.lhs());
- auto lhsValue = emitExpr(lhs);
- auto result = emitAugAssignmentHelper(stmt, lhsValue);
- environment_stack->setVar(lhs.range(), lhs.name().name(), result);
- }
-
- Value* emitAugAssignmentHelper(
- const AugAssign& stmt,
- Value* lhs) {
- if (lhs->type()->kind() == TypeKind::ClassType) {
+ if (lhsValue->type()->kind() == TypeKind::ClassType) {
// Call `__iadd__` so updates happen in place on class types
// https://docs.python.org/3/reference/datamodel.html#object.__iadd__
std::string in_place_method_name;
@@ -1877,7 +1863,7 @@
// Determine whether to use __iadd__ or __add__ (use __add__ only if
// __iadd__ is not present)
- auto type = lhs->type()->expect<ClassType>();
+ auto type = lhsValue->type()->expect<ClassType>();
std::string magic_method_name;
if (type->getMethod(in_place_method_name)) {
magic_method_name = in_place_method_name;
@@ -1890,21 +1876,60 @@
<< out_of_place_method_name << " method";
}
+ // Insert call to the magic method
+ MethodValue method_value(lhsValue, magic_method_name);
+ auto result = method_value.call(stmt.range(), method, {rhs}, {}, 0)
+ ->asValue(stmt.range(), method);
+
// x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
- // __iadd__ is not present
- return MethodValue(lhs, magic_method_name)
- .call(stmt.range(), method, {rhs}, {}, 0)
- ->asValue(stmt.range(), method);
+ // __iadd__ is not present, so set the value to the function's return
+ // value
+ lhsSugaredVar->setAttr(
+ stmt.range(), method, lhs.selector().name(), result);
} else {
const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()))
.value(*method.graph());
- return emitBuiltinCall(
+ auto rhsValue = emitBuiltinCall(
stmt.range(),
*method.graph(),
- getAugOp(stmt, lhs->type()),
- /*inputs=*/{lhs, rhs},
- /*attributes=*/{},
+ getAugOp(stmt, lhsValue->type()),
+ {lhsValue, rhs},
+ {},
/*self=*/c10::nullopt);
+ lhsSugaredVar->setAttr(
+ stmt.range(), method, lhs.selector().name(), rhsValue);
+ }
+ }
+
+ void emitAugAssignmentToVar(const AugAssign& stmt) {
+ const auto lhs = Var(stmt.lhs());
+ const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
+ ->asValue(lhs.range(), method);
+ auto lhsType = lhsValue->type();
+ if (lhsType->isSubtypeOf(TensorType::get()) ||
+ lhsType->cast<c10::ListType>()) {
+ // for tensors, emit the corresponding in-place op
+ const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
+ const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
+ const auto output = emitBuiltinCall(
+ stmt.range(),
+ *method.graph(),
+ getAugOp(stmt, lhsValue->type()),
+ {rhs},
+ {},
+ self);
+
+ environment_stack->setVar(lhs.range(), lhs.name().name(), output);
+ } else {
+ // for primitive types, desugar into a simple assignment
+ // e.g. foo += 1 becomes foo.2 = foo + 1
+ Ident lhs = Var(stmt.lhs()).name();
+ Expr expr = BinOp::create(
+ stmt.range(),
+ stmt.aug_op(),
+ Var::create(lhs.range(), lhs),
+ stmt.rhs());
+ environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
}
}