Revert D28387764: Codegen inplace forward AD formula from out of place one if needed
Test Plan: revert-hammer
Differential Revision:
D28387764 (https://github.com/pytorch/pytorch/commit/22799621626881c216e3a4d00cd6ae4785483093)
Original commit changeset: 7bf3929dd214
fbshipit-source-id: 473851cf7527b0edf303fdb46b9c07357ff7f340
diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml
index 72543c3..1795fce 100644
--- a/tools/autograd/derivatives.yaml
+++ b/tools/autograd/derivatives.yaml
@@ -188,6 +188,11 @@
other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))
result: self_t + maybe_multiply(other_t, alpha)
+- name: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
+ self: handle_r_to_c(self.scalar_type(), grad)
+ other: handle_r_to_c(other.scalar_type(), maybe_multiply(grad, alpha.conj()))
+ result: self_t.add_(maybe_multiply(other_t, alpha))
+
- name: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
self: handle_r_to_c(self.scalar_type(), grad)
@@ -842,6 +847,11 @@
other: mul_tensor_backward(grad, self, other.scalar_type())
result: other_t * self_p + self_t * other_p
+- name: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
+ self: mul_tensor_backward(grad, other, self.scalar_type())
+ other: mul_tensor_backward(grad, self, other.scalar_type())
+ result: self_t.mul_(other_p.conj()).add_(other_t * (self_p / other_p).conj())
+
- name: mul.Scalar(Tensor self, Scalar other) -> Tensor
self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
diff --git a/tools/codegen/api/autograd.py b/tools/codegen/api/autograd.py
index 9035513..3931111 100644
--- a/tools/codegen/api/autograd.py
+++ b/tools/codegen/api/autograd.py
@@ -253,7 +253,7 @@
f"in-place function is not supported: {f.func}")
# For functions that have a single def for out-of-place and inplace (like abs())
- if info and info.forward_derivatives:
+ if info and info.forward_derivatives and is_exact_match:
forward_derivatives = info.forward_derivatives
if f.func.kind() == SchemaKind.inplace:
@@ -268,14 +268,9 @@
# replace "result" from the formula by self
def repl(m: Match[str]) -> str:
return f'{m.group(1)}self{m.group(2)}'
- formula = re.sub(IDENT_REGEX.format("result"), repl, fw_info.formula)
-
- if not is_exact_match:
- # Make sure that the forward grad is modified inplace
- formula = f"self_t_raw.defined() ? self_t_raw.copy_({formula}) : {formula}"
forward_derivatives = [ForwardDerivative(
- formula=formula,
+ formula=re.sub(IDENT_REGEX.format("result"), repl, fw_info.formula),
var_name="self",
var_type=fw_info.var_type,
required_inputs_fw_grad=fw_info.required_inputs_fw_grad,