| #include <ATen/autocast_mode.h> |
| |
| #include <mutex> |
| #include <ATen/CachedTensorUtils.h> |
| #include <c10/util/flat_hash_map.h> |
| |
| namespace at::autocast { |
| |
| bool is_autocast_enabled(at::DeviceType device_type) { |
| at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type); |
| return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key); |
| } |
| |
| void set_autocast_enabled(at::DeviceType device_type, bool enabled) { |
| at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type); |
| c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled); |
| } |
| |
| namespace { |
| // Imitate Apex and cache some of the casts to streamline parameter reuse. |
| // Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below). |
| // |
| // After discussion with @ezyang, the cache uses the following structure: |
| // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's |
| // unchanged across shallow copies. |
| // The value is a tuple with a weakref to the source tensor's TensorImpl as the first |
| // element and the casted tensor as the second element. |
| // |
| // The weakref keeps the source's TensorImpl from being deleted. We need to because we're |
| // using the source TensorImpl* as the key. If it were deleted, another random Tensor could |
| // be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would |
| // then mistakenly hit in cache: a rare, intermittent, unpredictable bug. |
| // |
| // I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare |
| // directly against incoming TensorImpl*s. |
| using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>; |
| using val_type = std::tuple<weakref_type, Tensor>; |
| |
| static ska::flat_hash_map<TensorImpl*, val_type>& get_cached_casts() { |
| static ska::flat_hash_map<TensorImpl*, val_type> cached_casts; |
| return cached_casts; |
| } |
| std::mutex cached_casts_mutex; |
| |
| |
| // nesting tracks the nesting depth of the Python-side context manager. |
| // When the autocast context manager exits to a nesting level that's outside |
| // any instance of autocast (which should occur at the end of each forward pass) |
| // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region. |
| thread_local int nesting = 0; |
| |
| // The order of this array MUST exactly match the definition order of DeviceType |
| // in c10/core/DeviceType.h. |
| static_assert( |
| at::COMPILE_TIME_MAX_DEVICE_TYPES == 21, |
| "The definition of the default autocast data type per device backend doesn't match with the definition of the device type."); |
| thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES> |
| autocast_dtype = { |
| at::kBFloat16, // CPU |
| at::kHalf, // CUDA. |
| at::ScalarType::Undefined, // Reserved for explicit MKLDNN |
| at::ScalarType::Undefined, // OpenGL |
| at::ScalarType::Undefined, // OpenCL |
| at::ScalarType::Undefined, // IDEEP. |
| at::kHalf, // AMD HIP |
| at::ScalarType::Undefined, // FPGA |
| at::ScalarType::Undefined, // ONNX Runtime / Microsoft |
| at::kBFloat16, // XLA / TPU |
| at::ScalarType::Undefined, // Vulkan |
| at::ScalarType::Undefined, // Metal |
| at::kHalf, // XPU |
| at::ScalarType::Undefined, // MPS |
| at::ScalarType::Undefined, // Meta (tensors with no data) |
| at::kBFloat16, // HPU / HABANA |
| at::ScalarType::Undefined, // SX-Aurora / NEC |
| at::ScalarType::Undefined, // Lazy Tensors |
| at::kHalf, // Graphcore IPU |
| at::ScalarType::Undefined, // Meta training and inference devices |
| at::kHalf, // PrivateUse1 device |
| }; |
| |
| // should we enabled the cache inside autocast. |
| thread_local bool cache_enabled = true; |
| |
| } // anonymous namespace |
| |
| void clear_cache() { |
| const std::lock_guard<std::mutex> lock(cached_casts_mutex); |
| get_cached_casts().clear(); |
| } |
| |
| int increment_nesting() { |
| return ++nesting; |
| } |
| |
| int decrement_nesting() { |
| return --nesting; |
| } |
| |
| at::ScalarType get_autocast_dtype(at::DeviceType device_type) { |
| return autocast_dtype[static_cast<int>(device_type)]; |
| } |
| |
| void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) { |
| autocast_dtype[static_cast<int>(device_type)] = dtype; |
| } |
| |
| bool is_autocast_cache_enabled() { |
| return cache_enabled; |
| } |
| |
| void set_autocast_cache_enabled(bool enabled) { |
| cache_enabled = enabled; |
| } |
| |
| // Overload to catch Tensor args |
| // TODO (possible optimization): |
| // Move cast_cache to an inline function in a header with cached_casts declared as |
| // extern thread_local in the header. |
| Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) { |
| if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { |
| // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). |
| // See cached_casts declaration above for detailed strategy. |
| bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && |
| arg.scalar_type() == at::kFloat && arg.requires_grad() && |
| arg.is_leaf() && !arg.is_view() && cache_enabled && |
| !at::caching::is_cached_tensor(arg)); |
| |
| if (can_try_cache) { |
| const std::lock_guard<std::mutex> lock(cached_casts_mutex); |
| auto it = get_cached_casts().find(arg.unsafeGetTensorImpl()); |
| if (it != get_cached_casts().end()) { |
| return std::get<1>(it->second); |
| } else { |
| auto casted_arg = arg.to(to_type); |
| get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); |
| return casted_arg; |
| } |
| } else { |
| return arg.to(to_type); |
| } |
| } else { |
| return arg; |
| } |
| } |
| |
| /******************************* |
| Banned functions |
| *******************************/ |
| |
| static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional<Tensor>&, int64_t) { |
| AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" |
| "Many models use a sigmoid layer right before the binary cross entropy layer.\n" |
| "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n" |
| "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n" |
| "safe to autocast."); |
| } |
| |
| namespace { |
| |
| /***************************************** |
| Explicit registration for out-of-place ops |
| *****************************************/ |
| |
| TORCH_LIBRARY_IMPL(_, Autocast, m) { |
| m.fallback(torch::CppFunction::makeFallthrough()); |
| } |
| |
| TORCH_LIBRARY_IMPL(aten, Autocast, m) { |
| // lower_precision_fp |
| #define _KERNEL_CUDA_LOW_PRECISION_FP(...) \ |
| KERNEL_CUDA(__VA_ARGS__, lower_precision_fp) |
| |
| AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP) |
| KERNEL_CUDA(cudnn_convolution, lower_precision_fp) |
| KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp) |
| |
| // fp32 |
| #define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32) |
| |
| AT_FORALL_FP32(_KERNEL_CUDA_FP32) |
| |
| // fp32_set_opt_dtype |
| #define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \ |
| KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype) |
| |
| AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE) |
| // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even |
| // when autocasting. |
| // KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype) |
| // KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype) |
| // KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype) |
| |
| // fp32_append_dtype |
| // The fp32_append_dtype wrapper overrides implicit promotion behavior. |
| // norm does not implicitly promote, but be aware when adding new ops to this policy. |
| AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( |
| KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA) |
| |
| // promote |
| #define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote) |
| |
| AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE) |
| |
| m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), |
| TORCH_FN((&at::autocast::binary_cross_entropy_banned))); |
| } |
| |
| TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { |
| m.fallback(torch::CppFunction::makeFallthrough()); |
| } |
| |
| |
| TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { |
| // lower_precision_fp cast policy |
| KERNEL_CPU(conv1d, lower_precision_fp) |
| KERNEL_CPU(conv1d, padding, lower_precision_fp) |
| KERNEL_CPU(conv2d, lower_precision_fp) |
| KERNEL_CPU(conv2d, padding, lower_precision_fp) |
| KERNEL_CPU(conv3d, lower_precision_fp) |
| KERNEL_CPU(conv3d, padding, lower_precision_fp) |
| KERNEL_CPU(bmm, lower_precision_fp) |
| KERNEL_CPU(mm, lower_precision_fp) |
| KERNEL_CPU(linalg_vecdot, lower_precision_fp) |
| KERNEL_CPU(baddbmm, lower_precision_fp) |
| KERNEL_CPU(addmm, lower_precision_fp) |
| KERNEL_CPU(addbmm, lower_precision_fp) |
| KERNEL_CPU(linear, lower_precision_fp) |
| KERNEL_CPU(_convolution, deprecated, lower_precision_fp) |
| KERNEL_CPU(matmul, lower_precision_fp) |
| KERNEL_CPU(conv_tbc, lower_precision_fp) |
| KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp) |
| KERNEL_CPU(conv_transpose1d, lower_precision_fp) |
| KERNEL_CPU(conv_transpose2d, input, lower_precision_fp) |
| KERNEL_CPU(conv_transpose3d, input, lower_precision_fp) |
| KERNEL_CPU(prelu, lower_precision_fp) |
| KERNEL_CPU(scaled_dot_product_attention, lower_precision_fp) |
| KERNEL_CPU(_native_multi_head_attention, lower_precision_fp) |
| |
| // fp32 cast policy |
| KERNEL_CPU(avg_pool3d, fp32) |
| KERNEL_CPU(binary_cross_entropy, fp32) |
| KERNEL_CPU(grid_sampler, fp32) |
| KERNEL_CPU(polar, fp32) |
| KERNEL_CPU(prod, fp32) |
| KERNEL_CPU(prod, dim_int, fp32) |
| KERNEL_CPU(prod, dim_Dimname, fp32) |
| KERNEL_CPU(quantile, fp32) |
| KERNEL_CPU(quantile, scalar, fp32) |
| KERNEL_CPU(nanquantile, fp32) |
| KERNEL_CPU(nanquantile, scalar, fp32) |
| KERNEL_CPU(stft, fp32) |
| KERNEL_CPU(stft, center, fp32) |
| KERNEL_CPU(cdist, fp32) |
| KERNEL_CPU(grid_sampler_2d, fp32) |
| KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32) |
| KERNEL_CPU(grid_sampler_3d, fp32) |
| KERNEL_CPU(trace, fp32) |
| KERNEL_CPU(view_as_complex, fp32) |
| KERNEL_CPU(cholesky, fp32) |
| KERNEL_CPU(cholesky_inverse, fp32) |
| KERNEL_CPU(cholesky_solve, fp32) |
| KERNEL_CPU(inverse, fp32) |
| KERNEL_CPU(lu_solve, fp32) |
| KERNEL_CPU(orgqr, fp32) |
| KERNEL_CPU(ormqr, fp32) |
| KERNEL_CPU(pinverse, fp32) |
| KERNEL_CPU(max_pool3d, fp32) |
| KERNEL_CPU(max_unpool2d, fp32) |
| KERNEL_CPU(max_unpool3d, fp32) |
| KERNEL_CPU(adaptive_avg_pool3d, fp32) |
| KERNEL_CPU(reflection_pad1d, fp32) |
| KERNEL_CPU(reflection_pad2d, fp32) |
| KERNEL_CPU(replication_pad1d, fp32) |
| KERNEL_CPU(replication_pad2d, fp32) |
| KERNEL_CPU(replication_pad3d, fp32) |
| KERNEL_CPU(mse_loss, fp32) |
| KERNEL_CPU(cosine_embedding_loss, fp32) |
| KERNEL_CPU(nll_loss, fp32) |
| KERNEL_CPU(nll_loss2d, fp32) |
| KERNEL_CPU(hinge_embedding_loss, fp32) |
| KERNEL_CPU(poisson_nll_loss, fp32) |
| KERNEL_CPU(smooth_l1_loss, fp32) |
| KERNEL_CPU(cross_entropy_loss, fp32) |
| KERNEL_CPU(l1_loss, fp32) |
| KERNEL_CPU(huber_loss, fp32) |
| KERNEL_CPU(margin_ranking_loss, fp32) |
| KERNEL_CPU(soft_margin_loss, fp32) |
| KERNEL_CPU(triplet_margin_loss, fp32) |
| KERNEL_CPU(multi_margin_loss, fp32) |
| KERNEL_CPU(ctc_loss, IntList, fp32) |
| KERNEL_CPU(ctc_loss, Tensor, fp32) |
| KERNEL_CPU(kl_div, fp32) |
| KERNEL_CPU(multilabel_margin_loss, fp32) |
| KERNEL_CPU(binary_cross_entropy_with_logits, fp32) |
| KERNEL_CPU(fft_fft, fp32) |
| KERNEL_CPU(fft_ifft, fp32) |
| KERNEL_CPU(fft_fft2, fp32) |
| KERNEL_CPU(fft_ifft2, fp32) |
| KERNEL_CPU(fft_fftn, fp32) |
| KERNEL_CPU(fft_ifftn, fp32) |
| KERNEL_CPU(fft_rfft, fp32) |
| KERNEL_CPU(fft_irfft, fp32) |
| KERNEL_CPU(fft_rfft2, fp32) |
| KERNEL_CPU(fft_irfft2, fp32) |
| KERNEL_CPU(fft_rfftn, fp32) |
| KERNEL_CPU(fft_irfftn, fp32) |
| KERNEL_CPU(fft_hfft, fp32) |
| KERNEL_CPU(fft_ihfft, fp32) |
| KERNEL_CPU(linalg_cond, fp32) |
| KERNEL_CPU(linalg_cond, p_str, fp32) |
| KERNEL_CPU(linalg_matrix_rank, fp32) |
| KERNEL_CPU(linalg_matrix_rank, tol_tensor, fp32) |
| KERNEL_CPU(linalg_matrix_rank, atol_rtol_tensor, fp32) |
| KERNEL_CPU(linalg_matrix_rank, atol_rtol_float, fp32) |
| KERNEL_CPU(linalg_solve, fp32) |
| KERNEL_CPU(linalg_cholesky, fp32) |
| KERNEL_CPU(linalg_svdvals, fp32) |
| KERNEL_CPU(linalg_eigvals, fp32) |
| KERNEL_CPU(linalg_eigvalsh, fp32) |
| KERNEL_CPU(linalg_inv, fp32) |
| KERNEL_CPU(linalg_householder_product, fp32) |
| KERNEL_CPU(linalg_tensorinv, fp32) |
| KERNEL_CPU(linalg_tensorsolve, fp32) |
| KERNEL_CPU(fake_quantize_per_tensor_affine, fp32) |
| KERNEL_CPU(geqrf, fp32) |
| KERNEL_CPU(_lu_with_info, fp32) |
| KERNEL_CPU(qr, fp32) |
| KERNEL_CPU(svd, fp32) |
| KERNEL_CPU(triangular_solve, fp32) |
| KERNEL_CPU(fractional_max_pool2d, fp32) |
| KERNEL_CPU(fractional_max_pool3d, fp32) |
| KERNEL_CPU(adaptive_max_pool3d, fp32) |
| KERNEL_CPU(multilabel_margin_loss_forward, fp32) |
| KERNEL_CPU(linalg_qr, fp32) |
| KERNEL_CPU(linalg_cholesky_ex, fp32) |
| KERNEL_CPU(linalg_svd, fp32) |
| KERNEL_CPU(linalg_eig, fp32) |
| KERNEL_CPU(linalg_eigh, fp32) |
| KERNEL_CPU(linalg_lstsq, fp32) |
| KERNEL_CPU(linalg_inv_ex, fp32) |
| |
| // promote |
| KERNEL_CPU(stack, promote) |
| KERNEL_CPU(cat, promote) |
| KERNEL_CPU(index_copy, promote) |
| KERNEL_CPU(index_copy, dimname, promote) |
| |
| } |
| |
| TORCH_LIBRARY_IMPL(_, AutocastXPU, m) { |
| m.fallback(torch::CppFunction::makeFallthrough()); |
| } |
| |
| TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) { |
| // lower_precision_fp |
| #define _KERNEL_XPU_LOW_PRECISION_FP(...) \ |
| KERNEL_XPU(__VA_ARGS__, lower_precision_fp) |
| |
| AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP) |
| |
| // fp32 |
| #define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32) |
| |
| AT_FORALL_FP32(_KERNEL_XPU_FP32) |
| |
| // fp32_set_opt_dtype |
| #define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \ |
| KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype) |
| |
| AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE) |
| |
| // fp32_append_dtype |
| // The fp32_append_dtype wrapper overrides implicit promotion behavior. |
| // norm does not implicitly promote, but be aware when adding new ops to this policy. |
| AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( |
| KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU) |
| |
| // promote |
| #define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote) |
| |
| AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE) |
| |
| m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), |
| TORCH_FN((&at::autocast::binary_cross_entropy_banned))); |
| } |
| |
| } // namespace |
| } // namespace at::autocast |