Caffe2 - C++ API
A deep learning, cross platform ML framework
fused_kernel.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/jit/fuser/partition_desc.h>
5 #include <torch/csrc/jit/fuser/tensor_desc.h>
6 #include <torch/csrc/utils/disallow_copy.h>
7 
8 #include <cstdint>
9 #include <string>
10 #include <vector>
11 
12 namespace torch {
13 namespace jit {
14 namespace fuser {
15 
16 struct FusedKernel {
17  TH_DISALLOW_COPY_AND_ASSIGN(FusedKernel);
18 
20  std::string name,
21  std::string code,
22  std::vector<TensorDesc> input_desc,
23  std::vector<TensorDesc> output_desc,
24  std::vector<PartitionDesc> chunk_desc,
25  std::vector<PartitionDesc> concat_desc,
26  bool has_random)
27  : name_(std::move(name)),
28  code_(std::move(code)),
29  input_desc_(std::move(input_desc)),
30  output_desc_(std::move(output_desc)),
31  chunk_desc_(std::move(chunk_desc)),
32  concat_desc_(std::move(concat_desc)),
33  has_random_(has_random) {}
34 
35  virtual ~FusedKernel() = default;
36 
37  // arguments is a list of pointers to the arguments for the compiled CUDA/CPU
38  // code.
39  // The format of arguments is suitable for directly passing to a call to
40  // cuLaunchKernel as the kernel arguments.
41  // Currently the first argument is a pointer to numel (for passing to
42  // CUDA code), and the remainder are pointers to the TensorInfo<T> structs
43  // that compiled code uses to load Tensor data.
44  // launch_with_tensors handles packing at::Tensors into this arguments array.
45  // CPU code uses the same convension so that launch_with_tensors can be
46  // shared.
47  virtual void launch_raw(const uint32_t numel, std::vector<void*>& arguments)
48  const = 0;
49  virtual at::Backend backend() const = 0;
50 
51  // Getters
52  const std::string& name() const {
53  return name_;
54  }
55  const std::string& code() const {
56  return code_;
57  }
58  const std::vector<TensorDesc>& inputDesc() const {
59  return input_desc_;
60  }
61  const std::vector<TensorDesc>& outputDesc() const {
62  return output_desc_;
63  }
64  const std::vector<PartitionDesc>& chunkDesc() const {
65  return chunk_desc_;
66  }
67  const std::vector<PartitionDesc>& concatDesc() const {
68  return concat_desc_;
69  }
70  bool hasRandom() const {
71  return has_random_;
72  }
73 
74  protected:
75  const std::string name_;
76  const std::string code_;
77  const std::vector<TensorDesc> input_desc_;
78  const std::vector<TensorDesc> output_desc_;
79 
80  // same size as input_desc, describes whether an
81  // input should be broken into subtensors (chunks)
82  // to be consumed by the fusion group
83  const std::vector<PartitionDesc> chunk_desc_;
84 
85  // same size as output_desc, describes whether
86  // an output is actually a concatenation of
87  // many subtensors that the fusion group produces
88  const std::vector<PartitionDesc> concat_desc_;
89 
90  const bool has_random_;
91 };
92 
93 } // namespace fuser
94 } // namespace jit
95 } // namespace torch
Backend
This legacy enum class defines the set of backends supported by old school, code generated Type-based...
Definition: Backend.h:23
Definition: jit_type.h:17