|  | import itertools | 
|  | import operator | 
|  |  | 
|  | import numpy as np | 
|  | import scipy.special | 
|  |  | 
|  | import torch | 
|  |  | 
|  | from . import benchmark | 
|  |  | 
|  |  | 
|  | # A template class for elementwise operations. | 
|  | # A derived class will override the class instance to customize its behavior. | 
|  | class ElementBench(benchmark.Benchmark): | 
|  | # List of customization class variables. | 
|  | op_str = None | 
|  | binary_op_pt_func = None | 
|  | binary_op_np_func = None | 
|  | unary_op_pt_func = None | 
|  | unary_op_np_func = None | 
|  | split_input = True | 
|  |  | 
|  | def __init__(self, mode, device, dtype, N): | 
|  | super().__init__(mode, device, dtype) | 
|  | self.N = N | 
|  | self.d1 = self.rand( | 
|  | [N], device=device, dtype=dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.d2 = self.rand( | 
|  | [N], device=device, dtype=dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.d3 = self.rand( | 
|  | [N], device=device, dtype=dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.d4 = self.rand( | 
|  | [N], device=device, dtype=dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.inputs = [self.d1, self.d2, self.d3, self.d4] | 
|  | self.deterministic = "rand" not in self.op_str | 
|  |  | 
|  | def _eval(self, d1, d2, d3, d4, binary_op, unary_op): | 
|  | if not binary_op: | 
|  |  | 
|  | def binary_op(x, y): | 
|  | return x + y | 
|  |  | 
|  | if not unary_op: | 
|  |  | 
|  | def unary_op(x): | 
|  | return x | 
|  |  | 
|  | if self.split_input: | 
|  | d1 = unary_op(d1) | 
|  | d2 = unary_op(d2) | 
|  | d3 = unary_op(d3) | 
|  | d4 = unary_op(d4) | 
|  | else: | 
|  | d2 = unary_op(d1 + 0.001) | 
|  | d3 = unary_op(d1 + 0.002) | 
|  | d4 = unary_op(d1 + 0.003) | 
|  | d1 = unary_op(d1) | 
|  | a = binary_op(d1, d2) | 
|  | b = binary_op(d3, d4) | 
|  | c = a + b | 
|  | return c | 
|  |  | 
|  | def forward(self, d1, d2, d3, d4): | 
|  | binary_op = self.__class__.binary_op_pt_func | 
|  | unary_op = self.__class__.unary_op_pt_func | 
|  | return self._eval(d1, d2, d3, d4, binary_op, unary_op) | 
|  |  | 
|  | def reference(self): | 
|  | binary_op = self.__class__.binary_op_np_func | 
|  | unary_op = self.__class__.unary_op_np_func | 
|  | [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] | 
|  | return self._eval(d1, d2, d3, d4, binary_op, unary_op) | 
|  |  | 
|  | def config(self): | 
|  | return [self.N] | 
|  |  | 
|  | @classmethod | 
|  | def module(cls): | 
|  | return "element_" + cls.op_str | 
|  |  | 
|  | def memory_workload(self): | 
|  | input_count = len(self.inputs) | 
|  | if self.mode == "fwd": | 
|  | if self.split_input: | 
|  | sol_count = input_count + 1 | 
|  | algorithmic_count = input_count + 1 | 
|  | else: | 
|  | sol_count = 1 + 1 | 
|  | algorithmic_count = 1 + 1 | 
|  | if "rand" in self.op_str: | 
|  | sol_count = 1 | 
|  | algorithmic_count = 1 | 
|  | else: | 
|  | if self.split_input: | 
|  | sol_count = (input_count + 1) + (1 + input_count) | 
|  | algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) | 
|  | else: | 
|  | sol_count = 1 + 1 | 
|  | algorithmic_count = 1 + 1 | 
|  | if "rand" in self.op_str: | 
|  | sol_count = 1 | 
|  | algorithmic_count = 1 | 
|  |  | 
|  | buffer_size = self.N | 
|  | return { | 
|  | "sol": buffer_size * sol_count, | 
|  | "algorithmic": buffer_size * algorithmic_count, | 
|  | } | 
|  |  | 
|  | @staticmethod | 
|  | def default_configs(): | 
|  | return [[1 << 25]] | 
|  |  | 
|  |  | 
|  | def register_element_ops(): | 
|  | binary_op_list = [ | 
|  | ["mul", operator.mul], | 
|  | ["add", operator.add], | 
|  | ["sub", operator.sub], | 
|  | ["div", lambda a, b: a / (b + 1e-4)], | 
|  | [ | 
|  | "pow", | 
|  | torch.pow, | 
|  | np.power, | 
|  | ],  # no fuson triggered | 
|  | ["max", torch.max, np.maximum], | 
|  | ["min", torch.min, np.minimum], | 
|  | ] | 
|  |  | 
|  | unary_op_list = [ | 
|  | ["erf", torch.erf, scipy.special.erf], | 
|  | ["exp", torch.exp, np.exp], | 
|  | ["sin", torch.sin, np.sin], | 
|  | ["cos", torch.cos, np.cos], | 
|  | ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)], | 
|  | ] | 
|  |  | 
|  | for split_input, binary_op in itertools.product([True, False], binary_op_list): | 
|  | # Make a copy of ElementBench | 
|  | if len(binary_op) == 2: | 
|  | [op_str, op_pt_func] = binary_op | 
|  | op_np_func = op_pt_func | 
|  | elif len(binary_op) == 3: | 
|  | [op_str, op_pt_func, op_np_func] = binary_op | 
|  | split_str = "split" if split_input else "shared" | 
|  | op_str = split_str + "_" + op_str | 
|  | bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) | 
|  | bm_cls.op_str = op_str | 
|  | bm_cls.binary_op_pt_func = op_pt_func | 
|  | bm_cls.binary_op_np_func = op_np_func | 
|  | bm_cls.split_input = split_input | 
|  | benchmark.register_benchmark_class(bm_cls) | 
|  |  | 
|  | for split_input, unary_op in itertools.product([True, False], unary_op_list): | 
|  | # Make a copy of ElementBench | 
|  | if len(unary_op) == 2: | 
|  | [op_str, op_pt_func] = unary_op | 
|  | op_np_func = op_pt_func | 
|  | elif len(unary_op) == 3: | 
|  | [op_str, op_pt_func, op_np_func] = unary_op | 
|  | split_str = "split" if split_input else "shared" | 
|  | op_str = split_str + "_" + op_str | 
|  | bm_cls = type("ElementBench_" + op_str, (ElementBench,), {}) | 
|  | bm_cls.op_str = op_str | 
|  | bm_cls.unary_op_pt_func = op_pt_func | 
|  | bm_cls.unary_op_np_func = op_np_func | 
|  | bm_cls.split_input = split_input | 
|  | benchmark.register_benchmark_class(bm_cls) | 
|  |  | 
|  |  | 
|  | # benchmark.register_benchmark_class(ElementMulBench) | 
|  | register_element_ops() | 
|  |  | 
|  |  | 
|  | class SimpleElementBench(benchmark.Benchmark): | 
|  | def __init__(self, mode, device, dtype, N): | 
|  | super().__init__(mode, device, dtype) | 
|  | self.N = N | 
|  | self.data = self.rand( | 
|  | [N], device=device, dtype=dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.inputs = [self.data] | 
|  |  | 
|  | def forward(self, data): | 
|  | a = data + 0.001 | 
|  | b = a + 0.002 | 
|  | return b | 
|  |  | 
|  | def reference(self): | 
|  | binary_op = self.__class__.binary_op_np_func | 
|  | unary_op = self.__class__.unary_op_np_func | 
|  | [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] | 
|  | return self._eval(d1, d2, d3, d4, binary_op, unary_op) | 
|  |  | 
|  | def config(self): | 
|  | return [self.N] | 
|  |  | 
|  | @staticmethod | 
|  | def input_iterable(): | 
|  | return True | 
|  |  | 
|  | @classmethod | 
|  | def module(cls): | 
|  | return "simple_element" | 
|  |  | 
|  | def memory_workload(self): | 
|  | input_count = len(self.inputs) | 
|  | if self.mode == "fwd": | 
|  | sol_count = 2 | 
|  | algorithmic_count = 2 | 
|  | else: | 
|  | sol_count = 2 | 
|  | algorithmic_count = 2 | 
|  |  | 
|  | buffer_size = self.N | 
|  | return { | 
|  | "sol": buffer_size * sol_count, | 
|  | "algorithmic": buffer_size * algorithmic_count, | 
|  | } | 
|  |  | 
|  | @staticmethod | 
|  | def default_configs(): | 
|  | return [[1 << 25]] | 
|  |  | 
|  |  | 
|  | benchmark.register_benchmark_class(SimpleElementBench) | 
|  |  | 
|  |  | 
|  | class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench): | 
|  | def __init__(self, mode, device, dtype, N): | 
|  | benchmark.DynamicShape.__init__(self) | 
|  | SimpleElementBench.__init__(self, mode, device, dtype, N) | 
|  |  | 
|  | @classmethod | 
|  | def module(cls): | 
|  | return "simple_dynamic_element" | 
|  |  | 
|  | def instantiate_input(self): | 
|  | (N,) = self.rand_shape([self.N]) | 
|  | data = self.rand( | 
|  | [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad | 
|  | ) | 
|  | self.inputs = [data] | 
|  |  | 
|  |  | 
|  | benchmark.register_benchmark_class(DynamicSimpleElementBench) |