| from . import CWrapPlugin |
| from string import Template |
| |
| # Arguments to the Broadcast Plugin: |
| # broadcast: args_to_broadcast_against [inplace] [fallback] |
| # [args_to_broadcast_against]: either a single argument (e.g. "arg1") or a comma-seperated |
| # list of two arguments (e.g. "tensor1,tensor2") indicating |
| # arguments to broadcast specified argument (usually "self") against |
| # [inplace] will generate code for in-place function, which doesn't allow the in-place |
| # argument to be broadcast |
| # [fallback] if tensors aren't broadcastable, preserves "element number" pointwise behavior, |
| # where only number of elements need to match, and tensors are viewed as 1-dimensional. |
| # [dims] if the tensors shouldn't be broadcast to specific tensor or tensors, but a combination |
| # of their individual dimensions. Each dimension is specified as [arg].dim[#] and dimensions |
| # are comma-separated. So, to specify that the tensor should be broadcast to 3-dimensions with |
| # sizes: tensor0->size[0] x tensor1->size[1] x tensor2->size[2], you would write: |
| # dims:tensor0.dim0,tensor1.dim1,tensor2.dim2 |
| |
| # For out of place: |
| # Two args: expand the two args together |
| # Three args (fused kernels): (e.g. addcmul) expand all three args together |
| # Sketch of proof that this is the same: |
| # consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together] |
| # Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k |
| # |
| # Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a) |
| # = e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param) |
| # = e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand) |
| # = e(a,b,c) + e(b,c,a) * e(c,b,a) (see L1) |
| # which is a + b * c all expanded together |
| # |
| # L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size |
| # Consider any index _{ s_0, ..., s_n} |
| # e(i * j, a) = (i*j)_{f(s_0), ...,f(s_n)} where f is the expansion of that dimension with a |
| # = i_{f(s_0), ..., f(s_n)} * j_{f(s_0), ..., f(s_n)} by definition of pointwise operator |
| # = e(i,a) * e(j,a) |
| |
| class Broadcast(CWrapPlugin): |
| DEPRECATED_WARNING = \ |
| """PyErr_WarnEx(PyExc_UserWarning, "${op_a} and ${op_other} not broadcastable, but have the same number of " |
| "elements. Falling back to deprecated pointwise behavior.", 1);""" |
| |
| # Save and restore passed in arguments in case later plugins use |
| POST_TEMPLATE = Template( |
| """${arg_op_other} = ${arg_op_other}_save;\n""") |
| |
| def getPreArgStringTemplate(self, includeElementCount=True): |
| ret = """THTensor *${arg_op_other}_save = ${arg_op_other}; |
| THTensorPtr ${arg_op_other}_guard = THTensor_(new)(LIBRARY_STATE_NOARGS); |
| ${arg_op_other}=${arg_op_other}_guard.get();""" |
| if includeElementCount: |
| ret += "ptrdiff_t ${arg_op_other}_nElem = THTensor_(nElement)(LIBRARY_STATE ${arg_op_other}_save);" |
| return Template(ret) |
| |
| OUT_PLACE_PRE_EXPAND2_TEMPLATE = Template( |
| """bool ${arg_op_other}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other}_nElem); |
| int ${arg_op_other}_err = |
| THTensor_(expand2)(LIBRARY_STATE ${arg_op_a}, ${arg_op_other}, ${arg_op_a}_save, ${arg_op_other}_save, ${arg_op_other}_raise); |
| if (${arg_op_other}_err != 0 && !${arg_op_other}_raise) { |
| ${post_code}""" |
| + DEPRECATED_WARNING + "\n" + |
| """}""") |
| |
| DEPRECATED_WARNING3 = \ |
| """PyErr_WarnEx(PyExc_UserWarning, "${op_a}, ${op_other1}, and ${op_other2} not broadcastable, but have the same number of " |
| "elements. Falling back to deprecated pointwise behavior.", 1);""" |
| |
| OUT_PLACE_PRE_EXPAND3_TEMPLATE = Template( |
| """bool ${arg_op_other1}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other1}_nElem); |
| bool ${arg_op_other2}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other2}_nElem); |
| int ${arg_op_a}_err = |
| THTensor_(expand3)(LIBRARY_STATE ${arg_op_a}, ${arg_op_other1}, ${arg_op_other2}, |
| ${arg_op_a}_save, ${arg_op_other1}_save, ${arg_op_other2}_save, |
| ${arg_op_other1}_raise || ${arg_op_other2}_raise); |
| if (${arg_op_a}_err != 0 && !${arg_op_other1}_raise && !${arg_op_other2}_raise) { |
| ${post_code}""" |
| + DEPRECATED_WARNING3 + "\n" |
| """}""") |
| |
| OUT_PLACE_EXPAND_DIM_SINGLE_TEMPLATE = Template( |
| """if(THTensor_(nDimension)(LIBRARY_STATE ${arg_op_dim}) <= ${arg_op_dim_value}) { |
| THError("Argument %s requires at least %d dimensions, but only has %d", |
| "${op_dim}", ${arg_op_dim_value} + 1, THTensor_(nDimension)(LIBRARY_STATE ${arg_op_dim})); |
| } |
| long ${arg_op_a}_dim${idx}_size = THTensor_(size)(LIBRARY_STATE ${arg_op_dim}, ${arg_op_dim_value}); |
| """) |
| |
| OUT_PLACE_PRE_EXPAND1_DIM_TEMPLATE = Template( |
| """THLongStoragePtr ${arg_op_a}_storage = THLongStorage_newWithSize1(${arg_op_a}_dim0_size); |
| THTensor_(expand)(LIBRARY_STATE ${arg_op_a}, ${arg_op_a}_save, ${arg_op_a}_storage, true); |
| """) |
| |
| OUT_PLACE_PRE_EXPAND2_DIM_TEMPLATE = Template( |
| """THLongStoragePtr ${arg_op_a}_storage = THLongStorage_newWithSize2(${arg_op_a}_dim0_size, ${arg_op_a}_dim1_size); |
| THTensor_(expand)(LIBRARY_STATE ${arg_op_a}, ${arg_op_a}_save, ${arg_op_a}_storage, true); |
| """) |
| |
| OUT_PLACE_PRE_EXPAND3_DIM_TEMPLATE = Template( |
| """THLongStoragePtr ${arg_op_a}_storage = THLongStorage_newWithSize3(${arg_op_a}_dim0_size, ${arg_op_a}_dim1_size, ${arg_op_a}_dim2_size); |
| THTensor_(expand)(LIBRARY_STATE ${arg_op_a}, ${arg_op_a}_save, ${arg_op_a}_storage, true); |
| """) |
| |
| OUT_PLACE_PRE_TEMPLATE = Template( |
| """${code_arg_op_a} |
| ${code_arg_op_other1} |
| ${code_arg_op_other2} |
| ${expand_code} |
| """) |
| |
| IN_PLACE_PRE_EXPAND1_TEMPLATE = Template( |
| """bool ${arg_op_other}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other}_nElem); |
| int ${arg_op_other}_err = |
| !skip_expand && THTensor_(expand)(LIBRARY_STATE ${arg_op_other}, ${arg_op_other}_save, ${arg_op_a}_size.get(), ${arg_op_other}_raise); |
| if (${arg_op_other}_err != 0 && !${arg_op_other}_raise) { |
| skip_expand = true; // don't do further expansions |
| ${post_code}""" |
| + DEPRECATED_WARNING + "\n" + |
| """}""") |
| |
| IN_PLACE_PRE_EXPAND2_TEMPLATE = Template( |
| """bool ${arg_op_other1}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other1}_nElem); |
| bool ${arg_op_other2}_raise = ${raise_errors} || (${arg_op_a}_nElem != ${arg_op_other2}_nElem); |
| int ${arg_op_other1}_err = |
| !skip_expand && THTensor_(expand)(LIBRARY_STATE ${arg_op_other1}, ${arg_op_other1}_save, ${arg_op_a}_size.get(), ${arg_op_other1}_raise || ${arg_op_other2}_raise); |
| if (${arg_op_other1}_err != 0 && !${arg_op_other1}_raise && ${arg_op_other2}_raise) { |
| skip_expand = true; // don't do further expansions |
| ${post_code}""" |
| + DEPRECATED_WARNING3 + "\n" + |
| """} |
| int ${arg_op_other2}_err = |
| !skip_expand && THTensor_(expand)(LIBRARY_STATE ${arg_op_other2}, ${arg_op_other2}_save, ${arg_op_a}_size.get(), ${arg_op_other1}_raise || ${arg_op_other2}_raise); |
| if (${arg_op_other2}_err != 0 && !${arg_op_other1}_raise && ${arg_op_other2}_raise) { |
| skip_expand = true; // don't do further expansions |
| ${post_code}""" |
| + DEPRECATED_WARNING3 + "\n" + |
| """}""") |
| |
| IN_PLACE_PRE_TEMPLATE = Template( |
| """THLongStoragePtr ${arg_op_a}_size = THTensor_(newSizeOf)(LIBRARY_STATE ${arg_op_a}); |
| ptrdiff_t ${arg_op_a}_nElem = THTensor_(nElement)(LIBRARY_STATE ${arg_op_a}); |
| bool skip_expand = false; |
| ${code_arg_op_other1} |
| ${code_arg_op_other2} |
| ${expand_code} |
| """) |
| |
| def initialize(self, cwrap): |
| self.cwrap = cwrap |
| |
| # Arguments: |
| # [0]: name of tensor to broadcast with (possibly two comma separated) |
| # [1] inplace (optional). In place operations only broadcast on second tensor argument |
| # [2] fallback (optional). Will fallback to applying to tensor of equal nElem if broadcast fails |
| def process_option_code_template(self, template, option): |
| new_code_pre = [] |
| new_code_post = [] |
| for _, arg in enumerate(option['arguments']): |
| if 'broadcast' not in arg: |
| continue |
| |
| params = arg.get('broadcast').split(" ") |
| op_a = arg.get('assign_name', arg['name']) |
| in_place = "inplace" in params |
| raise_errors = "false" if "fallback" in params else "true" |
| |
| param_others = params[0].split(",") |
| if len(param_others) > 2: |
| raise ValueError('Broadcast only supports up to 2 secondary parameters') |
| op_b = param_others[0] |
| op_c = param_others[1] if len(param_others) == 2 else None |
| arg_op_b = "arg_" + op_b |
| arg_op_a = "arg_" + op_a |
| arg_op_c = ("arg_" + op_c) if op_c else None |
| |
| dims_kvs = [] |
| for p in params: |
| if p.startswith("dims:"): |
| if len(dims_kvs) != 0: |
| raise ValueError("multiple specifications of dims") |
| dims = p[len("dims:"):].split(",") |
| for dim in dims: |
| batchdim = dim.split(".") |
| assert len(batchdim) == 2 |
| assert batchdim[1].startswith("dim") |
| dim_val = batchdim[1][len("dim"):] |
| dims_kvs.append( {"op":batchdim[0], "arg_op":"arg_" + batchdim[0], "val":dim_val} ) |
| |
| assert len(dims_kvs) <= 3 |
| for p in params[1:]: |
| if p != "inplace" and p != "fallback" and not p.startswith("dims:"): |
| raise ValueError("invalid parameter {}".format(p)) |
| |
| op_b_mapping = { |
| "op_a":op_a, |
| "op_other":op_b, |
| "arg_op_a":arg_op_a, |
| "arg_op_other":arg_op_b, |
| "raise_errors":raise_errors |
| } |
| op_c_mapping = { |
| "op_a":op_a, |
| "op_other":op_c, |
| "arg_op_a":arg_op_a, |
| "arg_op_other":arg_op_c, |
| "raise_errors":raise_errors |
| } |
| |
| if in_place: |
| code_arg_op_other1 = self.getPreArgStringTemplate().substitute(op_b_mapping) |
| code_arg_op_other2 = self.getPreArgStringTemplate().substitute(op_c_mapping) if op_c else "" |
| |
| |
| post_code = self.POST_TEMPLATE.substitute(op_b_mapping) |
| if op_c: |
| post_code += self.POST_TEMPLATE.substitute(op_c_mapping) |
| |
| if op_c: |
| expand_code = self.IN_PLACE_PRE_EXPAND2_TEMPLATE.substitute( |
| op_b_mapping, |
| op_other1=op_b, |
| op_other2=op_c, |
| arg_op_other1=arg_op_b, |
| arg_op_other2=arg_op_c, |
| post_code=post_code) |
| else: |
| expand_code = self.IN_PLACE_PRE_EXPAND1_TEMPLATE.substitute(op_b_mapping, post_code=post_code) |
| |
| new_code_pre.append(self.IN_PLACE_PRE_TEMPLATE.substitute( |
| arg_op_a=arg_op_a, |
| code_arg_op_other1=code_arg_op_other1, |
| code_arg_op_other2=code_arg_op_other2, |
| expand_code=expand_code, |
| raise_errors=raise_errors)) |
| new_code_pre.append("") |
| |
| new_code_post.append(post_code) |
| new_code_post.append("") |
| else: |
| if len(dims_kvs) != 0: |
| code_arg_op_a = self.getPreArgStringTemplate(False).substitute(arg_op_other=arg_op_a) |
| code_arg_op_other1 = "" |
| code_arg_op_other2 = "" |
| expand_code = "" |
| for idx,kv in enumerate(dims_kvs): |
| expand_code += self.OUT_PLACE_EXPAND_DIM_SINGLE_TEMPLATE.substitute( |
| arg_op_a=arg_op_a, |
| op_dim=kv["op"], |
| arg_op_dim=kv["arg_op"], |
| arg_op_dim_value=kv["val"], |
| idx=idx) |
| |
| if len(dims_kvs) == 1: |
| expand_code += self.OUT_PLACE_PRE_EXPAND1_DIM_TEMPLATE.substitute( |
| arg_op_a=arg_op_a, |
| arg_op_dim0=dims_kvs[0]["arg_op"]) |
| elif len(dims_kvs) == 2: |
| expand_code += self.OUT_PLACE_PRE_EXPAND2_DIM_TEMPLATE.substitute( |
| arg_op_a=arg_op_a, |
| arg_op_dim0=dims_kvs[0]["arg_op"], |
| arg_op_dim1=dims_kvs[1]["arg_op"]) |
| else: |
| expand_code += self.OUT_PLACE_PRE_EXPAND3_DIM_TEMPLATE.substitute( |
| arg_op_a=arg_op_a, |
| arg_op_dim0=dims_kvs[0]["arg_op"], |
| arg_op_dim1=dims_kvs[1]["arg_op"], |
| arg_op_dim2=dims_kvs[2]["arg_op"]) |
| post_code = self.POST_TEMPLATE.substitute(arg_op_other=arg_op_a) |
| |
| else: |
| code_arg_op_a = self.getPreArgStringTemplate().substitute(arg_op_other=arg_op_a) |
| code_arg_op_other1 = self.getPreArgStringTemplate().substitute(op_b_mapping) |
| code_arg_op_other2 = self.getPreArgStringTemplate().substitute(op_c_mapping) if op_c else "" |
| |
| post_code = self.POST_TEMPLATE.substitute(arg_op_other=arg_op_a) |
| post_code += self.POST_TEMPLATE.substitute(op_b_mapping) |
| post_code += self.POST_TEMPLATE.substitute(op_c_mapping) if op_c else "" |
| |
| if op_c: |
| expand_code = self.OUT_PLACE_PRE_EXPAND3_TEMPLATE.substitute( |
| op_b_mapping, |
| op_other1=op_b, |
| op_other2=op_c, |
| arg_op_other1=arg_op_b, |
| arg_op_other2=arg_op_c, |
| post_code=post_code) |
| else: |
| expand_code = self.OUT_PLACE_PRE_EXPAND2_TEMPLATE.substitute(op_b_mapping, post_code=post_code) |
| |
| new_code_pre.append(self.OUT_PLACE_PRE_TEMPLATE.substitute( |
| code_arg_op_a=code_arg_op_a, |
| code_arg_op_other1=code_arg_op_other1, |
| code_arg_op_other2=code_arg_op_other2, |
| expand_code=expand_code)) |
| new_code_pre.append("") |
| |
| new_code_post.append(post_code) |
| new_code_post.append("") |
| |
| template = new_code_pre + template + new_code_post |
| return template |