stochastic_depth support (#71536)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71536

as titled

Test Plan: buck test mode/dev-nosan caffe2/test:test_fx_acc_tracer -- test_stochastic_depth

Reviewed By: yinghai

Differential Revision: D33668640

fbshipit-source-id: 3a8e6fc04b5529373d9dc77fef4514e9d01bf088
(cherry picked from commit e346d1d7a306f64334146a1a5c107c3db3ce8cd8)
diff --git a/test/fx_acc/test_acc_tracer.py b/test/fx_acc/test_acc_tracer.py
index 7949758..675cfd1 100644
--- a/test/fx_acc/test_acc_tracer.py
+++ b/test/fx_acc/test_acc_tracer.py
@@ -63,7 +63,7 @@
                 return self._torch_op(a, *self._args, **self._kwargs)
 
         m = TestModule(torch_op, args, kwargs)
-
+        m.eval()
         a = torch.randn(*input_shape)
         traced = acc_tracer.trace(m, [a])
         ph_a = acc_op_node = None
@@ -1227,6 +1227,16 @@
             input_shape=(1, 2, 3),
         )
 
+    def test_stochastic_depth(self):
+        self._make_acc_op_function_test(
+            None,
+            lambda x, p, mode, training: torchvision.ops.stochastic_depth(x, p=p, mode=mode, training=training),
+            input_shape=(1, 2, 3),
+            p=0.5,
+            mode="row",
+            training=False,
+        )
+
     def test_hardsigmoid(self):
         self._make_acc_op_function_test(
             acc_ops.hardsigmoid,
diff --git a/torch/fx/experimental/fx_acc/acc_ops.py b/torch/fx/experimental/fx_acc/acc_ops.py
index a3907e2..ea524f4 100644
--- a/torch/fx/experimental/fx_acc/acc_ops.py
+++ b/torch/fx/experimental/fx_acc/acc_ops.py
@@ -1,5 +1,6 @@
 # encoding: utf-8
 import operator
+import warnings
 
 import torch  # isort:skip
 from typing import Sequence, Optional, List, cast
@@ -461,6 +462,20 @@
     """
     return node.kwargs["input"]
 
+try:
+    from torchvision.ops import stochastic_depth
+except Exception as e:
+    warnings.warn(f"Unable to import torchvision related libraries.: {e}")
+else:
+    @register_custom_acc_mapper_fn(
+        op_and_target=("call_function", stochastic_depth),
+        arg_replacement_tuples=[("input", "input")],
+    )
+    def stochastic_depth_mapper(node: torch.fx.Node, mod: nn.Module):
+        """
+        Remove dropout node and directly map its input to output.
+        """
+        return node.kwargs["input"]
 
 @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary)
 @register_acc_op_mapping(