Fix for ppc64le jit graph difference in sigmoid backward, see #10726 (#11579)
Summary:
As reported in Issue #10726, the jit compiler, when running on ppc64le, may produce an isomorphic output but fail a diff test against the expected output file. The expected output file is created from a test that was ran on x86_64. This ensures that if ppc64le test output is different, the output is instead compared to an expected output file created when the test is run on a ppc64le system.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11579
Differential Revision: D10080890
Pulled By: soumith
fbshipit-source-id: 7249bf6b5dfa7c853368a3688a982bc9ed642bc9
diff --git a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
index cbdbc74..3674a3f 100644
--- a/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
+++ b/test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
@@ -56,8 +56,8 @@
%23 : Float(*, *) = aten::neg(%3)
%24 : int = prim::Constant[value=1]()
%25 : Float(*, *) = aten::add(%23, %24, %24)
- %26 : Float(*, *) = aten::mul(%19, %3)
- %27 : Float(*, *) = aten::mul(%26, %25)
+ %26 : Float(*, *) = aten::mul(%25, %3)
+ %27 : Float(*, *) = aten::mul(%26, %19)
%28 : Float(*, *) = aten::mul(%2, %2)
%29 : Float(*, *) = aten::neg(%28)
%30 : int = prim::Constant[value=1]()
@@ -66,13 +66,13 @@
%33 : Float(*, *) = aten::neg(%1)
%34 : int = prim::Constant[value=1]()
%35 : Float(*, *) = aten::add(%33, %34, %34)
- %36 : Float(*, *) = aten::mul(%22, %1)
- %37 : Float(*, *) = aten::mul(%36, %35)
+ %36 : Float(*, *) = aten::mul(%35, %1)
+ %37 : Float(*, *) = aten::mul(%36, %22)
%38 : Float(*, *) = aten::neg(%0)
%39 : int = prim::Constant[value=1]()
%40 : Float(*, *) = aten::add(%38, %39, %39)
- %41 : Float(*, *) = aten::mul(%20, %0)
- %42 : Float(*, *) = aten::mul(%41, %40)
+ %41 : Float(*, *) = aten::mul(%40, %0)
+ %42 : Float(*, *) = aten::mul(%41, %20)
%43 : Float(*, *) = prim::FusedConcat[dim=1](%42, %37, %32, %27)
return (%43, %18);
}
diff --git a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
index b0dc856..fb14a35 100644
--- a/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
+++ b/test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
@@ -62,8 +62,8 @@
%20 : Float(*, *) = aten::neg(%3)
%21 : int = prim::Constant[value=1]()
%22 : Float(*, *) = aten::add(%20, %21, %21)
- %23 : Float(*, *) = aten::mul(%8, %3)
- %24 : Float(*, *) = aten::mul(%23, %22)
+ %23 : Float(*, *) = aten::mul(%22, %3)
+ %24 : Float(*, *) = aten::mul(%23, %8)
%25 : Float(*, *) = aten::mul(%2, %2)
%26 : Float(*, *) = aten::neg(%25)
%27 : int = prim::Constant[value=1]()
@@ -72,13 +72,13 @@
%30 : Float(*, *) = aten::neg(%1)
%31 : int = prim::Constant[value=1]()
%32 : Float(*, *) = aten::add(%30, %31, %31)
- %33 : Float(*, *) = aten::mul(%19, %1)
- %34 : Float(*, *) = aten::mul(%33, %32)
+ %33 : Float(*, *) = aten::mul(%32, %1)
+ %34 : Float(*, *) = aten::mul(%33, %19)
%35 : Float(*, *) = aten::neg(%0)
%36 : int = prim::Constant[value=1]()
%37 : Float(*, *) = aten::add(%35, %36, %36)
- %38 : Float(*, *) = aten::mul(%17, %0)
- %39 : Float(*, *) = aten::mul(%38, %37)
+ %38 : Float(*, *) = aten::mul(%37, %0)
+ %39 : Float(*, *) = aten::mul(%38, %17)
%40 : Float(*, *) = prim::FusedConcat[dim=1](%39, %34, %29, %24)
return (%40);
}
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index 009bf68..80e196c 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -166,7 +166,10 @@
return {nullptr, -grads.at(0) * inputs.at(0) / (inputs.at(1) * inputs.at(1))};
} else if (node->matches("aten::sigmoid(Tensor self) -> Tensor")) {
- return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))};
+ // TODO: The order of operations matter in this case. This
+ // works for ppc64le and x86_64. Need to look at why the
+ // order matters.
+ return {(1 - outputs.at(0)) * outputs.at(0) * grads.at(0)};
} else if (node->matches("aten::tanh(Tensor self) -> Tensor")) {
return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};