create tensor based on provided datatype (#22468)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22468
as title
Reviewed By: ajauhri
Differential Revision: D15744503
fbshipit-source-id: 050b32dd7f135512385fc04f098c376c664211a9
diff --git a/benchmarks/operator_benchmark/benchmark_caffe2.py b/benchmarks/operator_benchmark/benchmark_caffe2.py
index 09a18a2..7cb5ac9 100644
--- a/benchmarks/operator_benchmark/benchmark_caffe2.py
+++ b/benchmarks/operator_benchmark/benchmark_caffe2.py
@@ -24,14 +24,19 @@
self.args = {}
self.user_provided_name = None
- # TODO: Add other dtype support
- def tensor(self, *shapes):
- """ A wapper function to create tensor (blob in caffe2) filled with random
- value. The name/label of the tensor is returned and it is available
+ def tensor(self, shapes, dtype='float32'):
+ """ A wapper function to create C2 tensor filled with random data.
+ The name/label of the tensor is returned and it is available
throughout the benchmark execution phase.
+ Args:
+ shapes: int or a sequence of ints to defining the shapes of the tensor
+ dtype: use the dtypes from numpy
+ (https://docs.scipy.org/doc/numpy/user/basics.types.html)
+ Return:
+ C2 tensor of dtype
"""
blob_name = 'blob_' + str(Caffe2BenchmarkBase.tensor_index)
- workspace.FeedBlob(blob_name, benchmark_utils.numpy_random_fp32(*shapes))
+ workspace.FeedBlob(blob_name, benchmark_utils.numpy_random(dtype, *shapes))
Caffe2BenchmarkBase.tensor_index += 1
return blob_name
diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py
index 8f54e6e..e7e730a 100644
--- a/benchmarks/operator_benchmark/benchmark_utils.py
+++ b/benchmarks/operator_benchmark/benchmark_utils.py
@@ -19,12 +19,18 @@
return ', '.join([str(x) for x in shape])
-def numpy_random_fp32(*shape):
- """Return a random numpy tensor of float32 type.
+def numpy_random(dtype, *shapes):
+ """ Return a random numpy tensor of the provided dtype.
+ Args:
+ shapes: int or a sequence of ints to defining the shapes of the tensor
+ dtype: use the dtypes from numpy
+ (https://docs.scipy.org/doc/numpy/user/basics.types.html)
+ Return:
+ numpy tensor of dtype
"""
# TODO: consider more complex/custom dynamic ranges for
# comprehensive test coverage.
- return np.random.rand(*shape).astype(np.float32)
+ return np.random.rand(*shapes).astype(dtype)
def set_omp_threads(num_threads):