| # Owner(s): ["module: onnx"] | |
| import torch | |
| # Autograd funtion that is a replica of the autograd funtion in | |
| # test_utility_funs.py (test_autograd_module_name) | |
| class CustomFunction(torch.autograd.Function): | |
| @staticmethod | |
| def forward(ctx, input): | |
| ctx.save_for_backward(input) | |
| return input.clamp(min=0) | |
| @staticmethod | |
| def backward(ctx, grad_output): | |
| (input,) = ctx.saved_tensors | |
| grad_input = grad_output.clone() | |
| grad_input[input < 0] = 0 | |
| return grad_input |