stochastic_depth support (#71536)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71536
as titled
Test Plan: buck test mode/dev-nosan caffe2/test:test_fx_acc_tracer -- test_stochastic_depth
Reviewed By: yinghai
Differential Revision: D33668640
fbshipit-source-id: 3a8e6fc04b5529373d9dc77fef4514e9d01bf088
(cherry picked from commit e346d1d7a306f64334146a1a5c107c3db3ce8cd8)
diff --git a/test/fx_acc/test_acc_tracer.py b/test/fx_acc/test_acc_tracer.py
index 7949758..675cfd1 100644
--- a/test/fx_acc/test_acc_tracer.py
+++ b/test/fx_acc/test_acc_tracer.py
@@ -63,7 +63,7 @@
return self._torch_op(a, *self._args, **self._kwargs)
m = TestModule(torch_op, args, kwargs)
-
+ m.eval()
a = torch.randn(*input_shape)
traced = acc_tracer.trace(m, [a])
ph_a = acc_op_node = None
@@ -1227,6 +1227,16 @@
input_shape=(1, 2, 3),
)
+ def test_stochastic_depth(self):
+ self._make_acc_op_function_test(
+ None,
+ lambda x, p, mode, training: torchvision.ops.stochastic_depth(x, p=p, mode=mode, training=training),
+ input_shape=(1, 2, 3),
+ p=0.5,
+ mode="row",
+ training=False,
+ )
+
def test_hardsigmoid(self):
self._make_acc_op_function_test(
acc_ops.hardsigmoid,
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
index a3907e2..ea524f4 100644
--- a/torch/fx/experimental/fx_acc/acc_ops.py
+++ b/torch/fx/experimental/fx_acc/acc_ops.py
@@ -1,5 +1,6 @@
# encoding: utf-8
import operator
+import warnings
import torch # isort:skip
from typing import Sequence, Optional, List, cast
@@ -461,6 +462,20 @@
"""
return node.kwargs["input"]
+try:
+ from torchvision.ops import stochastic_depth
+except Exception as e:
+ warnings.warn(f"Unable to import torchvision related libraries.: {e}")
+else:
+ @register_custom_acc_mapper_fn(
+ op_and_target=("call_function", stochastic_depth),
+ arg_replacement_tuples=[("input", "input")],
+ )
+ def stochastic_depth_mapper(node: torch.fx.Node, mod: nn.Module):
+ """
+ Remove dropout node and directly map its input to output.
+ """
+ return node.kwargs["input"]
@register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
@register_acc_op_mapping(