add relu to jit and exp to autodiff (#8573)
diff --git a/test/test_jit.py b/test/test_jit.py
index 7647ca6..e63cce6 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -416,6 +416,28 @@
ge = self.checkTrace(f, (x, y))
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_relu(self):
+ def f(x, y):
+ return F.relu(x + .5 * y)
+
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+
+ ge = self.checkTrace(f, (x, y))
+
+ @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
+ @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
+ def test_exp(self):
+ def f(x, y):
+ return (x + .5 * y).exp()
+
+ x = torch.randn(4, 4, dtype=torch.float, device='cuda')
+ y = torch.randn(4, 4, dtype=torch.float, device='cuda')
+
+ ge = self.checkTrace(f, (x, y))
+
# TODO: adapt this test to check that GraphExecutor treats them differently
@unittest.skip("Need to be adjusted to Graph Executor")
def test_arg_configurations(self):
diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp
index e1a6d58..625a554 100644
--- a/torch/csrc/jit/autodiff.cpp
+++ b/torch/csrc/jit/autodiff.cpp
@@ -16,7 +16,8 @@
static std::unordered_set<Symbol> differentiable_kinds = {
aten::add, aten::sub, aten::mul, prim::Constant, prim::ReplaceIfUndef,
aten::sigmoid, aten::tanh, aten::mm, aten::chunk, aten::split, aten::t, aten::neg,
- aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as
+ aten::unsqueeze, aten::expand, aten::addmm, aten::gt, aten::lt, aten::eq, aten::ne, aten::ge, aten::le, aten::type_as,
+ aten::relu, aten::exp
};
// TODO: check this more generally via schema
// This check ensures that the `alpha` and `beta` attributes on this addmm
@@ -87,6 +88,10 @@
return {grads.at(0) * outputs.at(0) * (1 - outputs.at(0))};
case aten::tanh:
return {grads.at(0) * (1 - outputs.at(0) * outputs.at(0))};
+ case aten::relu:
+ return {grads.at(0) * (outputs.at(0) > at::Scalar(0)).type_as(outputs.at(0))};
+ case aten::exp:
+ return {grads.at(0) * (outputs.at(0))};
case aten::chunk:
case aten::split:
return {SymbolicVariable::cat(grads, node->i(attr::dim))};
diff --git a/torch/csrc/jit/fusion_compiler.cpp b/torch/csrc/jit/fusion_compiler.cpp
index 7b2f8a1..f1264d7 100644
--- a/torch/csrc/jit/fusion_compiler.cpp
+++ b/torch/csrc/jit/fusion_compiler.cpp
@@ -174,6 +174,7 @@
// unary
{aten::abs, "absf(${0})"},
{aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
+ {aten::relu, "${0} < 0 ? 0.f : ${0} "},
{aten::log, "logf(${0})"},
{aten::log10, "log10f(${0})"},
{aten::log1p, "log1pf(${0})"},
diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
index d359bc7..7f68be8 100644
--- a/torch/csrc/jit/passes/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -57,6 +57,7 @@
aten::neg,
aten::pow,
aten::reciprocal,
+ aten::relu,
aten::remainder,
aten::round,
aten::rsqrt,