blob: 29acc92186994dcac3c2e7dc47bd8e2e25230b2b [file] [log] [blame]
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <torch/library.h>
#include <c10/util/irange.h>
namespace {
void functionalizeFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatchKeySet, torch::jit::Stack* stack) {
const auto& schema = op.schema();
TORCH_INTERNAL_ASSERT(!schema.hasAnyAliasInfo(), "mutating and aliasing ops should all have codegen'd kernels");
const auto num_arguments = schema.arguments().size();
const auto arguments_begin = stack->size() - num_arguments;
auto arguments = torch::jit::last(stack, num_arguments);
auto any_functional_inputs = false;
auto any_tensor_inputs = false;
for (uint64_t idx = 0; idx < num_arguments; ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
any_tensor_inputs = true;
auto t = ivalue.toTensor();
if (at::functionalization::impl::isFunctionalTensor(t)) {
any_functional_inputs = true;
at::functionalization::impl::sync(t);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(t));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isTensorList()) {
any_tensor_inputs = true;
auto tensors = ivalue.toTensorList();
if (at::functionalization::impl::isFunctionalTensor(tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(tensors));
(*stack)[arguments_begin + idx] = t_new;
}
} else if (ivalue.isOptionalTensorList()) {
any_tensor_inputs = true;
auto opt_tensors = ivalue.toOptionalTensorList();
if (at::functionalization::impl::isFunctionalTensor(opt_tensors)) {
any_functional_inputs = true;
at::functionalization::impl::sync(opt_tensors);
auto t_new = c10::IValue(at::functionalization::impl::from_functional_tensor(opt_tensors));
(*stack)[arguments_begin + idx] = t_new;
}
}
}
// we should wrap the output if any inputs were wrapped,
// OR if we're hitting a factory function (with no tensor inputs)
auto should_wrap_outputs = !any_tensor_inputs || any_functional_inputs;
{
at::AutoDispatchSkipFunctionalize guard;
op.callBoxed(stack);
}
const auto num_returns = schema.returns().size();
const auto returns_begin = stack->size() - num_returns;
auto returns = torch::jit::last(stack, num_returns);
for (const auto idx : c10::irange(num_returns)) {
const auto& ivalue = returns[idx];
if (ivalue.isTensor() && should_wrap_outputs) {
auto t = ivalue.toTensor();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(t));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isTensorList() && should_wrap_outputs) {
auto tensors = ivalue.toTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(tensors));
(*stack)[returns_begin + idx] = t_new;
} else if (ivalue.isOptionalTensorList() && should_wrap_outputs) {
auto opt_tensors = ivalue.toOptionalTensorList();
auto t_new = c10::IValue(at::functionalization::impl::to_functional_tensor(opt_tensors));
(*stack)[returns_begin + idx] = t_new;
}
}
}
}
TORCH_LIBRARY_IMPL(_, Functionalize, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&functionalizeFallback>());
}