4 #include <c10/util/Optional.h> 5 #include <torch/csrc/WindowsTorchApiMacro.h> 6 #include <torch/csrc/jit/fuser/arg_spec.h> 7 #include <torch/csrc/jit/fuser/fused_kernel.h> 8 #include <torch/csrc/jit/fuser/interface.h> 9 #include <torch/csrc/jit/interpreter.h> 10 #include <torch/csrc/jit/ir.h> 11 #include <ATen/core/stack.h> 16 #include <unordered_map> 30 : nSubTensors_{_nSubTensors}, dim_{_dim} {};
32 int64_t nSubTensors()
const {
53 : offset_{_offset}, sizeInput_{_sizeInput} {};
55 int64_t offset()
const {
58 int64_t sizeInput()
const {
61 bool needsSumToSize()
const {
62 return sizeInput_ != -1;
85 KernelSpec(
const int64_t _key,
const std::shared_ptr<Graph>& _graph)
89 nInputs_{_graph->inputs().size()},
91 inputBroadcastGroups_{},
96 for (
const auto& n : graph_->nodes()) {
97 if (n->kind() == aten::rand_like) {
102 nTensorInputs_ = std::count_if(
103 graph_->inputs().begin(), graph_->inputs().end(), [](
const Value* v) {
104 return v->type()->isSubtypeOf(TensorType::get());
109 int64_t key()
const {
112 std::shared_ptr<Graph> graph()
const {
115 const Code& code()
const {
118 int64_t nInputs()
const {
121 int64_t nTensorInputs()
const {
122 return nTensorInputs_;
125 std::vector<std::vector<int64_t>>& inputBroadcastGroups() {
126 return inputBroadcastGroups_;
128 const std::vector<std::vector<int64_t>>& inputBroadcastGroups()
const {
129 return inputBroadcastGroups_;
132 std::vector<PartitionInfo>& inputChunks() {
135 const std::vector<PartitionInfo>& inputChunks()
const {
139 std::vector<OutputMapAndSize>& outputMapAndSizes() {
140 return outputMapAndSizes_;
143 bool hasRandom()
const {
149 const ArgSpec& arg_spec)
const {
150 std::lock_guard<std::mutex> guard{mutex_};
151 const auto it = kernels_.find(arg_spec);
152 if (it == kernels_.end())
156 void cacheKernel(
const ArgSpec& arg_spec, std::shared_ptr<FusedKernel> kernel)
158 std::lock_guard<std::mutex> guard{mutex_};
159 kernels_.emplace(arg_spec, kernel);
164 std::shared_ptr<Graph> graph_;
167 uint64_t nTensorInputs_;
168 std::vector<std::vector<int64_t>> inputBroadcastGroups_;
169 std::vector<PartitionInfo> inputChunks_;
174 std::vector<OutputMapAndSize> outputMapAndSizes_;
176 mutable std::mutex mutex_;