1 #include <torch/csrc/jit/fuser/compiler.h> 4 #include <ATen/core/jit_type.h> 5 #include <c10/util/Exception.h> 6 #include <torch/csrc/jit/code_template.h> 7 #include <torch/csrc/jit/fuser/codegen.h> 8 #include <torch/csrc/jit/fuser/interface.h> 9 #include <torch/csrc/jit/fuser/kernel_cache.h> 10 #include <torch/csrc/jit/fuser/tensor_desc.h> 11 #include <torch/csrc/jit/ir.h> 12 #include <torch/csrc/jit/operator.h> 13 #include <torch/csrc/jit/passes/canonicalize.h> 14 #include <torch/csrc/jit/passes/graph_fuser.h> 15 #include <torch/csrc/jit/passes/shape_analysis.h> 24 #include <unordered_set> 31 std::mutex fusion_backends_lock_;
32 static std::unordered_map<at::Device::Type, FusedKernelConstructor>&
34 static std::unordered_map<at::Device::Type, FusedKernelConstructor>
36 return fusion_backends;
39 void registerFusionBackend(
40 at::Device::Type backend_type,
41 FusedKernelConstructor ctor) {
42 std::lock_guard<std::mutex> guard(fusion_backends_lock_);
43 getFusionBackends()[backend_type] = std::move(ctor);
46 bool hasFusionBackend(at::Device::Type backend_type) {
47 std::lock_guard<std::mutex> guard(fusion_backends_lock_);
48 return getFusionBackends().count(backend_type);
51 const FusedKernelConstructor& getConstructor(at::Device::Type backend_type) {
52 std::lock_guard<std::mutex> guard(fusion_backends_lock_);
53 return getFusionBackends().at(backend_type);
58 static std::atomic<size_t> next_kernel_id{0};
59 static int debug_fusion{-1};
61 size_t nCompiledKernels() {
62 return next_kernel_id.load();
66 if (debug_fusion < 0) {
67 const char* debug_env = getenv(
"PYTORCH_FUSION_DEBUG");
68 debug_fusion = debug_env ? atoi(debug_env) : 0;
75 static const Node* usedInFusedChunk(
const Value* input) {
76 const auto& uses = input->uses();
77 if (uses.size() == 1) {
78 const Node* user = uses[0].user;
79 if (user->kind() == prim::ConstantChunk) {
86 static void setInputChunkDescriptors(KernelSpec& spec) {
90 spec.inputChunks().reserve(spec.nTensorInputs());
91 for (int64_t i = 0; i < spec.nTensorInputs(); i++) {
92 const Value* input = spec.graph()->inputs()[i];
93 if (
const Node* chunk = usedInFusedChunk(input)) {
94 spec.inputChunks().emplace_back(
95 chunk->i(attr::chunks), chunk->i(attr::dim));
97 spec.inputChunks().emplace_back(1, 0);
103 static std::vector<int64_t> getInputDependencies(
const Value* output) {
104 std::vector<const Value*> queue{output};
105 std::unordered_set<const Value*> inputs;
106 std::unordered_set<const Value*> seen;
107 while (!queue.empty()) {
108 const Value* val = queue.back();
110 const Node* producer = val->node();
119 if (producer->kind() == prim::Param &&
120 val->type()->isSubtypeOf(TensorType::get())) {
124 for (
const Value* input : producer->inputs()) {
125 if ( seen.insert(input).second) {
126 queue.push_back(input);
132 std::vector<int64_t> offsets;
133 offsets.reserve(inputs.size());
134 for (
const Value* input : inputs) {
135 offsets.push_back(input->offset());
138 std::sort(offsets.begin(), offsets.end());
142 static void setInputBroadcastGroups(KernelSpec& spec) {
145 for (
const Value* output : (spec.graph())->outputs()) {
146 if (output->node()->kind() == prim::FusedConcat) {
147 for (
const Value* concat_input : output->node()->inputs()) {
148 broadcast_groups.insert(getInputDependencies(concat_input));
151 broadcast_groups.insert(getInputDependencies(output));
155 broadcast_groups.begin(),
156 broadcast_groups.end(),
157 std::back_inserter(spec.inputBroadcastGroups()));
173 void processGradSumToSize(KernelSpec& spec) {
174 auto graph = spec.graph();
176 std::vector<int64_t> outputGradSumToSizes(graph->outputs().size(), -1);
186 for (
auto it = graph->nodes().rbegin(); it != graph->nodes().rend(); it++) {
188 if (node->kind() != aten::_grad_sum_to_size) {
191 bool success = trackSingleGradSumToSizeToOutputs(
192 node->output(), &outputGradSumToSizes);
197 node->output()->replaceAllUsesWith(node->inputs()[0]);
208 auto& outputMapAndSizes = spec.outputMapAndSizes();
209 AT_ASSERT(outputMapAndSizes.empty());
210 std::unordered_map<const Value*, int64_t> reduced_output_indices;
212 for (
auto osize : outputGradSumToSizes) {
213 auto it = reduced_output_indices.find(graph->outputs()[newo]);
214 if (it == reduced_output_indices.end()) {
215 reduced_output_indices.emplace(graph->outputs()[newo], newo);
216 outputMapAndSizes.emplace_back(newo, osize);
219 graph->eraseOutput(newo);
220 outputMapAndSizes.emplace_back(it->second, osize);
234 static void upfrontCompilation(KernelSpec& spec) {
235 setInputBroadcastGroups(spec);
236 setInputChunkDescriptors(spec);
237 processGradSumToSize(spec);
240 int64_t registerFusion(
const Node* fusion_group) {
241 auto graph = normalizeGraphForCache(fusion_group->g(attr::Subgraph));
244 const auto maybe_spec = lookupGraph(graph);
246 return (*maybe_spec)->key();
254 const auto key = store(graph);
255 const auto maybe_retrieved_spec = retrieve(key);
256 AT_ASSERT(maybe_retrieved_spec);
257 upfrontCompilation(**maybe_retrieved_spec);
262 std::shared_ptr<FusedKernel> compileKernel(
263 const KernelSpec& spec,
264 const ArgSpec& arg_spec,
265 const std::vector<int64_t>& map_size,
267 const std::vector<TensorDesc>& input_desc = arg_spec.descs();
269 auto graph = spec.graph()->copy();
271 for (
size_t i = 0; i < input_desc.size(); i++) {
272 const auto& desc = input_desc[i];
273 graph->inputs()[i]->setType(DimensionedTensorType::create(
279 PropagateInputShapes(graph);
282 std::vector<PartitionDesc> chunk_desc;
283 std::vector<std::pair<const Value*, const TensorDesc>> flat_inputs;
285 size_t input_index = 0;
286 for (
const auto& p : graph->inputs()) {
287 if (!p->type()->isSubtypeOf(TensorType::get())) {
290 if (
const Node* chunk = usedInFusedChunk(p)) {
291 int64_t dim = chunk->i(attr::dim);
292 int64_t chunks = chunk->i(attr::chunks);
293 chunk_desc.emplace_back(input_desc[input_index++], chunks, dim);
294 for (
const auto* o : chunk->outputs()) {
295 flat_inputs.emplace_back(o, *chunk_desc.back().subTensorDesc());
298 chunk_desc.emplace_back();
299 flat_inputs.emplace_back(p, input_desc[input_index++]);
305 std::vector<TensorDesc> output_desc;
306 std::vector<PartitionDesc> concat_desc;
307 std::vector<std::pair<const Value*, const TensorDesc>> flat_outputs;
308 for (
const Value* o : graph->outputs()) {
310 std::vector<int64_t> sizes = map_size;
311 if (o->node()->kind() == prim::FusedConcat) {
312 sizes.at(o->node()->i(attr::dim)) *= o->node()->inputs().size();
315 auto type = CompleteTensorType::create(scalar_type, device, sizes);
316 output_desc.emplace_back(type);
317 const auto& desc = output_desc.back();
320 if (o->node()->kind() != prim::FusedConcat) {
321 concat_desc.emplace_back();
322 flat_outputs.emplace_back(o, desc);
324 const auto cat = o->node();
325 concat_desc.emplace_back(desc, cat->inputs().size(), cat->i(attr::dim));
326 for (
const auto& c : cat->inputs()) {
327 flat_outputs.emplace_back(c, *concat_desc.back().subTensorDesc());
332 const std::string name =
"kernel_" + std::to_string(next_kernel_id++);
333 const bool use_cuda = device.
is_cuda();
335 generateKernel(name, *graph, flat_inputs, flat_outputs, use_cuda);
336 const FusedKernelConstructor& kernel_ctor =
337 getConstructor(use_cuda ? at::DeviceType::CUDA : at::DeviceType::CPU);
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Represents a a compute device on which a tensor is located.
DeviceIndex index() const noexcept
Returns the optional index.