3 #include <torch/csrc/WindowsTorchApiMacro.h> 4 #include <torch/csrc/jit/fuser/arg_spec.h> 5 #include <torch/csrc/jit/fuser/fused_kernel.h> 6 #include <torch/csrc/jit/fuser/interface.h> 7 #include <torch/csrc/jit/fuser/kernel_spec.h> 8 #include <torch/csrc/jit/ir.h> 9 #include <ATen/core/stack.h> 21 TORCH_API int64_t registerFusion(
const Node* fusion_group);
26 TORCH_API std::shared_ptr<FusedKernel> compileKernel(
27 const KernelSpec& spec,
28 const ArgSpec& arg_spec,
29 const std::vector<int64_t>& map_size,
32 TORCH_API
size_t nCompiledKernels();
34 TORCH_API
int debugFuser();
36 using FusedKernelConstructor = std::function<std::shared_ptr<FusedKernel>(
40 std::vector<TensorDesc> input_desc,
41 std::vector<TensorDesc> output_desc,
42 std::vector<PartitionDesc> chunk_desc,
43 std::vector<PartitionDesc> concat_desc,
46 TORCH_API
void registerFusionBackend(
47 at::Device::Type backend_type,
48 FusedKernelConstructor ctor);
49 TORCH_API
bool hasFusionBackend(at::Device::Type backend_type);
52 at::Device::Type backend_type,
53 FusedKernelConstructor ctor) {
54 registerFusionBackend(backend_type, std::move(ctor));
Represents a a compute device on which a tensor is located.