Caffe2 - C++ API
A deep learning, cross platform ML framework
interface.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/WindowsTorchApiMacro.h>
5 #include <torch/csrc/jit/ir.h>
6 #include <ATen/core/stack.h>
7 
8 #include <cstdint>
9 #include <memory>
10 #include <vector>
11 
12 namespace torch {
13 namespace jit {
14 
15 constexpr int kCPUDevice = -1;
16 
17 // Assigns a "key" to the given fusion_group that it can use to run its
18 // fusion later (via runFusion() below).
19 TORCH_API int64_t registerFusion(const Node* fusion_group);
20 
21 // Runs the fusion corresponding to the given key on the inputs
22 // found on the stack. Outputs are placed on the same stack.
23 // In some cases a fusion cannot be run and a fallback path where
24 // PyTorch's interpreter runs the graph instead is attempted.
25 TORCH_API void runFusion(const int64_t key, Stack& stack);
26 
27 // True if the respective devices can fuse, false otherwise
28 TORCH_API bool canFuseOnCPU();
29 TORCH_API bool canFuseOnGPU();
30 
31 // Sets whether fusion on the CPU is allowed (disabled by default due to
32 // flakiness)
33 TORCH_API void overrideCanFuseOnCPU(bool value);
34 
35 // Treats the given graph as a fusion group and launches it on the
36 // specified device with the given inputs.
37 // Returns the outputs.
38 TORCH_API std::vector<at::Tensor> debugLaunchGraph(
39  Graph& graph,
41 
42 TORCH_API size_t nCompiledKernels();
43 
44 } // namespace jit
45 } // namespace torch
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41