idtt + zch distributed inference (#35763)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35763
Adds inference function and test for ScatterAssign
Test Plan: Updated unit test
Reviewed By: yyetim, shunting1986
Differential Revision: D20501079
fbshipit-source-id: 7ec6ef0127a151250dd699c90c2b80c35cfb1fe4
diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc
index cb67503..9b0ad56 100644
--- a/caffe2/operators/utility_ops.cc
+++ b/caffe2/operators/utility_ops.cc
@@ -345,6 +345,12 @@
.NumInputs(3)
.NumOutputs(1)
.EnforceInplace({{0, 0}})
+ .TensorInferenceFunction([](const OperatorDef& /* unused */,
+ const vector<TensorShape>& in) {
+ vector<TensorShape> out(1);
+ out[0] = in[0];
+ return out;
+ })
.SetDoc(R"DOC(
Update slices of the tensor in-place by overriding current value.
diff --git a/caffe2/python/operator_test/sparse_ops_test.py b/caffe2/python/operator_test/sparse_ops_test.py
index 50c127c..19030fb 100644
--- a/caffe2/python/operator_test/sparse_ops_test.py
+++ b/caffe2/python/operator_test/sparse_ops_test.py
@@ -76,7 +76,7 @@
ind = np.random.choice(first_dim, index_dim,
replace=False).astype(ind_type)
x = (rand_array(index_dim, *extra_dims) * 10).astype(data_type)
- self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3)
+ self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3, ensure_outputs_are_inferred=True)
if __name__ == "__main__":
import unittest