blob: 9cd82cbad5fe0cbb0b95b367283b408f78a3fc77 [file] [log] [blame]
#include "torch/csrc/jit/fuser/interface.h"
#include "torch/csrc/jit/fuser/config.h"
#if USE_CUDA_FUSER || USE_CPU_FUSER
#include "torch/csrc/jit/fuser/compiler.h"
#include "torch/csrc/jit/fuser/executor.h"
#include "torch/csrc/jit/fuser/fallback.h"
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
#include <stdexcept>
namespace torch { namespace jit {
namespace detail {
// Note: CPU fusion is currently disabled due to test flakiness
bool cpu_fuser_enabled = false;
} // namespace detail
int64_t registerFusion(const Node* fusion_group) {
#if USE_CUDA_FUSER || USE_CPU_FUSER
return fuser::registerFusion(fusion_group);
#else
throw std::runtime_error("Fusion not supported for this build.");
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
void runFusion(const int64_t key, Stack& stack) {
#if USE_CUDA_FUSER || USE_CPU_FUSER
const auto result = fuser::runFusion(key, stack);
if (!result) fuser::runFallback(key, stack);
#else
throw std::runtime_error("Fusion not supported for this build.");
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
bool canFuseOnCPU() {
#if USE_CPU_FUSER
return detail::cpu_fuser_enabled;
#endif // USE_CPU_FUSER
return false;
}
bool canFuseOnGPU() {
#if USE_CUDA_FUSER
return true;
#endif // USE_CUDA_FUSER
return false;
}
void overrideCanFuseOnCPU(bool value) {
detail::cpu_fuser_enabled = value;
}
// Uses the above interface by stuffing the graph into a node and treating that
// node as a fusion group.
std::vector<at::Tensor> debugLaunchGraph(
Graph& graph
, at::ArrayRef<at::Tensor> inputs) {
#if USE_CUDA_FUSER || USE_CPU_FUSER
// Creates a fusion group node
auto wrapper_graph = std::make_shared<Graph>();
Node* fusion_group = wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
fusion_group->g_(attr::Subgraph, graph.copy());
for (size_t i = 0; i < graph.inputs().size(); ++i) {
fusion_group->addInput(wrapper_graph->addInput());
}
for (size_t i = 0; i < graph.outputs().size(); ++i) {
wrapper_graph->registerOutput(fusion_group->addOutput());
}
// Creates the stack, registers and runs the fusion
Stack stack = fmap<IValue>(inputs);
const auto key = fuser::registerFusion(fusion_group);
fuser::runFusion(key, stack);
return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
#else
throw std::runtime_error("Fusion not supported for this build.");
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
size_t nCompiledKernels() {
#if USE_CUDA_FUSER || USE_CPU_FUSER
return fuser::nCompiledKernels();
#else
return 0;
#endif // USE_CUDA_FUSER || USE_CPU_FUSER
}
} // namespace jit
} // namespace torch