1 #include <torch/csrc/jit/fuser/interface.h> 3 #include <torch/csrc/jit/fuser/compiler.h> 4 #include <torch/csrc/jit/fuser/executor.h> 5 #include <torch/csrc/jit/fuser/fallback.h> 15 bool cpu_fuser_enabled =
false;
19 int64_t registerFusion(
const Node* fusion_group) {
20 return fuser::registerFusion(fusion_group);
23 void runFusion(
const int64_t key, Stack& stack) {
24 const auto result = fuser::runFusion(key, stack);
26 fuser::runFallback(key, stack);
30 return fuser::hasFusionBackend(at::DeviceType::CPU) &&
31 detail::cpu_fuser_enabled;
35 return fuser::hasFusionBackend(at::DeviceType::CUDA);
38 void overrideCanFuseOnCPU(
bool value) {
39 detail::cpu_fuser_enabled = value;
44 std::vector<at::Tensor> debugLaunchGraph(
48 auto wrapper_graph = std::make_shared<Graph>();
50 wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
51 fusion_group->g_(attr::Subgraph, graph.copy());
52 for (
size_t i = 0; i < graph.inputs().size(); ++i) {
53 fusion_group->addInput(wrapper_graph->addInput());
55 for (
size_t i = 0; i < graph.outputs().size(); ++i) {
56 wrapper_graph->registerOutput(fusion_group->addOutput());
60 Stack stack = fmap<IValue>(inputs);
61 const auto key = fuser::registerFusion(fusion_group);
62 fuser::runFusion(key, stack);
63 return fmap(stack, [](
const IValue& iv) {
return iv.toTensor(); });
66 size_t nCompiledKernels() {
67 return fuser::nCompiledKernels();
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...