blob: 5b0a74167839c4b07a408a9dc0934cfd6761df75 [file] [log] [blame]
from . import CWrapPlugin
from string import Template
# Arguments to the Broadcast Plugin:
# broadcast: args_to_broadcast_against [inplace] [fallback]
# [args_to_broadcast_against]: either a single argument (e.g. "arg1") or a comma-separated
# list of two arguments (e.g. "tensor1,tensor2") indicating
# arguments to broadcast specified argument (usually "self") against
# [inplace] will generate code for in-place function, which doesn't allow the in-place
# argument to be broadcast
# [fallback] if tensors aren't broadcastable, preserves "element number" pointwise behavior,
# where only number of elements need to match, and tensors are viewed as 1-dimensional.
# [dims] specify if the tensors shouldn't be broadcast to a specific tensor or tensors, but a combination
# of individual dimension sizes of a set of tensors. For example: addbmm(C,A,B) a.k.a. [C + A @ B]
# broadcasts C to the first dimension of A and the second dimension of B. Each dimension is specified as
# [arg].dim[#] and dimensions are comma-separated. So, to specify that the tensor should be
# broadcast to 3-dimensions with sizes:
# tensor0->size[0] x tensor1->size[1] x tensor2->size[2]
# you would write:
# dims:tensor0.dim0,tensor1.dim1,tensor2.dim2
# [types] if the tensors should be of different types than THTensor, specify as X where
# the actual type to use is THXTensor (i.e. Byte for THByteTensor). If the type
# should be THTensor, use 'Real'
# For out of place:
# Two args: expand the two args together
# Three args (fused kernels): (e.g. addcmul) expand all three args together
# Sketch of proof that this is the same:
# consider addcmul, under expansion we want: a + (b * c) = (a + b * c) [all expanded together]
# Let e(i, j) be the expansion of i with j, e(i, j, k) be the expansion of i with j,k
#
# Then a + (b * c) = e(a, e(b,c) * e(c,b)) + e(e(b,c) * e(c,b), a)
# = e(a, e(b,c)) + e(e(b,c) * e(c,b), a) (only size matters for second param)
# = e(a,b,c) + e(e(b,c) * e(c,b), a) (by associativity of max in expand)
# = e(a,b,c) + e(b,c,a) * e(c,b,a) (see L1)
# which is a + b * c all expanded together
#
# L1: Show e(i * j, a) = e(i,a) * e(j,a) where i,j have same size
# Consider any index _{ s_0, ..., s_n}
# e(i * j, a) = (i*j)_{f(s_0), ...,f(s_n)} where f is the expansion of that dimension with a
# = i_{f(s_0), ..., f(s_n)} * j_{f(s_0), ..., f(s_n)} by definition of pointwise operator
# = e(i,a) * e(j,a)
class Broadcast(CWrapPlugin):
# Save and restore passed in arguments in case later plugins use
POST_TEMPLATE = Template(
"""${arg_op_other} = ${arg_op_other}_save;\n""")
def getPreArgStringTemplate(self, type=None):
if type is None:
ret = """THTensor *${arg_op_other}_save = ${arg_op_other};
THTensorPtr ${arg_op_other}_guard(nullptr);\n"""
else:
cpu_t = "TH" + type + "Tensor"
gpu_t = "THCuda" + type + "Tensor"
ret = ("#if !IS_CUDA\n" +
cpu_t + " *${arg_op_other}_save = ${arg_op_other};\n" +
cpu_t + "Ptr ${arg_op_other}_guard(nullptr);\n" +
"#else\n" +
gpu_t + " *${arg_op_other}_save = ${arg_op_other};\n" +
"THPPointer<" + gpu_t + "> ${arg_op_other}_guard(nullptr);\n" +
"#endif\n")
return Template(ret)
def getNewForExpand(self, type):
if type is None:
ret = """THTensor_(new)(LIBRARY_STATE_NOARGS);\n"""
else:
cpu_t = "TH" + type + "Tensor"
gpu_t = "THCuda" + type + "Tensor"
ret = ("#if !IS_CUDA\n" +
cpu_t + "_new(LIBRARY_STATE_NOARGS);\n" +
"#else\n" +
gpu_t + "_new(LIBRARY_STATE_NOARGS);\n" +
"#endif\n")
return ret
def getExpandTemplate(self, same_size_check, expand_call, success_code, raise_errors):
if not raise_errors:
return Template(
"bool try_expand = !" + same_size_check + "\n" +
"if (try_expand) {\n" +
"bool expand_success = false;\n" +
"try {\n" +
expand_call +
"\nexpand_success = true;\n" +
"}\n"
"catch (std::exception &e) {}\n" +
"if(expand_success) {\n" +
success_code +
"\n}" +
"\n}\n")
else:
return Template(
"bool try_expand = !" + same_size_check + "\n" +
"if (try_expand) {\n" +
expand_call + "\n" +
success_code + "\n"
"}\n")
def getOutPlacePreExpand2Template(self, type_op_a, type_op_other, raise_errors):
size_check = """THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other}->size, ${arg_op_other}->nDimension);"""
expand_code = ("${arg_op_a}_guard = \n" + self.getNewForExpand(type_op_a) + "\n" +
"${arg_op_other}_guard = \n" + self.getNewForExpand(type_op_other) + "\n" +
"""expand_outplace2(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_other}_guard.get(),
${arg_op_a}, ${arg_op_other},
\"${op_a}\", \"${op_other}\", !${raise_errors});""")
success_code = """${arg_op_a} = ${arg_op_a}_guard.get();
${arg_op_other} = ${arg_op_other}_guard.get();"""
return self.getExpandTemplate(size_check, expand_code, success_code, raise_errors)
def getOutPlacePreExpand3Template(self, type_op_a, type_op_other1, type_op_other2, raise_errors):
size_check = """(THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other1}->size, ${arg_op_other1}->nDimension) &&
THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other2}->size, ${arg_op_other2}->nDimension));"""
expand_code = ("${arg_op_a}_guard = \n" + self.getNewForExpand(type_op_a) + "\n" +
"${arg_op_other1}_guard = \n" + self.getNewForExpand(type_op_other1) + "\n" +
"${arg_op_other2}_guard = \n" + self.getNewForExpand(type_op_other2) + "\n" +
"""expand_outplace3(LIBRARY_STATE ${arg_op_a}_guard.get(),
${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
${arg_op_a}, ${arg_op_other1}, ${arg_op_other2},
\"${op_a}\", \"${op_other1}\", \"${op_other2}\", !${raise_errors});""")
success_code = """${arg_op_a} = ${arg_op_a}_guard.get();
${arg_op_other1} = ${arg_op_other1}_guard.get();
${arg_op_other2} = ${arg_op_other2}_guard.get();"""
return self.getExpandTemplate(size_check, expand_code, success_code, raise_errors)
OUT_PLACE_PRE_EXPAND_PRE_DIM_TEMPLATE = Template(
"""if(THTensor_(nDimension)(LIBRARY_STATE ${arg_op_dim}) <= ${arg_op_dim_value}) {
THError("Argument %s requires at least %d dimensions, but only has %d",
"${op_dim}", ${arg_op_dim_value} + 1, THTensor_(nDimension)(LIBRARY_STATE ${arg_op_dim}));
}
int64_t ${arg_op_a}_dim${idx}_size = THTensor_(size)(LIBRARY_STATE ${arg_op_dim}, ${arg_op_dim_value});\n""")
OUT_PLACE_PRE_EXPAND1_DIM_TEMPLATE = Template(
"""THLongStoragePtr ${arg_op_a}_storage(THLongStorage_newWithSize1(${arg_op_a}_dim0_size));\n""")
OUT_PLACE_PRE_EXPAND2_DIM_TEMPLATE = Template(
"""THLongStoragePtr ${arg_op_a}_storage(
THLongStorage_newWithSize2(${arg_op_a}_dim0_size, ${arg_op_a}_dim1_size));\n""")
OUT_PLACE_PRE_EXPAND3_DIM_TEMPLATE = Template(
"""THLongStoragePtr ${arg_op_a}_storage(
THLongStorage_newWithSize3(${arg_op_a}_dim0_size, ${arg_op_a}_dim1_size, ${arg_op_a}_dim2_size));\n""")
def getOutPlacePreExpandPostDimTemplate(self, type_op_a, raise_errors):
size_check = """THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_a}_storage->data, ${arg_op_a}_storage->size);"""
expand_code = ("${arg_op_a}_guard = \n" + self.getNewForExpand(type_op_a) + "\n" +
"""expand(LIBRARY_STATE ${arg_op_a}_guard.get(), ${arg_op_a}, ${arg_op_a}_storage);""")
success_code = """${arg_op_a} = ${arg_op_a}_guard.get();"""
return self.getExpandTemplate(size_check, expand_code, success_code, raise_errors)
OUT_PLACE_PRE_TEMPLATE = Template(
"""${code_arg_op_a}${code_arg_op_other1}${code_arg_op_other2}
${expand_code}""")
def getInPlacePreExpand1Template(self, type_op_other, raise_errors):
size_check = """THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other}->size, ${arg_op_other}->nDimension);"""
expand_code = ("${arg_op_other}_guard = \n" + self.getNewForExpand(type_op_other) + "\n" +
"""expand_inplace1(LIBRARY_STATE ${arg_op_other}_guard.get(), ${arg_op_other}, ${arg_op_a},
\"${op_other}\", \"${op_a}\", !${raise_errors});""")
success_code = """${arg_op_other} = ${arg_op_other}_guard.get();"""
return self.getExpandTemplate(size_check, expand_code, success_code, raise_errors)
def getInPlacePreExpand2Template(self, type_op_other1, type_op_other2, raise_errors):
size_check = """(THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other1}->size, ${arg_op_other1}->nDimension) &&
THSize_isSameSizeAs(${arg_op_a}->size, ${arg_op_a}->nDimension,
${arg_op_other2}->size, ${arg_op_other2}->nDimension));"""
expand_code = ("${arg_op_other1}_guard = \n" + self.getNewForExpand(type_op_other1) + "\n" +
"${arg_op_other2}_guard = \n" + self.getNewForExpand(type_op_other2) + "\n" +
"""expand_inplace2(LIBRARY_STATE ${arg_op_other1}_guard.get(), ${arg_op_other2}_guard.get(),
${arg_op_other1}, ${arg_op_other2}, ${arg_op_a},
\"${op_other1}\", \"${op_other2}\", \"${op_a}\", !${raise_errors});""")
success_code = """${arg_op_other1} = ${arg_op_other1}_guard.get();
${arg_op_other2} = ${arg_op_other2}_guard.get();"""
return self.getExpandTemplate(size_check, expand_code, success_code, raise_errors)
IN_PLACE_PRE_TEMPLATE = Template(
"""${code_arg_op_other1}${code_arg_op_other2}
${expand_code}""")
def initialize(self, cwrap):
self.cwrap = cwrap
# Arguments:
# [0]: name of tensor to broadcast with (possibly two comma separated)
# [1] inplace (optional). In place operations only broadcast on second tensor argument
# [2] fallback (optional). Will fallback to applying to tensor of equal nElem if broadcast fails
def process_option_code_template(self, template, option):
new_code_pre = []
new_code_post = []
for _, arg in enumerate(option['arguments']):
if 'broadcast' not in arg:
continue
params = arg.get('broadcast').split(" ")
op_a = arg.get('assign_name', arg['name'])
in_place = "inplace" in params
raise_errors = "false" if "fallback" in params else "true"
param_others = params[0].split(",")
if len(param_others) > 2:
raise ValueError('Broadcast only supports up to 2 secondary parameters')
op_b = param_others[0]
op_c = param_others[1] if len(param_others) == 2 else None
arg_op_b = "arg_" + op_b
arg_op_a = "arg_" + op_a
arg_op_c = ("arg_" + op_c) if op_c else None
dims_kvs = []
for p in params:
if p.startswith("dims:"):
assert(raise_errors == "true")
if len(dims_kvs) != 0:
raise ValueError("multiple specifications of dims")
dims = p[len("dims:"):].split(",")
for dim in dims:
batchdim = dim.split(".")
assert len(batchdim) == 2
assert batchdim[1].startswith("dim")
dim_val = batchdim[1][len("dim"):]
dims_kvs.append({"op": batchdim[0], "arg_op": "arg_" + batchdim[0], "val": dim_val})
assert len(dims_kvs) <= 3
for p in params[1:]:
if p != "inplace" and p != "fallback" and not p.startswith("dims:") and not p.startswith("types:"):
raise ValueError("invalid parameter {}".format(p))
type_op_b = None
type_op_c = None
for p in params:
if p.startswith("types:"):
if not in_place and len(dims_kvs) > 0:
raise ValueError("type specification not supported yet for out-of-place functions "
"that specify explicit dimensions")
types = p[len("types:"):].split(",")
assert(len(types) == (2 if op_c else 1))
type_op_b = None if types[0] == "Real" else types[0]
if op_c:
type_op_c = None if types[1] == "Real" else types[1]
op_b_mapping = {
"op_a": op_a,
"op_other": op_b,
"arg_op_a": arg_op_a,
"arg_op_other": arg_op_b,
"raise_errors": raise_errors
}
op_c_mapping = {
"op_a": op_a,
"op_other": op_c,
"arg_op_a": arg_op_a,
"arg_op_other": arg_op_c,
"raise_errors": raise_errors
}
raise_errors_s = raise_errors == "true"
if in_place:
code_arg_op_other1 = self.getPreArgStringTemplate(type=type_op_b).substitute(op_b_mapping)
code_arg_op_other2 = (
self.getPreArgStringTemplate(type=type_op_c).substitute(op_c_mapping) if op_c else "")
if op_c:
expand_code = self.getInPlacePreExpand2Template(type_op_b, type_op_c, raise_errors_s).substitute(
op_b_mapping,
op_other1=op_b,
op_other2=op_c,
arg_op_other1=arg_op_b,
arg_op_other2=arg_op_c)
else:
expand_code = self.getInPlacePreExpand1Template(type_op_b, raise_errors_s).substitute(op_b_mapping)
new_code_pre.append(self.IN_PLACE_PRE_TEMPLATE.substitute(
arg_op_a=arg_op_a,
code_arg_op_other1=code_arg_op_other1,
code_arg_op_other2=code_arg_op_other2,
expand_code=expand_code,
raise_errors=raise_errors))
new_code_pre.append("")
post_code = self.POST_TEMPLATE.substitute(op_b_mapping)
if op_c:
post_code += self.POST_TEMPLATE.substitute(op_c_mapping)
new_code_post.append(post_code)
new_code_post.append("")
else:
if len(dims_kvs) != 0:
code_arg_op_a = self.getPreArgStringTemplate().substitute(arg_op_other=arg_op_a)
code_arg_op_other1 = ""
code_arg_op_other2 = ""
expand_code = ""
for idx, kv in enumerate(dims_kvs):
expand_code += self.OUT_PLACE_PRE_EXPAND_PRE_DIM_TEMPLATE.substitute(
arg_op_a=arg_op_a,
op_dim=kv["op"],
arg_op_dim=kv["arg_op"],
arg_op_dim_value=kv["val"],
idx=idx)
if len(dims_kvs) == 1:
expand_code += self.OUT_PLACE_PRE_EXPAND1_DIM_TEMPLATE.substitute(
arg_op_a=arg_op_a,
arg_op_dim0=dims_kvs[0]["arg_op"])
elif len(dims_kvs) == 2:
expand_code += self.OUT_PLACE_PRE_EXPAND2_DIM_TEMPLATE.substitute(
arg_op_a=arg_op_a,
arg_op_dim0=dims_kvs[0]["arg_op"],
arg_op_dim1=dims_kvs[1]["arg_op"])
else:
expand_code += self.OUT_PLACE_PRE_EXPAND3_DIM_TEMPLATE.substitute(
arg_op_a=arg_op_a,
arg_op_dim0=dims_kvs[0]["arg_op"],
arg_op_dim1=dims_kvs[1]["arg_op"],
arg_op_dim2=dims_kvs[2]["arg_op"])
expand_code += self.getOutPlacePreExpandPostDimTemplate(None, raise_errors_s).substitute(
arg_op_a=arg_op_a,
raise_errors=raise_errors)
post_code = self.POST_TEMPLATE.substitute(arg_op_other=arg_op_a)
else:
code_arg_op_a = self.getPreArgStringTemplate().substitute(arg_op_other=arg_op_a)
code_arg_op_other1 = self.getPreArgStringTemplate(type=type_op_b).substitute(op_b_mapping)
code_arg_op_other2 = (self.getPreArgStringTemplate(type=type_op_c).substitute(op_c_mapping)
if op_c else "")
if op_c:
expand_template = self.getOutPlacePreExpand3Template(None, type_op_b, type_op_c, raise_errors_s)
expand_code = expand_template.substitute(
op_b_mapping,
op_other1=op_b,
op_other2=op_c,
arg_op_other1=arg_op_b,
arg_op_other2=arg_op_c)
else:
expand_code = self.getOutPlacePreExpand2Template(None, type_op_b, raise_errors_s).substitute(
op_b_mapping)
post_code = self.POST_TEMPLATE.substitute(arg_op_other=arg_op_a)
post_code += self.POST_TEMPLATE.substitute(op_b_mapping)
post_code += self.POST_TEMPLATE.substitute(op_c_mapping) if op_c else ""
new_code_pre.append(self.OUT_PLACE_PRE_TEMPLATE.substitute(
code_arg_op_a=code_arg_op_a,
code_arg_op_other1=code_arg_op_other1,
code_arg_op_other2=code_arg_op_other2,
expand_code=expand_code))
new_code_pre.append("")
new_code_post.append(post_code)
new_code_post.append("")
template = new_code_pre + template + new_code_post
return template