blob: 824bfe257284e12f871caefe5d0db0ecb8b569b2 [file] [log] [blame]
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/functorch/DynamicLayer.h>
#include <ATen/functorch/TensorWrapper.h>
#include <ATen/functorch/BatchedTensorImpl.h>
#include <ATen/functorch/BatchRulesHelper.h>
#include <torch/library.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <c10/util/irange.h>
#include <ATen/FuncTorchTLS.h>
#include <iostream>
namespace at {
namespace functorch {
void setDynamicLayerFrontBackKeysIncluded(bool included) {
c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerFrontMode, included);
c10::impl::tls_set_dispatch_key_included(DispatchKey::FuncTorchDynamicLayerBackMode, included);
}
DynamicLayer::DynamicLayer(
TransformType transform_type,
int64_t layerId,
optional<int64_t> batchSize,
optional<RandomnessType> randomness,
optional<bool> prev_grad_mode,
optional<bool> prev_fwd_grad_mode,
optional<bool> functionalize_add_back_views)
{
if (transform_type == TransformType::Grad) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
}
if (transform_type == TransformType::Jvp) {
TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
}
switch (transform_type) {
case TransformType::Vmap:
interpreter_ = Interpreter::Vmap(layerId, batchSize.value(), randomness.value());
break;
case TransformType::Grad:
interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value());
break;
case TransformType::Jvp:
interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
break;
case TransformType::Functionalize:
interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
break;
default:
TORCH_INTERNAL_ASSERT(false);
}
}
TransformType DynamicLayer::key() const {
return interpreter_.key();
}
int64_t DynamicLayer::layerId() const {
return interpreter_.level();
}
int64_t DynamicLayer::batchSize() const {
return VmapInterpreterPtr(&interpreter_).batchSize();
}
RandomnessType DynamicLayer::randomness() const {
return VmapInterpreterPtr(&interpreter_).randomness();
}
// Maps level to life handle, see NOTE: [Life handles and lexically scoped transforms]
// for details
using DynmetaData = std::unordered_map<int64_t, std::shared_ptr<bool>>;
DynmetaData kDynMetaDataSingleton;
static DynmetaData& getGlobalDynmetaData() {
return kDynMetaDataSingleton;
}
// functorch stores some TLS. Inside the TLS is the stack of transforms.
// Unfortunately, since functorch isn't a part of libtorch, we have
// a level of indirection. FuncTorchTLSBase is the interface that lives in libtorch,
// while FuncTorchTLS implements all the methods and stores data.
//
// TODO: after functorch C++ code is moved into PyTorch, we can get rid of
// this layer of indirection.
class FuncTorchTLS : public FuncTorchTLSBase {
public:
FuncTorchTLS() {}
std::unique_ptr<FuncTorchTLSBase> deepcopy() const override {
auto result = std::make_unique<FuncTorchTLS>();
result->dynamicLayerStack = dynamicLayerStack;
return result;
}
int64_t checkSupportsAutogradFunction() const override {
TORCH_CHECK(dynamicLayerStack.size() == 0,
"functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ",
"Please rewrite your function to not use autograd.Function while we work on fixing this");
return 0;
}
void checkSupportsInplaceRequiresGrad() const override {
TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_,
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
"torch.autograd.functional.* APIs) inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
"outside of a function being transformed instead.");
}
void checkSupportsRetainGrad() const override {
TORCH_CHECK(dynamicLayerStack.size() == 0,
"You are attempting to call Tensor.retain_grad() ",
"inside of a function being transformed ",
"by a functorch transform. ",
"This is unsupported, please attempt to use the functorch transforms ",
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
"outside of a function being transformed instead.");
}
std::vector<DynamicLayer> dynamicLayerStack;
bool allow_inplace_requires_grad_ = false;
};
static FuncTorchTLS* getRawFunctorchTLS() {
auto& state = functorchTLSAccessor();
if (state == nullptr) {
state = std::make_unique<FuncTorchTLS>();
}
// Raw pointer usage OK, `state` keeps the pointer alive
FuncTorchTLSBase* raw_state = state.get();
FuncTorchTLS* result = static_cast<FuncTorchTLS*>(raw_state);
return result;
}
void setInplaceRequiresGradAllowed(bool allowed) {
auto* functorch_tls = getRawFunctorchTLS();
functorch_tls->allow_inplace_requires_grad_ = allowed;
}
bool getInplaceRequiresGradAllowed() {
auto* functorch_tls = getRawFunctorchTLS();
return functorch_tls->allow_inplace_requires_grad_;
}
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
return getRawFunctorchTLS()->dynamicLayerStack;
}
std::shared_ptr<bool> getLifeHandleForLevel(int64_t level) {
auto it = getGlobalDynmetaData().find(level);
TORCH_INTERNAL_ASSERT(it != kDynMetaDataSingleton.end(), "level should be alive");
return it->second;
}
optional<DynamicLayer> maybeCurrentDynamicLayer() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
if (dynamicLayerStack.size() == 0) {
return {};
}
return dynamicLayerStack.back();
}
struct SaveLocalDispatchKeySet {
public:
SaveLocalDispatchKeySet() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto& layer = dynamicLayerStack.back();
auto tmp = c10::impl::tls_local_dispatch_key_set();
layer.interpreter().saveLocalDispatchKeySet(tmp);
}
~SaveLocalDispatchKeySet() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto& layer = dynamicLayerStack.back();
auto tmp = layer.interpreter().getSavedLocalDispatchKeySet();
layer.interpreter().clearSavedLocalDispatchKeySet();
c10::impl::_force_tls_local_dispatch_key_set(tmp);
}
SaveLocalDispatchKeySet(const SaveLocalDispatchKeySet&) = delete;
SaveLocalDispatchKeySet& operator=(const SaveLocalDispatchKeySet&) = delete;
};
const std::vector<DynamicLayer>& getDynamicLayerStack() {
return dynamicLayerStackAccessor();
}
void setDynamicLayerStack(const std::vector<DynamicLayer>& stack) {
dynamicLayerStackAccessor() = stack;
}
bool areTransformsActive() {
const auto& data = getGlobalDynmetaData();
return !data.empty();
}
static DynamicLayer popDynamicLayer() {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
auto result = dynamicLayerStack.back();
dynamicLayerStack.pop_back();
if (dynamicLayerStack.size() == 0) {
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
std::cout << "DynamicLayer off" << std::endl;
}
#endif
setDynamicLayerFrontBackKeysIncluded(false);
}
return result;
}
static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
int64_t layerId = 1 + dynamicLayerStack.size();
TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId());
dynamicLayerStack.emplace_back(dynamic_layer);
if (layerId == 1) {
setDynamicLayerFrontBackKeysIncluded(true);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
std::cout << "DynamicLayer on" << std::endl;
}
#endif
}
return layerId;
}
int64_t initAndPushDynamicLayer(
TransformType transform_type,
optional<int64_t> batch_size,
optional<RandomnessType> randomness,
optional<bool> prev_grad_mode,
optional<bool> prev_fwd_grad_mode,
optional<bool> functionalize_add_back_views) {
const auto& dynamicLayerStack = dynamicLayerStackAccessor();
const auto layerId = 1 + dynamicLayerStack.size();
DynamicLayer new_layer(transform_type, layerId, batch_size, randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views);
pushDynamicLayer(std::move(new_layer));
auto& data = getGlobalDynmetaData();
TORCH_INTERNAL_ASSERT(data.find(layerId) == data.end());
if (transform_type == TransformType::Grad) {
TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value());
}
if (transform_type == TransformType::Jvp) {
TORCH_INTERNAL_ASSERT(prev_fwd_grad_mode.has_value());
}
data[layerId] = std::make_shared<bool>(true);
return layerId;
}
DynamicLayer popDynamicLayerAndDeleteMetadata() {
auto result = popDynamicLayer();
auto level = result.layerId();
// TODO: is this lock safe? No one else should be writing to the same bucket
auto& data = getGlobalDynmetaData();
auto it = data.find(level);
if (it == data.end()) {
return result;
}
// invalidate the thing
*(it->second) = false;
data.erase(level);
return result;
}
Tensor unwrapIfDead(const Tensor& tensor) {
auto* wrapped = maybeGetTensorWrapper(tensor);
if (!wrapped) {
return tensor;
}
if (wrapped->is_alive()) {
return tensor;
}
return wrapped->value();
}
void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
std::function<Tensor(const Tensor&)> func) {
TORCH_INTERNAL_ASSERT(begin >= 0);
TORCH_INTERNAL_ASSERT(end >= 0);
TORCH_INTERNAL_ASSERT(begin <= end);
for (int64_t idx = begin; idx < end; idx++) {
auto ivalue = args[idx];
// Tensor?[] translates to a c10::List<IValue> so we need to peek inside List
if (ivalue.isList()) {
bool modified = false;
// TODO: might be more efficient if we scan first then not copy? Depends.
auto list = ivalue.toList().copy();
for (const auto list_idx : c10::irange(0, list.size())) {
const auto& elt = list.get(list_idx);
if (elt.isTensor()) {
list.set(list_idx, func(elt.toTensor()));
modified = true;
}
}
if (modified) {
args[idx] = list;
}
continue;
}
if (ivalue.isTensorList()) {
auto list = ivalue.toTensorList();
for (const auto list_idx : c10::irange(0, list.size())) {
list[list_idx] = func(list[list_idx]);
}
args[idx] = list;
}
TORCH_INTERNAL_ASSERT(!ivalue.isGenericDict(), "No operators can accept GenericDict");
if (!ivalue.isTensor()) {
continue;
}
Tensor value = ivalue.toTensor();
Tensor replacement = func(value);
args[idx] = std::move(replacement);
// sanity checks
if (ivalue.toTensor().defined()) {
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
}
}
}
std::ostream& operator<< (std::ostream& os, const DynamicLayer& layer) {
os << layer.layerId() << ":" << layer.key();
return os;
}
std::ostream& operator<< (std::ostream& os, const std::vector<DynamicLayer>& dls) {
os << "DynamicLayerStack[ ";
for (const auto& layer : dls) {
os << layer << " ";
}
os << "]";
return os;
}
bool isInplaceOp(const FunctionSchema& schema) {
if (!schema.is_mutable() || schema.returns().size() != 1) {
return false;
}
// Check that the first argument is being written to
const auto& first_arg_alias_info = schema.arguments().begin()->alias_info();
if (!first_arg_alias_info || !first_arg_alias_info->isWrite()) {
return false;
}
// Check that none of the other args are being aliased
for (auto it = schema.arguments().begin() + 1; it != schema.arguments().end(); ++it) {
const auto& alias_info = it->alias_info();
if (alias_info) {
return false;
}
}
// Check that the first tensor is being returned (i.e., output has a (a!))
const auto& return_alias_info = schema.returns()[0].alias_info();
return return_alias_info && return_alias_info->isWrite();
}
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
static void dump_local_tls() {
auto tls = c10::impl::tls_local_dispatch_key_set();
std::cout << "[Local Include] " << tls.included_ << std::endl;
std::cout << "[Local Exclude] " << tls.excluded_ << std::endl;
}
#endif
struct WithoutTop {
WithoutTop();
~WithoutTop();
DynamicLayer layer_;
};
WithoutTop::WithoutTop(): layer_(popDynamicLayer()) {}
WithoutTop::~WithoutTop() {
pushDynamicLayer(std::move(layer_));
}
// NOTE: [functorch front and back key fallbacks]
//
// Please read NOTE: [functorch interpreter stack] first for some context.
// The following doc also provides some visuals:
// https://docs.google.com/document/d/14qyaa3xIjmVxYiMLlIlQErunYgR_uR1WupsKMZlnGY4/edit
//
// functorch's "stack of transforms" is implemented as the following:
// - each transform is associated with one or more dispatch keys in the PyTorch
// dispatcher. For example, vmap -> {FuncTorchBatched, FuncTorchVmapMode},
// Autograd -> {Autograd{Backend}, ADInplaceOrView}
// - Whenever a functorch transform is active, the FuncTorchDynamicLayer{Front, Back}Mode
// keys are added to the dispatcher's local dispatch key set.
//
// DynamicLayerFrontMode is responsible for:
// 1. selecting the transform that is at the top of the stack and grabbing its
// interpreter
// 2. Calling interpreter.process(), which does the following:
// 2a. enables/disables a bunch of dispatch keys, so that the only dispatch
// keys that are enabled are the ones that belong to the transform.
// 2b. redispatching
//
// Eventually, DynamicLayerBackMode captures the redispatch from the transforms.
// DynamicLayerBackMode is responsible for:
// - redirecting back to DynamicLayerFrontMode
static void dynamicLayerFrontFallback(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
auto& dynamicLayerStack = dynamicLayerStackAccessor();
TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0);
#ifdef HAS_TORCH_SHOW_DISPATCH_TRACE
if (c10::show_dispatch_trace_enabled()) {
std::cout << dynamicLayerStack << std::endl;
dump_local_tls();
}
#endif
// Save the current LocalDispatchKeySet (to the current DynamicLayer).
// Upon exiting the current scope, that LocalDispatchKeySet gets restored.
// When the current DynamicLayer dispatches to the next (inner) DynamicLayer,
// it will also temporarily restore the saved LocalDispatchKeySet.
SaveLocalDispatchKeySet guard;
// Unwrap escaped GradWrappers
auto num_args = op.schema().arguments().size();
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(), unwrapIfDead);
auto& layer = dynamicLayerStack.back();
layer.interpreter().process(op, stack);
}
static c10::impl::ForceDispatchKeyGuard
restoreLocalDispatchKeySetRAII(const c10::impl::LocalDispatchKeySet& key_set) {
return c10::impl::ForceDispatchKeyGuard(key_set);
}
void dynamicLayerBackFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& layer = dynamicLayerStackAccessor().back();
auto restore_guard = restoreLocalDispatchKeySetRAII(layer.interpreter().getSavedLocalDispatchKeySet());
WithoutTop guard;
layer.interpreter().sendToNextInterpreter(op, stack);
}
TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerFrontMode, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerFrontFallback>());
}
TORCH_LIBRARY_IMPL(_, FuncTorchDynamicLayerBackMode, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&dynamicLayerBackFallback>());
}
}
} // namespace at