idtt + zch distributed inference (#35763)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35763

Adds inference function and test for ScatterAssign

Test Plan: Updated unit test

Reviewed By: yyetim, shunting1986

Differential Revision: D20501079

fbshipit-source-id: 7ec6ef0127a151250dd699c90c2b80c35cfb1fe4
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index cb67503..9b0ad56 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -345,6 +345,12 @@
     .NumInputs(3)
     .NumOutputs(1)
     .EnforceInplace({{0, 0}})
+    .TensorInferenceFunction([](const OperatorDef& /* unused */,
+                                const vector<TensorShape>& in) {
+      vector<TensorShape> out(1);
+      out[0] = in[0];
+      return out;
+    })
     .SetDoc(R"DOC(
 Update slices of the tensor in-place by overriding current value.
 
diff --git a/caffe2/python/operator_test/sparse_ops_test.py b/caffe2/python/operator_test/sparse_ops_test.py
index 50c127c..19030fb 100644
--- a/caffe2/python/operator_test/sparse_ops_test.py
+++ b/caffe2/python/operator_test/sparse_ops_test.py
@@ -76,7 +76,7 @@
         ind = np.random.choice(first_dim, index_dim,
                                replace=False).astype(ind_type)
         x = (rand_array(index_dim, *extra_dims) * 10).astype(data_type)
-        self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3)
+        self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3, ensure_outputs_are_inferred=True)
 
 if __name__ == "__main__":
     import unittest