Revert D20010383: [jit] Unify augmented assign handling

Test Plan: revert-hammer

Differential Revision:
D20010383

Original commit changeset: 52e559ce907e

fbshipit-source-id: 7ca938070d5e98c91e7a7b8485a3c1e790c3ceb2
diff --git a/test/test_jit.py b/test/test_jit.py
index f42b538..6119dbb 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -5418,73 +5418,6 @@
         with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
             torch.jit.script(A())
 
-    def test_var_aug_assign(self):
-        @torch.jit.script
-        class SomeNonAddableClass(object):
-            def __init__(self):
-                self.num = 99
-
-            def __eq__(self, other):
-                # type: (SomeNonAddableClass) -> bool
-                return self.num == other.num
-
-        with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"):
-            @torch.jit.script
-            def fn():
-                a = SomeNonAddableClass()
-                a += SomeNonAddableClass()
-                return a
-
-        @torch.jit.script
-        class SomeClass(object):
-            def __init__(self):
-                self.num = 99
-
-            def __iadd__(self, x):
-                # type: (int)
-                self.num += x
-                return self
-
-            def __eq__(self, other):
-                # type: (SomeClass) -> bool
-                return self.num == other.num
-
-        @torch.jit.script
-        class SomeOutOfPlaceClass(object):
-            def __init__(self):
-                self.num = 99
-
-            def __add__(self, x):
-                # type: (int)
-                self.num = x
-                return self
-
-            def __eq__(self, other):
-                # type: (SomeClass) -> bool
-                return self.num == other.num
-
-        def fn():
-            a = SomeClass()
-            a_copy = a
-            a += 20
-            assert a is a_copy
-            b = SomeOutOfPlaceClass()
-            b_copy = b
-            b += 99
-            assert b is b_copy
-            c = [1, 2, 3]
-            c_copy = c
-            c *= 2
-            assert c is c_copy
-            c += [4, 5, 6]
-            d = torch.ones(2, 2)
-            d_copy = d
-            d += torch.ones(2, 2)
-            assert d is d_copy
-            return a, b, c, d
-
-        self.checkScript(fn, [])
-
     def test_nested_list_construct(self):
         def foo():
             return [[4]] + [[4, 5]]
diff --git a/torch/csrc/jit/script/ir_emitter.cpp b/torch/csrc/jit/script/ir_emitter.cpp
index 42cf7e3..ea677d0 100644
--- a/torch/csrc/jit/script/ir_emitter.cpp
+++ b/torch/csrc/jit/script/ir_emitter.cpp
@@ -1852,21 +1852,7 @@
     const auto lhsValue =
         lhsSugaredVar->attr(lhs.range(), method, lhs.selector().name())
             ->asValue(lhs.range(), method);
-    auto result = emitAugAssignmentHelper(stmt, lhsValue);
-    lhsSugaredVar->setAttr(stmt.range(), method, lhs.selector().name(), result);
-  }
-
-  void emitAugAssignmentToVar(const AugAssign& stmt) {
-    const auto lhs = Var(stmt.lhs());
-    auto lhsValue = emitExpr(lhs);
-    auto result = emitAugAssignmentHelper(stmt, lhsValue);
-    environment_stack->setVar(lhs.range(), lhs.name().name(), result);
-  }
-
-  Value* emitAugAssignmentHelper(
-      const AugAssign& stmt,
-      Value* lhs) {
-    if (lhs->type()->kind() == TypeKind::ClassType) {
+  if (lhsValue->type()->kind() == TypeKind::ClassType) {
       // Call `__iadd__` so updates happen in place on class types
       // https://docs.python.org/3/reference/datamodel.html#object.__iadd__
       std::string in_place_method_name;
@@ -1877,7 +1863,7 @@
 
       // Determine whether to use __iadd__ or __add__ (use __add__ only if
       // __iadd__ is not present)
-      auto type = lhs->type()->expect<ClassType>();
+      auto type = lhsValue->type()->expect<ClassType>();
       std::string magic_method_name;
       if (type->getMethod(in_place_method_name)) {
         magic_method_name = in_place_method_name;
@@ -1890,21 +1876,60 @@
             << out_of_place_method_name << " method";
       }
 
+      // Insert call to the magic method
+      MethodValue method_value(lhsValue, magic_method_name);
+      auto result = method_value.call(stmt.range(), method, {rhs}, {}, 0)
+                        ->asValue(stmt.range(), method);
+
       // x += y is equivalent to x = x.__iadd__(y) or x = x.__add__(y) if
-      // __iadd__ is not present
-      return MethodValue(lhs, magic_method_name)
-          .call(stmt.range(), method, {rhs}, {}, 0)
-          ->asValue(stmt.range(), method);
+      // __iadd__ is not present, so set the value to the function's return
+      // value
+      lhsSugaredVar->setAttr(
+          stmt.range(), method, lhs.selector().name(), result);
     } else {
       const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()))
                            .value(*method.graph());
-      return emitBuiltinCall(
+      auto rhsValue = emitBuiltinCall(
           stmt.range(),
           *method.graph(),
-          getAugOp(stmt, lhs->type()),
-          /*inputs=*/{lhs, rhs},
-          /*attributes=*/{},
+          getAugOp(stmt, lhsValue->type()),
+          {lhsValue, rhs},
+          {},
           /*self=*/c10::nullopt);
+      lhsSugaredVar->setAttr(
+          stmt.range(), method, lhs.selector().name(), rhsValue);
+    }
+  }
+
+  void emitAugAssignmentToVar(const AugAssign& stmt) {
+    const auto lhs = Var(stmt.lhs());
+    const auto lhsValue = environment_stack->getSugaredVar(lhs.name())
+                              ->asValue(lhs.range(), method);
+    auto lhsType = lhsValue->type();
+    if (lhsType->isSubtypeOf(TensorType::get()) ||
+        lhsType->cast<c10::ListType>()) {
+      // for tensors, emit the corresponding in-place op
+      const auto rhs = NamedValue(stmt.rhs().range(), emitExpr(stmt.rhs()));
+      const auto self = NamedValue(stmt.lhs().range(), "self", lhsValue);
+      const auto output = emitBuiltinCall(
+          stmt.range(),
+          *method.graph(),
+          getAugOp(stmt, lhsValue->type()),
+          {rhs},
+          {},
+          self);
+
+      environment_stack->setVar(lhs.range(), lhs.name().name(), output);
+    } else {
+      // for primitive types, desugar into a simple assignment
+      //   e.g. foo += 1 becomes foo.2 = foo + 1
+      Ident lhs = Var(stmt.lhs()).name();
+      Expr expr = BinOp::create(
+          stmt.range(),
+          stmt.aug_op(),
+          Var::create(lhs.range(), lhs),
+          stmt.rhs());
+      environment_stack->setVar(lhs.range(), lhs.name(), emitExpr(expr));
     }
   }