Caffe2 - C++ API
A deep learning, cross platform ML framework
interface.cpp
1 #include <torch/csrc/jit/fuser/interface.h>
2 
3 #include <torch/csrc/jit/fuser/compiler.h>
4 #include <torch/csrc/jit/fuser/executor.h>
5 #include <torch/csrc/jit/fuser/fallback.h>
6 
7 #include <stdexcept>
8 
9 namespace torch {
10 namespace jit {
11 
12 namespace detail {
13 
14 // Note: CPU fusion is currently disabled due to test flakiness
15 bool cpu_fuser_enabled = false;
16 
17 } // namespace detail
18 
19 int64_t registerFusion(const Node* fusion_group) {
20  return fuser::registerFusion(fusion_group);
21 }
22 
23 void runFusion(const int64_t key, Stack& stack) {
24  const auto result = fuser::runFusion(key, stack);
25  if (!result)
26  fuser::runFallback(key, stack);
27 }
28 
29 bool canFuseOnCPU() {
30  return fuser::hasFusionBackend(at::DeviceType::CPU) &&
31  detail::cpu_fuser_enabled;
32 }
33 
34 bool canFuseOnGPU() {
35  return fuser::hasFusionBackend(at::DeviceType::CUDA);
36 }
37 
38 void overrideCanFuseOnCPU(bool value) {
39  detail::cpu_fuser_enabled = value;
40 }
41 
42 // Uses the above interface by stuffing the graph into a node and treating that
43 // node as a fusion group.
44 std::vector<at::Tensor> debugLaunchGraph(
45  Graph& graph,
46  at::ArrayRef<at::Tensor> inputs) {
47  // Creates a fusion group node
48  auto wrapper_graph = std::make_shared<Graph>();
49  Node* fusion_group =
50  wrapper_graph->insertNode(wrapper_graph->createFusionGroup());
51  fusion_group->g_(attr::Subgraph, graph.copy());
52  for (size_t i = 0; i < graph.inputs().size(); ++i) {
53  fusion_group->addInput(wrapper_graph->addInput());
54  }
55  for (size_t i = 0; i < graph.outputs().size(); ++i) {
56  wrapper_graph->registerOutput(fusion_group->addOutput());
57  }
58 
59  // Creates the stack, registers and runs the fusion
60  Stack stack = fmap<IValue>(inputs);
61  const auto key = fuser::registerFusion(fusion_group);
62  fuser::runFusion(key, stack);
63  return fmap(stack, [](const IValue& iv) { return iv.toTensor(); });
64 }
65 
66 size_t nCompiledKernels() {
67  return fuser::nCompiledKernels();
68 }
69 
70 } // namespace jit
71 } // namespace torch
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41