1 #include <torch/csrc/jit/fuser/fallback.h> 3 #include <ATen/core/functional.h> 4 #include <ATen/core/stack.h> 5 #include <torch/csrc/jit/custom_operator.h> 6 #include <torch/csrc/jit/fuser/kernel_cache.h> 7 #include <torch/csrc/jit/interpreter.h> 8 #include <torch/csrc/jit/ir.h> 18 RegisterOperators reg_fused_operators(
19 {Operator(prim::FusedConcat, [](
const Node* node) {
20 int64_t dim = node->i(attr::dim);
21 int64_t num_inputs = node->inputs().size();
22 return [dim, num_inputs](Stack& stack) {
23 auto result = at::cat(
25 last(stack, num_inputs),
26 [](
const IValue& i) {
return i.toTensor(); }),
28 drop(stack, num_inputs);
29 pack(stack, std::move(result));
34 void runFallback(int64_t key, Stack& stack) {
35 auto maybe_spec = retrieve(key);
37 throw std::runtime_error(
"Failed to find fusion spec to run fallback.");
39 InterpreterState{(*maybe_spec)->code()}.run(stack);