blob: ad413d48c7ebcb04cbc31cdad6fe346ddf5bfd13 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/functorch/LegacyVmapTransforms.h>
#include <ATen/functorch/BatchedTensorImpl.h>
#include <ATen/functorch/PlumbingHelper.h>
#include <ATen/functorch/DynamicLayer.h>
#include <ATen/core/dispatch/Dispatcher.h>
// functorch's vmap has two Dispatch Keys that implement it:
// FuncTorchBatched and FuncTorchVmapMode. This file contains registrations for
// FuncTorchVmapMode -- these registrations are to error out on operations
// that we don't support on regular Tensors.
namespace at {
namespace functorch {
static void unsupportedRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_CHECK(false, "vmap: We do not support calling out variants of random operations inside of vmap. ",
"Please use non-out variants as a workaround");
}
TORCH_LIBRARY_IMPL(_, FuncTorchVmapMode, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}
static void nyiRandomOp(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_CHECK(false, "vmap: we do not yet support ", op.schema().operator_name(),
". Please file an issue");
}
#define UNSUPPORTED_RANDOM(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
#define UNSUPPORTED_RANDOM2(op, overload) \
m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&unsupportedRandomOp>());
#define NYI_RANDOM(op) \
m.impl(#op, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
#define NYI_RANDOM2(op, overload) \
m.impl(#op"."#overload, torch::CppFunction::makeFromBoxedFunction<&nyiRandomOp>());
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
UNSUPPORTED_RANDOM2(bernoulli, out);
UNSUPPORTED_RANDOM2(rand, generator_out);
UNSUPPORTED_RANDOM2(rand, out);
UNSUPPORTED_RANDOM2(randint, generator_out);
UNSUPPORTED_RANDOM2(randint, out);
UNSUPPORTED_RANDOM2(randn, generator_out);
UNSUPPORTED_RANDOM2(randn, out);
UNSUPPORTED_RANDOM2(randperm, generator_out);
UNSUPPORTED_RANDOM2(randperm, out);
UNSUPPORTED_RANDOM2(multinomial, out);
UNSUPPORTED_RANDOM2(normal, float_Tensor_out);
UNSUPPORTED_RANDOM2(normal, Tensor_Tensor_out);
UNSUPPORTED_RANDOM2(normal, float_float_out);
UNSUPPORTED_RANDOM2(rrelu_with_noise, out);
NYI_RANDOM(rrelu_with_noise);
NYI_RANDOM(rrelu_with_noise_);
NYI_RANDOM(rrelu_);
NYI_RANDOM(rrelu);
}
}
} // namespace at