change unary, pool, max ops to use new interface (#22661)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22661
as title
Reviewed By: hl475
Differential Revision: D16170825
fbshipit-source-id: d80944224b8717e7aa35980907ff48e587b85217
diff --git a/benchmarks/operator_benchmark/pt/pool_test.py b/benchmarks/operator_benchmark/pt/pool_test.py
index fcde4fc..ca81c7d 100644
--- a/benchmarks/operator_benchmark/pt/pool_test.py
+++ b/benchmarks/operator_benchmark/pt/pool_test.py
@@ -29,7 +29,7 @@
pool_1d_ops_list = op_bench.op_list(
- attr_names=["op_name", "op"],
+ attr_names=["op_name", "op_func"],
attrs=[
["MaxPool1d", nn.MaxPool1d],
["AvgPool1d", nn.AvgPool1d],
@@ -38,19 +38,17 @@
class Pool1dBenchmark(op_bench.TorchBenchmarkBase):
- def init(self, kernel, stride, N, C, L):
+ def init(self, kernel, stride, N, C, L, op_func):
self.input = torch.rand(N, C, L)
self.kernel = kernel
self.stride = stride
-
- def set_op(self, op):
- self.op = op(self.kernel, stride=self.stride)
+ self.op_func = op_func(self.kernel, stride=self.stride)
def forward(self):
- return self.op(self.input)
+ return self.op_func(self.input)
-op_bench.generate_pt_tests_from_list(pool_1d_ops_list, pool_1d_configs, Pool1dBenchmark)
+op_bench.generate_pt_tests_from_op_list(pool_1d_ops_list, pool_1d_configs, Pool1dBenchmark)
"""
@@ -73,7 +71,7 @@
pool_2d_ops_list = op_bench.op_list(
- attr_names=["op_name", "op"],
+ attr_names=["op_name", "op_func"],
attrs=[
["MaxPool2d", nn.MaxPool2d],
["AvgPool2d", nn.AvgPool2d],
@@ -82,19 +80,17 @@
class Pool2dBenchmark(op_bench.TorchBenchmarkBase):
- def init(self, kernel, stride, N, C, H, W):
+ def init(self, kernel, stride, N, C, H, W, op_func):
self.input = torch.rand(N, C, H, W)
self.kernel = kernel
self.stride = stride
-
- def set_op(self, op):
- self.op = op(self.kernel, stride=self.stride)
+ self.op_func = op_func(self.kernel, stride=self.stride)
def forward(self):
- return self.op(self.input)
+ return self.op_func(self.input)
-op_bench.generate_pt_tests_from_list(pool_2d_ops_list, pool_2d_configs, Pool2dBenchmark)
+op_bench.generate_pt_tests_from_op_list(pool_2d_ops_list, pool_2d_configs, Pool2dBenchmark)
"""
@@ -117,7 +113,7 @@
pool_3d_ops_list = op_bench.op_list(
- attr_names=["op_name", "op"],
+ attr_names=["op_name", "op_func"],
attrs=[
["MaxPool3d", nn.MaxPool3d],
["AvgPool3d", nn.AvgPool3d],
@@ -126,19 +122,17 @@
class Pool3dBenchmark(op_bench.TorchBenchmarkBase):
- def init(self, kernel, stride, N, C, D, H, W):
+ def init(self, kernel, stride, N, C, D, H, W, op_func):
self.input = torch.rand(N, C, D, H, W)
self.kernel = kernel
self.stride = stride
-
- def set_op(self, op):
- self.op = op(self.kernel, stride=self.stride)
+ self.op_func = op_func(self.kernel, stride=self.stride)
def forward(self):
- return self.op(self.input)
+ return self.op_func(self.input)
-op_bench.generate_pt_tests_from_list(pool_3d_ops_list, pool_3d_configs, Pool3dBenchmark)
+op_bench.generate_pt_tests_from_op_list(pool_3d_ops_list, pool_3d_configs, Pool3dBenchmark)
if __name__ == "__main__":
diff --git a/benchmarks/operator_benchmark/pt/softmax_test.py b/benchmarks/operator_benchmark/pt/softmax_test.py
index 9fc376c..7492929 100644
--- a/benchmarks/operator_benchmark/pt/softmax_test.py
+++ b/benchmarks/operator_benchmark/pt/softmax_test.py
@@ -49,7 +49,7 @@
return self.op_func(self.input_one)
-op_bench.generate_pt_tests_from_list(softmax_ops_list, softmax_configs, SoftmaxBenchmark)
+op_bench.generate_pt_tests_from_op_list(softmax_ops_list, softmax_configs, SoftmaxBenchmark)
if __name__ == "__main__":
diff --git a/benchmarks/operator_benchmark/pt/unary_test.py b/benchmarks/operator_benchmark/pt/unary_test.py
index e9a0289..d74c01b 100644
--- a/benchmarks/operator_benchmark/pt/unary_test.py
+++ b/benchmarks/operator_benchmark/pt/unary_test.py
@@ -24,16 +24,16 @@
class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):
- def init(self, M, N, op_function):
+ def init(self, M, N, op_func):
self.input_one = torch.rand(M, N)
- self.op_function = op_function
+ self.op_func = op_func
def forward(self):
- return self.op_function(self.input_one)
+ return self.op_func(self.input_one)
unary_ops_list = op_bench.op_list(
- attr_names=["op_name", "op_function"],
+ attr_names=["op_name", "op_func"],
attrs=[
["abs", torch.abs],
["abs_", torch.abs_],
@@ -119,7 +119,7 @@
)
-op_bench.generate_pt_tests_from_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark)
+op_bench.generate_pt_tests_from_op_list(unary_ops_list, unary_ops_configs, UnaryOpBenchmark)
if __name__ == "__main__":