[operator_benchmark] Added channels last 3d option to interpolate test (#53117)
Summary:
Description:
- Added channels last 3d option to interpolate test
- split config non-4d into two : 3d and 5d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53117
Reviewed By: NicolasHug
Differential Revision: D26754243
Pulled By: fmassa
fbshipit-source-id: 49bbab3bb47de27790e39537d0fbeca0f01782c4
diff --git a/benchmarks/operator_benchmark/pt/interpolate_test.py b/benchmarks/operator_benchmark/pt/interpolate_test.py
index 694e5cd..b33de76 100644
--- a/benchmarks/operator_benchmark/pt/interpolate_test.py
+++ b/benchmarks/operator_benchmark/pt/interpolate_test.py
@@ -10,7 +10,14 @@
input_image = torch.randint(0, 256, size=input_size, dtype=torch.float, device='cpu',
requires_grad=self.auto_set())
if channels_last:
- input_image = input_image.contiguous(memory_format=torch.channels_last)
+ if input_image.ndim == 4:
+ input_image = input_image.contiguous(memory_format=torch.channels_last)
+ elif input_image.ndim == 5:
+ input_image = input_image.contiguous(memory_format=torch.channels_last_3d)
+ else:
+ raise ValueError(
+ f"Can not set channels_last to the input of {input_image.ndim} dims"
+ )
ndim_to_mode = {
3: 'linear',
@@ -61,13 +68,10 @@
)
-config_not_4d = op_bench.config_list(
- # no channels_last as it's only valid for 4D tensors
+config_3d = op_bench.config_list(
+ # no channels_last for 3D tensors
attr_names=["input_size", "output_size"],
attrs=[
- [(1, 3, 16, 320, 320), (8, 256, 256)],
- [(1, 3, 16, 320, 320), (32, 512, 512)],
-
[(4, 512, 320), (256,)],
[(4, 512, 320), (512,)],
],
@@ -75,7 +79,20 @@
)
-for config in (config_short, config_long, config_not_4d):
+config_5d = op_bench.config_list(
+ attr_names=["input_size", "output_size"],
+ attrs=[
+ [(1, 3, 16, 320, 320), (8, 256, 256)],
+ [(1, 3, 16, 320, 320), (32, 512, 512)],
+ ],
+ cross_product_configs={
+ 'channels_last': [True, False],
+ },
+ tags=["long"],
+)
+
+
+for config in (config_short, config_long, config_3d, config_5d):
op_bench.generate_pt_test(config, InterpolateBenchmark)