| import torch |
| from torch.jit import BatchTensor |
| |
| |
| # TODO: there are some commented raise statements |
| # when we support rasie exception in script, we want to check them |
| @torch.jit.script |
| def batch_tanh(data, mask, dims): |
| data = torch.tanh(data) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_sigmoid(data, mask, dims): |
| data = torch.sigmoid(data) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_relu(data, mask, dims): |
| data = torch.relu(data) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_neg(data, mask, dims): |
| data = torch.neg(data) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_neg_scalar(data): |
| return torch.neg(data) |
| |
| |
| @torch.jit.script |
| def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_): |
| alpha = float(alpha_) |
| data = torch.add(data1, data2, alpha=alpha) |
| mask = mask1 * mask2 |
| dims = dims1.__or__(dims2) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_add_scalar(data, mask, dims, other, alpha_): |
| alpha = float(alpha_) |
| data = torch.add(data, other.type_as(data), alpha=alpha) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_): |
| alpha = float(alpha_) |
| data = torch.sub(data1, data2, alpha=alpha) |
| mask = mask1 * mask2 |
| dims = dims1.__or__(dims2) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_sub_scalar(data1, data2): |
| return data1 - data2 |
| |
| |
| @torch.jit.script |
| def batch_mul(data1, mask1, dims1, data2, mask2, dims2): |
| data = torch.mul(data1, data2) |
| mask = mask1 * mask2 |
| dims = dims1.__or__(dims2) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_mul_scalar(data1, data2): |
| return data1 * data2 |
| |
| |
| @torch.jit.script |
| def batch_div(data, mask, dims, other): # div(batchtensor, scalar) |
| data = torch.div(data, other) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_mm(data1, mask1, dims1, data2, mask2, dims2): |
| data1 = data1 * mask1.type_as(data1) |
| data2 = data2 * mask2.type_as(data2) |
| data = torch.bmm(data1, data2) |
| mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1)) |
| dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)])) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_matmul(data1, mask1, dims1, data2, mask2, dims2): |
| d1 = data1.dim() - 1 |
| d2 = data2.dim() - 1 |
| data1 = data1 * mask1.type_as(data1) |
| data2 = data2 * mask2.type_as(data2) |
| if d1 == 1: |
| data1 = data1.unsqueeze(-2) |
| if d2 == 1: |
| data2 = data2.unsqueeze(-1) |
| data = torch.bmm(data1, data2) |
| mask = mask1 |
| dims = dims1 |
| if d1 == 1 and d2 == 1: |
| # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all(): |
| # raise ValueError("cannot contract non-matching dimensions") |
| data = data.squeeze(-1).squeeze(-1) |
| mask = mask1.narrow(1, 0, 1).squeeze(-1) |
| dims = dims1[:0] # empty tensor |
| if d1 == 2 and d2 == 1: |
| # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all(): |
| # raise ValueError("cannot contract non-matching dimensions") |
| data = data.squeeze(-1) |
| mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1) |
| dims = dims1[:1] |
| elif d1 == 1 and d2 == 2: |
| # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all(): |
| # raise ValueError("cannot contract non-matching dimensions") |
| data = data.squeeze(-2) |
| mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2) |
| dims = dims2[1:dims2.size(0)] |
| elif d1 == 2 and d2 == 2: |
| # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all(): |
| # raise ValueError("cannot contract non-matching dimensions") |
| mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1)) |
| dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)])) |
| # else: |
| # raise NotImplementedError("matmul not implemented with batches of 3+D tensors") |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_select(data, mask, dims, dim_, index_): |
| dim = int(dim_) |
| index = int(index_) |
| # if dim == 0: |
| # raise ValueError("Cannot select 0 dim in BatchTensor") |
| data = data.select(dim, index) |
| if bool(dims[dim - 1]): |
| mask = mask.select(dim, index) |
| else: |
| mask = mask.select(dim, 0) |
| dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_fmod(data, mask, dims, other_): |
| other = int(other_) |
| data = torch.fmod(data, other) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_zeros_like(data, mask, dims): |
| res_data = torch.zeros_like(data) |
| return res_data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims): |
| dim = int(dim_) |
| # if dim == 0: |
| # raise ValueError("Cannot index_select along 0 dim in BatchTensor") |
| batch_size = data.size(0) # TODO maybe index_mask will be used at some point |
| res_data = torch.zeros([0]) |
| res_mask = torch.zeros([0]) |
| for i in range(batch_size): |
| d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0) |
| if bool(dims[dim - 1]): |
| m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0) |
| else: |
| m = mask[i].unsqueeze(0) |
| if i == 0: |
| res_data = d |
| res_mask = m |
| else: |
| res_data = torch.cat((res_data, d), 0) |
| res_mask = torch.cat((res_mask, m), 0) |
| return res_data, res_mask, dims |
| |
| |
| @torch.jit.script |
| def batch_view_as(data, mask, dims, data1, mask1, dims1): |
| # if data.size(0) != data1.size(0): |
| # raise ValueError("In view_as, tensor and target tensor should have the same batch_size") |
| # if not torch.equal(dims, dims1): |
| # raise ValueError("In batched view_as, dims and target dims should be the same") |
| data = data.view_as(data1) |
| mask = mask.view_as(mask1) |
| dims = dims1 |
| return data, mask, dims |
| |
| |
| # assume data, data1, data2 have same size |
| @torch.jit.script |
| def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2): |
| data = data * mask.type_as(data) |
| cond_data = data |
| cond_mask = data |
| if data.dim() == 1: |
| for _ in range(data1.dim() - 1): |
| data = data.unsqueeze(data.dim()) |
| cond_data = data.expand_as(data1) |
| cond_mask = data.expand_as(mask1) |
| res_data = torch.where(cond_data, data1, data2) |
| res_mask = torch.where(cond_mask, mask1, mask2) |
| res_dims = dims1.__or__(dims2) |
| return res_data, res_mask, res_dims |
| |
| |
| @torch.jit.script |
| def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2): |
| cond = torch.zeros([1], dtype=torch.uint8) |
| res_data = torch.where(cond, data1, data2) |
| res_mask = torch.where(cond, mask1, mask2) |
| res_dims = torch.where(cond, dims1, dims2) |
| return res_data, res_mask, res_dims |
| |
| |
| @torch.jit.script |
| def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims): |
| data = torch.where(new_mask, new_data, batch_data) |
| return data, new_mask, new_dims # TODO: consider whether return new_mask and new_dims |
| |
| |
| @torch.jit.script |
| def batch_any(data, mask, dims): |
| return torch.gt(torch.sum(data * mask), 0) |
| |
| |
| @torch.jit.script |
| def batch_type_as(data, mask, dims, data1, mask1, dims1): |
| return data.type_as(data1), mask, dims |
| |
| |
| @torch.jit.script |
| def batch_gt(data, mask, dims, data1, mask1, dims1): |
| return torch.gt(data, data1), mask * mask1, dims.__or__(dims1) |
| |
| |
| @torch.jit.script |
| def batch_gt_scalar(data1, data2): |
| return torch.gt(data1, data2) |
| |
| |
| @torch.jit.script |
| def batch_gt_one_scalar(data, mask, dims, other_): |
| other = float(other_) |
| return torch.gt(data, other), mask, dims |
| |
| |
| @torch.jit.script |
| def batch_lt(data, mask, dims, data1, mask1, dims1): |
| return torch.lt(data, data1), mask * mask1, dims.__or__(dims1) |
| |
| |
| @torch.jit.script |
| def batch_eq(data, mask, dims, data1, mask1, dims1): |
| return torch.eq(data, data1), mask * mask1, dims.__or__(dims1) |
| |
| |
| @torch.jit.script |
| def batch_size(data, mask, dims, dim_): |
| dim = int(dim_) |
| return data.size(dim) |
| |
| |
| @torch.jit.script |
| def batch_dim(data, mask, dims): |
| return data.dim() |
| |
| |
| @torch.jit.script |
| def batch_squeeze(data, mask, dims, dim_): |
| if int(dim_) < 0: |
| dim_ += data.dim() |
| dim = int(dim_) |
| # if dim == 0: |
| # raise ValueError("cannot do squeeze along batch_dim") |
| data = data.squeeze(dim) |
| mask = mask.squeeze(dim) |
| dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_unsqueeze(data, mask, dims, dim_): |
| if int(dim_) < 0: |
| dim_ += data.dim() + 1 |
| dim = int(dim_) |
| # if dim == 0: |
| # raise ValueError("cannot do unsqueeze along batch_dim") |
| data = data.unsqueeze(dim) |
| mask = mask.unsqueeze(dim) |
| dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)])) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_argmax(data, mask, dims, dim_, keepdim_): |
| dim = int(dim_) |
| keepdim = bool(keepdim_) |
| # if dim == 0: |
| # raise ValueError("cannot do argmax along batch_dim") |
| batch_size = data.size(0) |
| res_data = torch.zeros([0]) |
| for i in range(batch_size): |
| if bool(dims[dim - 1]): |
| if dim - 1 != 0: |
| m = mask[i].transpose(0, dim - 1) |
| else: |
| m = mask[i] |
| valid_num = m.sum(0, keepdim=True) |
| while(valid_num.dim() >= 1): |
| valid_num = valid_num[0] |
| d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num)) |
| else: |
| d = data[i].unsqueeze(0) |
| d = d.argmax(dim, keepdim) |
| if i == 0: |
| res_data = d |
| else: |
| res_data = torch.cat([res_data, d], 0) |
| if keepdim: |
| mask = mask |
| else: |
| mask = mask.select(dim, 0) |
| dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)])) |
| return res_data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_): |
| k = int(k_) |
| dim = int(dim_) |
| largest = bool(largest_) |
| sorted = bool(sorted_) |
| # if dim == 0: |
| # raise ValueError("cannot do topk along batch_dim") |
| batch_size = data.size(0) |
| res_data = torch.zeros([0]) |
| res_index = torch.zeros([0]) |
| for i in range(batch_size): |
| if bool(dims[dim - 1]): |
| if dim - 1 != 0: |
| m = mask[i].transpose(0, dim - 1) |
| else: |
| m = mask[i] |
| valid_num = m.sum(0, keepdim=True) |
| while(valid_num.dim() >= 1): |
| valid_num = valid_num[0] |
| d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num)) |
| else: |
| d = data[i].unsqueeze(0) |
| d, idx = d.topk(k, dim, largest, sorted) |
| if i == 0: |
| res_data = d |
| res_index = idx |
| else: |
| res_data = torch.cat([res_data, d], 0) |
| res_index = torch.cat([res_index, idx], 0) |
| if bool(dims[dim - 1]): |
| mask = mask.narrow(dim, 0, k) |
| return res_data, mask, dims, res_index, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_softmax(data, mask, dims, dim_): |
| dim = int(dim_) |
| # if dim == 0: |
| # raise ValueError("cannot do softmax along batch_dim") |
| batch_size = data.size(0) |
| max_len = data.size(dim) |
| res_data = torch.zeros([0]) |
| for i in range(batch_size): |
| if bool(dims[dim - 1]): |
| if dim - 1 != 0: |
| m = mask[i].transpose(0, dim - 1) |
| else: |
| m = mask[i] |
| valid_num = m.sum(0, keepdim=True) |
| while(valid_num.dim() >= 1): |
| valid_num = valid_num[0] |
| valid_num = int(valid_num) |
| d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim) |
| if valid_num < max_len: |
| d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim) |
| else: |
| d = data[i].unsqueeze(0).softmax(dim) |
| if i == 0: |
| res_data = d |
| else: |
| res_data = torch.cat([res_data, d], 0) |
| return res_data, mask, dims |
| |
| |
| # size argument in dynamic dimension has to be -1 |
| # in static dimension, size has to be specified, -1 is not supported |
| @torch.jit.script |
| def batch_view(data, mask, dims, sizes): |
| batch_size = data.size(0) |
| # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1): |
| # raise "first dim in view must be 1, -1, or batch size" |
| # for i in range(dims.size(0)): |
| # if dims[0] == 1 and sizes[i + 1] != -1: |
| # raise "size argument in dynamic dimension has to be -1" |
| sizes = sizes.type_as(torch.ones([1], dtype=torch.int)) |
| data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0) |
| data_sizes = data_sizes_._tensor_to_list() |
| res_data = data.view(data_sizes) |
| mask_sizes_ = data_sizes_.narrow(0, 0, 1) |
| res_dims = data_sizes_.narrow(0, 0, 1) |
| for i_ in range(sizes.size(0) - 1): |
| i = i_ + 1 |
| if bool(sizes[i] == -1): |
| cur_size_ = mask.size(i) |
| cur_dim = 1 |
| else: |
| cur_size_ = 1 |
| cur_dim = 0 |
| mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_]) |
| res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim]) |
| mask_sizes = mask_sizes_._tensor_to_list() |
| res_mask = mask.view(mask_sizes) |
| return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims) |
| |
| |
| @torch.jit.script |
| def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_): |
| dim = int(dim_) |
| data = torch.cat([data1, data2], dim) |
| if bool(dims1[dim - 1]): |
| mask = torch.cat([mask1, mask2], dim) |
| else: |
| mask = mask1 |
| return data, mask, dims1 |
| |
| |
| @torch.jit.script |
| def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_): |
| dim = int(dim_) |
| data = torch.cat([data1, data2, data3], dim) |
| if bool(dims1[dim - 1]): |
| mask = torch.cat([mask1, mask2, mask3], dim) |
| else: |
| mask = mask1 |
| return data, mask, dims1 |
| |
| |
| @torch.jit.script |
| def batch_narrow(data, mask, dims, dimension_, start_, length_): |
| dimension = int(dimension_) |
| start = int(start_) |
| length = int(length_) |
| # if dimension == 0: |
| # raise ValueError("cannot do narrow along batch_dim") |
| data = data.narrow(dimension, start, length) |
| if bool(dims[dimension - 1]): |
| mask = mask.narrow(dimension, start, length) |
| else: |
| mask = mask.narrow(dimension, 0, 1) |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_sum(data, mask, dims): |
| data = data * mask.type_as(data) |
| for _ in range(dims.size(0)): |
| data = data.sum(1) |
| mask = torch.ones([data.size(0)], dtype=torch.uint8) |
| dims = dims[:0] # empty tensor |
| return data, mask, dims |
| |
| |
| @torch.jit.script |
| def batch_from_scalar_tensor(data): |
| data = data.unsqueeze(0) |
| mask = torch.ones([1], dtype=torch.uint8) |
| dims = torch.zeros([0], dtype=torch.uint8) |
| return data, mask, dims |
| |
| torch.register_batch_operator("tanh", batch_tanh.graph) |
| torch.register_batch_operator("sigmoid", batch_sigmoid.graph) |
| torch.register_batch_operator("relu", batch_relu.graph) |
| torch.register_batch_operator("neg", batch_neg.graph) |
| torch.register_batch_operator("neg", batch_neg_scalar.graph) |
| torch.register_batch_operator("add", batch_add.graph) |
| torch.register_batch_operator("add", batch_add_scalar.graph) |
| torch.register_batch_operator("sub", batch_sub.graph) |
| torch.register_batch_operator("sub", batch_sub_scalar.graph) |
| torch.register_batch_operator("mul", batch_mul.graph) |
| torch.register_batch_operator("mul", batch_mul_scalar.graph) |
| torch.register_batch_operator("div", batch_div.graph) |
| torch.register_batch_operator("matmul", batch_matmul.graph) |
| torch.register_batch_operator("mm", batch_mm.graph) |
| torch.register_batch_operator("fmod", batch_fmod.graph) |
| torch.register_batch_operator("zeros_like", batch_zeros_like.graph) |
| torch.register_batch_operator("select", batch_select.graph) |
| torch.register_batch_operator("index_select", batch_index_select.graph) |
| torch.register_batch_operator("view_as", batch_view_as.graph) |
| torch.register_batch_operator("where", batch_where.graph) |
| torch.register_batch_operator("where", batch_where_scalar.graph) |
| torch.register_batch_operator("update", batch_update.graph) |
| torch.register_batch_operator("any", batch_any.graph) |
| torch.register_batch_operator("type_as", batch_type_as.graph) |
| torch.register_batch_operator("gt", batch_gt.graph) |
| torch.register_batch_operator("gt", batch_gt_scalar.graph) |
| torch.register_batch_operator("gt", batch_gt_one_scalar.graph) |
| torch.register_batch_operator("lt", batch_lt.graph) |
| torch.register_batch_operator("eq", batch_eq.graph) |
| torch.register_batch_operator("size", batch_size.graph) |
| torch.register_batch_operator("dim", batch_dim.graph) |
| torch.register_batch_operator("squeeze", batch_squeeze.graph) |
| torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph) |
| torch.register_batch_operator("argmax", batch_argmax.graph) |
| torch.register_batch_operator("topk", batch_topk.graph) |
| torch.register_batch_operator("softmax", batch_softmax.graph) |
| torch.register_batch_operator("view", batch_view.graph) |
| torch.register_batch_operator("cat", batch_cat2.graph) |
| torch.register_batch_operator("cat", batch_cat3.graph) |
| torch.register_batch_operator("narrow", batch_narrow.graph) |
| torch.register_batch_operator("sum", batch_sum.graph) |
| torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph) |