Kill THNN function auto generation. (#25322)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25322
As far as I can tell, none of these are actually used anymore.
Test Plan: Imported from OSS
Differential Revision: D17097301
Pulled By: gchanan
fbshipit-source-id: 649ee0fd549f6e2a875faef7c32b19c70bb969b6
diff --git a/torch/nn/_functions/thnn/__init__.py b/torch/nn/_functions/thnn/__init__.py
index 7336673..c418d2f 100644
--- a/torch/nn/_functions/thnn/__init__.py
+++ b/torch/nn/_functions/thnn/__init__.py
@@ -1,5 +1,4 @@
_all_functions = []
-from .auto import * # noqa: F401
from .normalization import * # noqa: F401
from .sparse import * # noqa: F401
diff --git a/torch/nn/_functions/thnn/auto.py b/torch/nn/_functions/thnn/auto.py
deleted file mode 100644
index 6da2976..0000000
--- a/torch/nn/_functions/thnn/auto.py
+++ /dev/null
@@ -1,342 +0,0 @@
-from itertools import repeat
-from collections import defaultdict
-
-import torch
-from torch._thnn.utils import parse_header, THNN_H_PATH
-from torch.autograd.function import Function, InplaceFunction
-from torch._thnn import type2backend
-from .auto_double_backwards import double_backwards_fns
-from .auto_symbolic import symbolic_fns
-
-from . import _all_functions
-
-
-def _make_function_class_criterion(class_name, update_output, update_grad_input, acc_grad_parameters,
- double_backwards_fn, symbolic_fn):
- weight_arg_idx = -1
- for i, arg in enumerate(update_output.arguments):
- if arg.name.startswith('weight'):
- weight_arg_idx = i
- break
-
- reduce_arg_idx = -1
- for i, arg in enumerate(update_output.arguments):
- if arg.name == 'reduce':
- reduce_arg_idx = i
- break
-
- buffers_idx = []
- additional_arg_idx = 0
- for arg in update_output.arguments[4:]:
- if not arg.name.startswith('weight') and arg.type == 'THTensor*':
- buffers_idx.append(additional_arg_idx)
- additional_arg_idx += 1
-
- @staticmethod
- def symbolic(*args, **kwargs):
- a = symbolic_fn(*args, **kwargs)
- return a
-
- @staticmethod
- def forward(ctx, input, target, *args):
- ctx._backend = type2backend[input.type()]
- ctx.save_for_backward(input, target)
- if weight_arg_idx >= 0:
- ctx.weight = args[0]
- args = args[1:]
- ctx.additional_args = list(args)
- insert_idx = weight_arg_idx - 4 # state, input, target, output
- ctx.additional_args.insert(insert_idx, ctx.weight)
- else:
- ctx.additional_args = list(args)
-
- ctx.forward_args_count = len(ctx.additional_args)
- for idx in buffers_idx:
- ctx.additional_args.insert(idx, input.new(1))
- output = input.new(1)
- getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, target,
- output, *ctx.additional_args)
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- input, target = ctx.saved_tensors
- # apply returns grad_input, so we need to return Nones for target (1) + 1 for each extra arg passed to forward.
- return ((backward_cls.apply(input, target, grad_output, ctx.additional_args, ctx._backend),) +
- (None,) * (ctx.forward_args_count + 1))
-
- @staticmethod
- def backward_cls_forward(ctx, input, target, grad_output, additional_args_ctx, backend_ctx):
- ctx.additional_args = additional_args_ctx
- ctx._backend = backend_ctx
- ctx.save_for_backward(input, target, grad_output)
- grad_input = grad_output.new().resize_as_(input).zero_()
-
- if reduce_arg_idx >= 0:
- getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
- grad_output, grad_input, *ctx.additional_args)
- return grad_input
-
- getattr(ctx._backend, update_grad_input.name)(ctx._backend.library_state, input, target,
- grad_input, *ctx.additional_args)
- grad_output_expanded = grad_output.view(*repeat(1, grad_input.dim()))
- grad_input.mul_(grad_output_expanded.expand_as(grad_input))
- return grad_input
-
- @staticmethod
- def backward_cls_backward(ctx, *grad_params):
- return double_backwards_fn(ctx, *grad_params)
-
- backward_cls = type(class_name + "Backward", (Function,),
- dict(forward=backward_cls_forward, backward=backward_cls_backward))
- return type(class_name, (Function,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
-
-
-def _find_buffers(args, ignored_args):
- additional_arg_idx = 0
- buffers = []
- for arg in args:
- if arg.name in ignored_args:
- continue
- if arg.type == 'THTensor*':
- buffers.append((additional_arg_idx, arg.name))
- additional_arg_idx += 1
- return buffers
-
-
-def _make_function_class(class_name, update_output, update_grad_input, acc_grad_parameters,
- double_backwards_fn, symbolic_fn):
- def has_argument(fn, name):
- for arg in fn.arguments:
- if arg.name == name:
- return True
- return False
- save_output = has_argument(update_grad_input, 'output')
-
- param_args = {'weight', 'bias'}
- ignored_args = {'weight', 'bias', 'gradWeight', 'gradBias', 'output'}
- expected_params = [arg for arg in update_output.arguments[3:]
- if arg.name in param_args]
- buffers = {}
- buffers['update_output'] = _find_buffers(update_output.arguments[3:],
- ignored_args)
- buffers['update_grad_input'] = _find_buffers(
- update_grad_input.arguments[4:], ignored_args)
- if acc_grad_parameters is not None:
- buffers['acc_grad_parameters'] = _find_buffers(
- acc_grad_parameters.arguments[3:], ignored_args)
-
- # This assumes that only the last argument can be
- # an inplace flag
- is_inplace = update_output.arguments[-1].name == 'inplace'
-
- def _initialize_buffers(ctx, fn_name):
- additional_args = ctx.additional_args
- for idx, name in buffers[fn_name]:
- # TODO: some buffers are necessary only for update output and can be
- # freed right afterwards
- buffer = ctx.buffers[name]
- additional_args = additional_args[:idx] + [buffer] + additional_args[idx:]
- return tuple(additional_args)
-
- @staticmethod
- def symbolic(*args, **kwargs):
- return symbolic_fn(*args, **kwargs)
-
- @staticmethod
- def forward(ctx, input, *params):
- ctx._backend = type2backend[input.type()]
-
- ctx.additional_args = []
- tensor_param_list = []
- for param in params:
- if isinstance(param, torch.Tensor):
- if type(param) != type(input):
- raise RuntimeError("input type ({}) doesn't match the type of "
- "a parameter tensor ({})".format(torch.typename(input),
- torch.typename(param)))
- tensor_param_list.append(param)
- else:
- ctx.additional_args.append(param)
-
- tensor_params = tuple(tensor_param_list)
- if is_inplace:
- ctx.inplace = params[-1]
- # Allocate temporary buffers and insert them into additional_args
- ctx.buffers = defaultdict(type(input))
- additional_args = _initialize_buffers(ctx, 'update_output')
-
- # Fill in optional params with None
- args = tensor_params
- for i in range(len(params), len(expected_params)):
- param = expected_params[i]
- if param.is_optional:
- args += (None,)
- else:
- raise ValueError("missing required argument '%s'" % param.name)
-
- args += tuple(additional_args)
-
- # If the module is working in-place its output will be set to the
- # same storage as input, but its tensor won't be dirty.
- if is_inplace and ctx.inplace:
- ctx.mark_dirty(input)
- output = input
- else:
- output = input.new()
-
- if save_output:
- ctx.save_for_backward(input, output, *tensor_params)
- else:
- ctx.save_for_backward(input, *tensor_params)
-
- if not ctx.requires_grad:
- del ctx.buffers
-
- getattr(ctx._backend, update_output.name)(ctx._backend.library_state, input, output, *args)
- return output
-
- @staticmethod
- def backward(ctx, grad_output):
- t = ctx.saved_tensors
- input, tensor_params = t[0], t[1:]
- # Some notes on this function call:
- # 1) We need to pass params as *params so they are unwrapped correctly in backward_cls_forward.
- # 2) apply returns the grad_input / grad_tensor_params, so we need to append Nones equal to the number
- # of non tensor_params, i.e. the additional_args
- # 3) it may be simpler to recalculate some of these parameters (e.g. ctx._backend) in backward_cls_forward?
-
- return (backward_cls.apply(input, grad_output, ctx.additional_args, ctx._backend, ctx.buffers, *tensor_params) +
- (None,) * len(ctx.additional_args))
-
- @staticmethod
- def backward_cls_forward(ctx, input, grad_output, additional_args_ctx, backend_ctx, buffers_ctx, *params):
- ctx.additional_args = additional_args_ctx
- ctx.buffers = buffers_ctx
- ctx._backend = backend_ctx
- ctx.save_for_backward(input, grad_output, *params)
- if save_output:
- output = params[0]
- params = params[1:]
-
- grad_params = tuple(None for p in params)
- grad_input_tuple = (None,)
- if is_inplace:
- ctx.inplace = additional_args_ctx[-1]
-
- if ctx.needs_input_grad[0]:
- additional_args = _initialize_buffers(ctx, 'update_grad_input')
- if save_output:
- additional_args = (output,) + additional_args
-
- if is_inplace and ctx.inplace:
- assert additional_args[-1] is True
- tmp_args = list(additional_args)
- tmp_args[-1] = False
- additional_args = tuple(tmp_args)
- grad_input = input.new(input.size())
- params_without_bias = params if len(params) < 2 else params[:1]
- update_grad_input_fn = getattr(ctx._backend, update_grad_input.name)
- gi_args = params_without_bias + additional_args
- update_grad_input_fn(ctx._backend.library_state, input, grad_output, grad_input, *gi_args)
- grad_input_tuple = (grad_input,)
-
- if acc_grad_parameters and any(ctx.needs_input_grad[1:]):
- additional_args = _initialize_buffers(ctx, 'acc_grad_parameters')
- grad_params = tuple(p.new(p.size()).zero_() for p in params)
- appended_grads = len(expected_params) - len(grad_params)
- grad_params += (None,) * appended_grads
- acc_grad_parameters_fn = getattr(ctx._backend, acc_grad_parameters.name)
- param_args = grad_params + additional_args + (1,)
- acc_grad_parameters_fn(ctx._backend.library_state, input, grad_output, *param_args)
- if appended_grads:
- grad_params = grad_params[:-appended_grads]
-
- return grad_input_tuple + grad_params
-
- @staticmethod
- def backward_cls_backward(ctx, *grad_params):
- return double_backwards_fn(ctx, *grad_params)
-
- base_class = Function if not is_inplace else InplaceFunction
- backward_cls = type(class_name + "Backward", (base_class,), dict(forward=backward_cls_forward,
- backward=backward_cls_backward))
-
- return type(class_name, (base_class,), dict(forward=forward, backward=backward, symbolic=symbolic)), backward_cls
-
-
-def _generate_function_classes(scope_dict):
- global function_list, function_by_name
- function_list = parse_header(THNN_H_PATH)
- function_by_name = {fn.name: fn for fn in function_list}
- classes_to_generate = {fn.name.partition('_')[0] for fn in function_list}
- exceptions = {
- 'Linear',
- 'IndexLinear',
- 'SpatialConvolutionMM',
- 'TemporalConvolution',
- 'SpatialMaxUnpooling',
- 'VolumetricMaxUnpooling',
- 'VolumetricConvolutionMM',
- 'TemporalMaxPooling',
- 'BatchNormalization',
- 'LookupTable',
- 'LookupTableBag',
- 'PReLU',
- 'RReLU',
- 'SoftMax',
- 'LogSoftMax',
- 'GRUFused',
- 'LSTMFused',
- 'unfolded',
- }
- name_remap = {
- 'TemporalConvolution': 'Conv1d',
- 'HardTanh': 'Hardtanh',
- 'HardShrink': 'Hardshrink',
- 'SoftPlus': 'Softplus',
- 'SoftShrink': 'Softshrink',
- 'MSECriterion': 'MSELoss',
- 'AbsCriterion': 'L1Loss',
- 'BCECriterion': 'BCELoss',
- 'ClassNLLCriterion': 'NLLLoss',
- 'DistKLDivCriterion': 'KLDivLoss',
- 'SpatialClassNLLCriterion': 'NLLLoss2d',
- 'MultiLabelMarginCriterion': 'MultiLabelMarginLoss',
- 'MultiMarginCriterion': 'MultiMarginLoss',
- 'SmoothL1Criterion': 'SmoothL1Loss',
- 'SoftMarginCriterion': 'SoftMarginLoss',
- }
-
- classes_to_generate -= exceptions
- for fn in classes_to_generate:
- update_output = function_by_name[fn + '_updateOutput']
- update_grad_input = function_by_name[fn + '_updateGradInput']
- acc_grad_parameters = function_by_name.get(fn + '_accGradParameters')
- class_name = name_remap.get(fn, fn)
- double_backwards_fn = double_backwards_fns.get(class_name)
- if double_backwards_fn is None:
- def make_default_double_backwards_fn(class_name):
- def default_double_backwards_fn(ctx, *grad_params):
- raise ValueError(class_name + " can only be differentiated once.")
- return default_double_backwards_fn
- double_backwards_fn = make_default_double_backwards_fn(class_name)
- symbolic_fn = symbolic_fns.get(class_name)
- # This has to call a function to retain correct references to functions
- is_criterion_fn = 'Criterion' in fn
- if is_criterion_fn:
- cls, backward_cls = _make_function_class_criterion(class_name, update_output,
- update_grad_input, acc_grad_parameters,
- double_backwards_fn, symbolic_fn)
- else:
- cls, backward_cls = _make_function_class(class_name, update_output,
- update_grad_input, acc_grad_parameters,
- double_backwards_fn, symbolic_fn)
- scope_dict[class_name] = cls
- scope_dict[backward_cls.__name__] = backward_cls
- if not class_name.startswith('_'):
- _all_functions.append(cls)
- _all_functions.append(backward_cls)
-
-
-_generate_function_classes(locals())
diff --git a/torch/nn/_functions/thnn/auto_double_backwards.py b/torch/nn/_functions/thnn/auto_double_backwards.py
deleted file mode 100644
index 54ce0b5..0000000
--- a/torch/nn/_functions/thnn/auto_double_backwards.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import torch
-
-
-def elu_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input, grad_output = t[0], t[1]
- alpha = ctx.additional_args[0]
-
- negative_mask = (input < 0).type_as(ggI)
- exp_alpha = input.exp() * alpha * negative_mask
- gI = ggI * grad_output * exp_alpha
-
- non_negative_mask = (input >= 0).type_as(ggI)
- ggO = ggI * (exp_alpha + non_negative_mask)
- return gI, ggO, None, None, None, None
-
-
-def gatedlinear_double_backwards(ctx, ggI):
- input, gO = ctx.saved_tensors
- dim = ctx.additional_args[0]
-
- input_size = input.size(dim) // 2
-
- first_half = input.narrow(dim, 0, input_size)
- second_half = input.narrow(dim, input_size, input_size)
- sig_second_half = second_half.sigmoid()
- one_sub_sig_second_half = 1 - sig_second_half
- sig_one_sub_sig = sig_second_half * one_sub_sig_second_half
-
- ggI_first_half = ggI.narrow(dim, 0, input_size)
- ggI_second_half = ggI.narrow(dim, input_size, input_size)
- ggI_second_half_times_first_half = ggI_second_half * first_half
-
- gI_first_half = ggI_second_half * gO * sig_one_sub_sig
- second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig
- gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig
- gI = torch.cat((gI_first_half, gI_second_half), dim)
-
- ggO = ggI_first_half * sig_second_half + ggI_second_half_times_first_half * sig_one_sub_sig
-
- return gI, ggO, None, None, None
-
-
-def hardshrink_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input = t[0]
- lambd = ctx.additional_args[0]
- gI = None
-
- mask = torch.zeros_like(input).masked_fill_(input > lambd, 1).masked_fill_(input < -lambd, 1)
- ggO = ggI * mask
-
- return gI, ggO, None, None, None
-
-
-def hardtanh_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input, grad_output = t[0], t[1]
- min_val, max_val = ctx.additional_args[0:2]
-
- max_mask = input <= max_val
- min_mask = input <= min_val
- gI = torch.zeros_like(ggI)
- ggO = ggI * (max_mask - min_mask).type_as(grad_output)
- return gI, ggO, None, None, None
-
-
-def leakyrelu_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input = t[0]
- negative_slope = ctx.additional_args[0]
-
- gI = torch.zeros_like(ggI)
- input_lt_0 = (input < 0).type_as(ggI)
- input_ge_0 = (input >= 0).type_as(ggI)
- ggO = ggI * (input_lt_0 * negative_slope + input_ge_0)
- return gI, ggO, None, None, None
-
-
-def logsigmoid_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- # maybe more efficient in terms of output, but save_output is False
- input, gO = t[0], t[1]
-
- exp_input = input.exp()
- exp_input_plus_1 = exp_input + 1
- gI = ggI * gO * -1 * exp_input / (exp_input_plus_1.pow(2))
- ggO = ggI / exp_input_plus_1
-
- return gI, ggO, None, None, None, None
-
-
-def softplus_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input, gO, output = t[0], t[1], t[2]
- beta, threshold = ctx.additional_args[0], ctx.additional_args[1]
-
- input_beta = input * beta
- above_threshold = torch.zeros_like(ggI).masked_fill_(input_beta > threshold, 1)
- below_threshold = torch.zeros_like(ggI).masked_fill_(input_beta <= threshold, 1)
-
- exp_output_beta = (output * beta).exp()
- first_deriv = (exp_output_beta - 1) / exp_output_beta
- first_deriv_below_threshold = first_deriv * below_threshold
-
- gI = ggI * gO * first_deriv_below_threshold * beta / exp_output_beta
- ggO = ggI * (above_threshold + first_deriv_below_threshold)
-
- return gI, ggO, None, None, None, None
-
-
-def softshrink_double_backwards(ctx, ggI):
- return hardshrink_double_backwards(ctx, ggI)
-
-
-def threshold_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- input = t[0]
- threshold, value = ctx.additional_args[0:2]
-
- gI = torch.zeros_like(ggI)
- input_gt_threshold = (input > threshold).type_as(ggI)
- ggO = ggI * input_gt_threshold
- return gI, ggO, None, None, None
-
-
-def klddivloss_double_backwards(ctx, ggI):
- size_average = ctx.additional_args[0]
- input, target, gO = ctx.saved_tensors
- div_factor = input.nelement() if size_average else 1
-
- gI = None
- ggO = (ggI * target).sum() / -div_factor
-
- return gI, None, ggO, None, None
-
-
-def l1loss_double_backwards(ctx, ggI):
- size_average = ctx.additional_args[0]
- input, target, grad_output = ctx.saved_tensors
- gI = torch.zeros_like(ggI)
-
- positive_mask = (input > target).type_as(ggI)
- negative_mask = (input < target).type_as(ggI)
- ggO = (ggI * (positive_mask - negative_mask)).sum()
- if size_average:
- ggO = ggO / input.nelement()
- return gI, None, ggO, None, None
-
-
-def mseloss_double_backwards(ctx, ggI):
- size_average = ctx.additional_args[0]
- reduce = ctx.additional_args[1]
- input, target, gO = ctx.saved_tensors
- div_factor = input.nelement() if size_average and reduce else 1
-
- gI = ggI * (gO * 2. / div_factor).expand_as(input)
- if reduce:
- ggO = (ggI * (input - target)).sum() * (2. / div_factor)
- else:
- ggO = (ggI * (input - target)) * 2.
-
- return gI, None, ggO, None, None
-
-
-def nllloss_double_backwards(ctx, ggI):
- t = ctx.saved_tensors
- target = t[1]
- weights = ctx.additional_args[1]
- size_average = ctx.additional_args[0]
- ignore_index = ctx.additional_args[3]
- reduce = ctx.additional_args[4]
-
- gI = None
-
- # can't scatter/gather on indices outside of range, let's just put them in range
- # and 0 out the weights later (so it doesn't matter where in range we put them)
- target_mask = target == ignore_index
- safe_target = target.clone()
- safe_target.masked_fill_(target_mask, 0)
-
- if weights.dim() == 0:
- weights_to_scatter = torch.ones_like(safe_target)
- else:
- weights_maybe_resized = weights
- while weights_maybe_resized.dim() < target.dim():
- weights_maybe_resized = weights_maybe_resized.unsqueeze(1)
-
- weights_maybe_resized = weights_maybe_resized.expand(weights.size()[0:1] + target.size()[1:])
- weights_to_scatter = weights_maybe_resized.gather(0, safe_target)
-
- weights_to_scatter.masked_fill_(target_mask, 0)
- divisor = weights_to_scatter.sum() if size_average and reduce else 1
- weights_to_scatter = -1 * weights_to_scatter / divisor
- zeros = torch.zeros_like(ggI)
- mask = zeros.scatter_(1, safe_target.unsqueeze(1), weights_to_scatter.unsqueeze(1))
-
- if reduce:
- ggO = (ggI * mask).sum()
- else:
- ggO = (ggI * mask).sum(dim=1)
-
- return gI, None, ggO, None, None, None
-
-
-def smoothl1loss_double_backwards(ctx, ggI):
- size_average = ctx.additional_args[0]
- input, target, gO = ctx.saved_tensors
- div_factor = input.nelement() if size_average else 1
-
- input_sub_target = input - target
- small_error_mask = (input_sub_target.abs() < 1)
- large_error_mask = (small_error_mask == 0)
- large_error_pos_mask = (((input_sub_target > 0) + large_error_mask) == 2).type_as(ggI)
- large_error_neg_mask = (((input_sub_target <= 0) + large_error_mask) == 2).type_as(ggI)
- small_error_mask = small_error_mask.type_as(ggI)
-
- gI = small_error_mask * ggI * gO / div_factor
- ggO = (ggI * (input_sub_target * small_error_mask + large_error_pos_mask - large_error_neg_mask)).sum() / div_factor
-
- return gI, None, ggO, None, None, None
-
-
-def softmarginloss_double_backwards(ctx, ggI):
- size_average = ctx.additional_args[0]
- input, target, gO = ctx.saved_tensors
- div_factor = input.nelement() if size_average else 1
-
- t0 = (1 + (-target * input).exp()).pow(-1)
- t1 = (-target * (-target * input).exp())
- first_deriv = t0 * t1
-
- gI = -1 * gO * ggI / div_factor * (first_deriv.pow(2) + first_deriv * target)
- ggO = (ggI * first_deriv).sum() / div_factor
-
- return gI, None, ggO, None, None, None
-
-
-double_backwards_fns = {
- 'ELU': elu_double_backwards,
- 'GatedLinear': gatedlinear_double_backwards,
- 'Hardshrink': hardshrink_double_backwards,
- 'Hardtanh': hardtanh_double_backwards,
- 'LeakyReLU': leakyrelu_double_backwards,
- 'LogSigmoid': logsigmoid_double_backwards,
- 'Softplus': softplus_double_backwards,
- 'Softshrink': softshrink_double_backwards,
- 'Threshold': threshold_double_backwards,
- 'KLDivLoss': klddivloss_double_backwards,
- 'L1Loss': l1loss_double_backwards,
- 'MSELoss': mseloss_double_backwards,
- 'NLLLoss': nllloss_double_backwards,
- 'NLLLoss2d': nllloss_double_backwards,
- 'SmoothL1Loss': smoothl1loss_double_backwards,
- 'SoftMarginLoss': softmarginloss_double_backwards,
-}
diff --git a/torch/nn/_functions/thnn/auto_symbolic.py b/torch/nn/_functions/thnn/auto_symbolic.py
deleted file mode 100644
index 4bbad81..0000000
--- a/torch/nn/_functions/thnn/auto_symbolic.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from torch.autograd._functions.utils import prepare_onnx_paddings
-
-
-def reflectionpad_symbolic(g, input, *params):
- mode = "reflect"
- paddings = prepare_onnx_paddings(len(input.type().sizes()), params)
- return g.op("Pad", input, pads_i=paddings, mode_s=mode)
-
-
-def replicationpad_symbolic(g, input, *params):
- mode = "edge"
- paddings = prepare_onnx_paddings(len(input.type().sizes()), params)
- return g.op("Pad", input, pads_i=paddings, mode_s=mode)
-
-
-symbolic_fns = {
- 'ReflectionPad1d': reflectionpad_symbolic,
- 'ReflectionPad2d': reflectionpad_symbolic,
- 'ReplicationPad1d': replicationpad_symbolic,
- 'ReplicationPad2d': replicationpad_symbolic,
- 'ReplicationPad3d': replicationpad_symbolic,
-}