Caffe2 - C++ API
A deep learning, cross platform ML framework
fallback.cpp
1 #include <torch/csrc/jit/fuser/fallback.h>
2 
3 #include <ATen/core/functional.h> //fmap
4 #include <ATen/core/stack.h>
5 #include <torch/csrc/jit/custom_operator.h>
6 #include <torch/csrc/jit/fuser/kernel_cache.h>
7 #include <torch/csrc/jit/interpreter.h>
8 #include <torch/csrc/jit/ir.h>
9 
10 #include <stdexcept>
11 
12 namespace torch {
13 namespace jit {
14 namespace fuser {
15 
16 // Registers fused operators so that fused graphs can properly generate fallback
17 // code.
18 RegisterOperators reg_fused_operators(
19  {Operator(prim::FusedConcat, [](const Node* node) {
20  int64_t dim = node->i(attr::dim);
21  int64_t num_inputs = node->inputs().size();
22  return [dim, num_inputs](Stack& stack) {
23  auto result = at::cat(
24  fmap(
25  last(stack, num_inputs),
26  [](const IValue& i) { return i.toTensor(); }),
27  dim);
28  drop(stack, num_inputs);
29  pack(stack, std::move(result));
30  return 0;
31  };
32  })});
33 
34 void runFallback(int64_t key, Stack& stack) {
35  auto maybe_spec = retrieve(key);
36  if (!maybe_spec)
37  throw std::runtime_error("Failed to find fusion spec to run fallback.");
38 
39  InterpreterState{(*maybe_spec)->code()}.run(stack);
40 }
41 
42 } // namespace fuser
43 } // namespace jit
44 } // namespace torch
Definition: jit_type.h:17