ScatterAssign int types
Summary: Closes https://github.com/caffe2/caffe2/pull/1357
Reviewed By: dzhulgakov
Differential Revision: D6107036
Pulled By: bddppq
fbshipit-source-id: 9278dae988c3c0656b4e4fd08bf7ca1e2eec3348
diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h
index 0b130d9..ec6aef8 100644
--- a/caffe2/operators/utility_ops.h
+++ b/caffe2/operators/utility_ops.h
@@ -606,10 +606,18 @@
&ScatterAssignOp::DoRun<int32_t, float>},
{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16},
&ScatterAssignOp::DoRun<int32_t, float16>},
+ {{TensorProto_DataType_INT32, TensorProto_DataType_INT32},
+ &ScatterAssignOp::DoRun<int32_t, int32_t>},
+ {{TensorProto_DataType_INT32, TensorProto_DataType_INT64},
+ &ScatterAssignOp::DoRun<int32_t, int64_t>},
{{TensorProto_DataType_INT64, TensorProto_DataType_FLOAT},
&ScatterAssignOp::DoRun<int64_t, float>},
{{TensorProto_DataType_INT64, TensorProto_DataType_FLOAT16},
- &ScatterAssignOp::DoRun<int64_t, float16>}}) {}
+ &ScatterAssignOp::DoRun<int64_t, float16>},
+ {{TensorProto_DataType_INT64, TensorProto_DataType_INT32},
+ &ScatterAssignOp::DoRun<int64_t, int32_t>},
+ {{TensorProto_DataType_INT64, TensorProto_DataType_INT64},
+ &ScatterAssignOp::DoRun<int64_t, int64_t>}}) {}
bool RunOnDevice() override {
const auto& data = Input(DATA);
diff --git a/caffe2/python/operator_test/sparse_ops_test.py b/caffe2/python/operator_test/sparse_ops_test.py
index 0990caf..e1a6f4d 100644
--- a/caffe2/python/operator_test/sparse_ops_test.py
+++ b/caffe2/python/operator_test/sparse_ops_test.py
@@ -70,10 +70,11 @@
@given(first_dim=st.integers(1, 20),
index_dim=st.integers(1, 10),
extra_dims=st.lists(st.integers(1, 4), min_size=0, max_size=3),
+ data_type=st.sampled_from([np.float16, np.float32, np.int32, np.int64]),
ind_type=st.sampled_from([np.int32, np.int64]),
**hu.gcs_cpu_only)
def testScatterAssign(
- self, first_dim, index_dim, extra_dims, ind_type, gc, dc):
+ self, first_dim, index_dim, extra_dims, data_type, ind_type, gc, dc):
op = core.CreateOperator('ScatterAssign',
['data', 'indices', 'slices'], ['data'])
def ref(d, ind, x):
@@ -84,10 +85,10 @@
# let's have indices unique
if first_dim < index_dim:
first_dim, index_dim = index_dim, first_dim
- d = rand_array(first_dim, *extra_dims)
+ d = (rand_array(first_dim, *extra_dims) * 10).astype(data_type)
ind = np.random.choice(first_dim, index_dim,
replace=False).astype(ind_type)
- x = rand_array(index_dim, *extra_dims)
+ x = (rand_array(index_dim, *extra_dims) * 10).astype(data_type)
self.assertReferenceChecks(gc, op, [d, ind, x], ref, threshold=1e-3)
if __name__ == "__main__":