Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_kernel.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/WindowsTorchApiMacro.h>
5 #include <torch/csrc/jit/fuser/fused_kernel.h>
6 
7 #include <cuda.h>
8 #include <cuda_runtime.h>
9 #include <nvrtc.h>
10 
11 #include <cstdint>
12 #include <string>
13 #include <vector>
14 
15 namespace torch {
16 namespace jit {
17 namespace fuser {
18 namespace cuda {
19 
20 // A class holding metadata for an actual CUDA function.
21 // Note: CUDA functions are per device.
24  int16_t device,
25  std::string name,
26  std::string code,
27  std::vector<TensorDesc> input_desc,
28  std::vector<TensorDesc> output_desc,
29  std::vector<PartitionDesc> chunk_desc,
30  std::vector<PartitionDesc> concat_desc,
31  bool has_random);
32 
33  ~FusedKernelCUDA() override;
34 
35  void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
36  const override;
37 
38  at::Backend backend() const override {
39  return at::Backend::CUDA;
40  }
41 
42  private:
43  static constexpr auto kBlockSize = 128;
44 
45  // Note: per device to store device properties and compute launch heuristics
46  // Acquiring these values at launch time would be too slow
47  int16_t device_;
48  int maxBlocks_;
49  cudaDeviceProp* prop_;
50  std::vector<char> ptx_;
51  CUmodule module_;
52  CUfunction function_;
53 };
54 
55 } // namespace cuda
56 } // namespace fuser
57 } // namespace jit
58 } // namespace torch
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Definition: Backend.h:23
Definition: jit_type.h:17