1 #include <torch/csrc/jit/graph_executor.h> 3 #include <ATen/core/ivalue.h> 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/autograd/grad_mode.h> 6 #include <torch/csrc/jit/argument_spec.h> 7 #include <torch/csrc/jit/autodiff.h> 8 #include <torch/csrc/jit/custom_operator.h> 9 #include <torch/csrc/jit/interpreter.h> 10 #include <torch/csrc/jit/ir.h> 11 #include <torch/csrc/jit/resource_guard.h> 12 #include <ATen/core/ivalue.h> 13 #include <torch/csrc/jit/passes/batch_mm.h> 14 #include <torch/csrc/jit/passes/canonicalize_ops.h> 15 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 16 #include <torch/csrc/jit/passes/constant_pooling.h> 17 #include <torch/csrc/jit/passes/constant_propagation.h> 18 #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> 19 #include <torch/csrc/jit/passes/dead_code_elimination.h> 20 #include <torch/csrc/jit/passes/graph_fuser.h> 21 #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> 22 #include <torch/csrc/jit/passes/inplace_check.h> 23 #include <torch/csrc/jit/passes/loop_unrolling.h> 24 #include <torch/csrc/jit/passes/lower_grad_of.h> 25 #include <torch/csrc/jit/passes/peephole.h> 26 #include <torch/csrc/jit/passes/remove_expands.h> 27 #include <torch/csrc/jit/passes/requires_grad_analysis.h> 28 #include <torch/csrc/jit/passes/shape_analysis.h> 29 #include <torch/csrc/jit/passes/specialize_autogradzero.h> 30 #include <torch/csrc/jit/symbolic_variable.h> 31 #include <torch/csrc/jit/tracer.h> 33 #include <torch/csrc/autograd/edge.h> 34 #include <torch/csrc/autograd/function.h> 35 #include <torch/csrc/jit/script/compiler.h> 41 #include <unordered_map> 50 using tensor_list = std::vector<at::Tensor>;
51 using Variable = autograd::Variable;
52 using autograd::variable_list;
54 struct ExecutionPlan {
55 ExecutionPlan() =
default;
56 ExecutionPlan(std::shared_ptr<Graph> graph)
57 : code(graph), graph(
std::move(graph)) {}
59 void run(Stack& stack)
const {
60 return InterpreterState(code).run(stack);
63 operator bool()
const {
64 return static_cast<bool>(graph);
67 ExecutionPlanState getDebugState() {
68 ExecutionPlanState state;
70 state.graph = graph.get();
75 std::shared_ptr<Graph> graph;
78 struct DifferentiableGraphBackward :
public autograd::Function {
79 DifferentiableGraphBackward(GraphExecutor executor,
size_t capture_size)
80 : executor(
std::move(executor)) {
81 is_var_capture.reserve(capture_size);
82 var_captures.reserve(capture_size);
83 ivalue_captures.reserve(capture_size);
86 variable_list apply(variable_list&& inputs)
override {
88 stack.reserve(is_var_capture.size() + inputs.size());
91 std::make_move_iterator(inputs.begin()),
92 std::make_move_iterator(inputs.end()));
93 auto var_capture_it = var_captures.begin();
94 auto ivalue_capture_it = ivalue_captures.begin();
95 for (
bool is_var : is_var_capture) {
97 stack.emplace_back(var_capture_it->unpack(this->shared_from_this()));
100 stack.push_back(*ivalue_capture_it);
106 AT_ASSERT(stack.size() == num_outputs());
108 variable_list outputs;
109 outputs.reserve(num_outputs());
110 for (
size_t i = 0; i < num_outputs(); ++i) {
113 if (should_compute_output(i) && !stack[i].isNone()) {
114 auto output = std::move(stack[i]).toTensor();
115 const auto& edge = next_edge(i);
116 if (output.defined()) {
117 outputs.emplace_back(std::move(output));
118 }
else if (edge.is_valid()) {
119 outputs.emplace_back(
120 edge.function->input_metadata(edge.input_nr).zeros_like());
122 outputs.emplace_back();
125 outputs.emplace_back();
131 void capture(
const IValue& val,
bool is_output) {
132 const bool is_tensor = val.isTensor();
133 is_var_capture.push_back(is_tensor);
135 var_captures.emplace_back(Variable(val.toTensor()), is_output);
137 ivalue_captures.push_back(val);
142 friend struct ExecutionPlan;
143 GraphExecutor executor;
147 std::vector<bool> is_var_capture;
148 std::vector<autograd::SavedVariable> var_captures;
149 std::vector<IValue> ivalue_captures;
157 struct DifferentiableGraphOp {
158 DifferentiableGraphOp(Gradient grad)
160 grad(
std::move(grad)),
161 grad_executor(this->grad.df),
162 num_inputs(this->grad.f->inputs().size()),
163 num_outputs(this->grad.f->outputs().size()) {}
166 int operator()(Stack& stack)
const {
167 auto grad_fn = std::make_shared<DifferentiableGraphBackward>(
169 grad.df_input_captured_inputs.size() +
170 grad.df_input_captured_outputs.size());
173 auto inputs = last(stack, num_inputs);
176 for (
auto idx : grad.df_output_vjps) {
177 auto v = Variable(inputs[idx].toTensor());
178 grad_fn->add_next_edge(
179 v.defined() ? v.gradient_edge() : autograd::Edge{});
181 captureInputs(*grad_fn, inputs);
184 detachVariables(stack);
185 InterpreterState(f).run(stack);
188 auto outputs = last(stack, num_outputs);
196 for (
auto idx : grad.df_input_vjps) {
203 if (!outputs[idx].isNone()) {
204 output = outputs[idx].toTensor();
209 if (at::isFloatingType(output.scalar_type())) {
210 autograd::create_gradient_edge(output, grad_fn);
211 output.set_requires_grad(
true);
213 grad_fn->add_input_metadata(autograd::Function::undefined_input{});
216 captureOutputs(*grad_fn, outputs);
219 const size_t num_temporary_outputs = num_outputs - grad.f_real_outputs;
220 stack.erase(stack.end() - num_temporary_outputs, stack.end());
226 friend GraphExecutor* detail::getGradExecutor(Operation& op);
228 void detachVariables(Stack& stack)
const {
232 const int64_t stack_size = stack.size();
233 const int64_t stack_offset = stack_size - num_inputs;
234 for (int64_t i = stack_offset; i < stack_size; ++i) {
238 auto t = std::move(v).toTensor();
239 v = IValue{t.defined() ? autograd::as_variable_ref(t).detach()
245 DifferentiableGraphBackward& grad_fn,
247 for (
size_t offset : grad.df_input_captured_inputs) {
248 grad_fn.capture(inputs[offset],
false);
252 DifferentiableGraphBackward& grad_fn,
254 for (
size_t offset : grad.df_input_captured_outputs) {
255 grad_fn.capture(outputs[offset],
true);
261 GraphExecutor grad_executor;
263 const size_t num_inputs;
264 const size_t num_outputs;
267 void packGradient(Gradient gradient, Node* dnode) {
268 AT_ASSERT(dnode->kind() == prim::DifferentiableGraph);
269 dnode->g_(attr::Subgraph, gradient.f)
270 ->g_(attr::ReverseSubgraph, gradient.df)
271 ->i_(attr::f_real_outputs, gradient.f_real_outputs)
272 ->is_(attr::df_input_vjps, fmap<int64_t>(gradient.df_input_vjps))
274 attr::df_input_captured_inputs,
275 fmap<int64_t>(gradient.df_input_captured_inputs))
277 attr::df_input_captured_outputs,
278 fmap<int64_t>(gradient.df_input_captured_outputs))
279 ->is_(attr::df_output_vjps, fmap<int64_t>(gradient.df_output_vjps));
282 Gradient getGradient(
const Node* n) {
283 AT_ASSERT(n->kind() == prim::DifferentiableGraph);
285 grad.f = n->g(attr::Subgraph);
286 grad.df = n->g(attr::ReverseSubgraph);
287 grad.f_real_outputs = n->i(attr::f_real_outputs);
288 grad.df_input_vjps = fmap<size_t>(n->is(attr::df_input_vjps));
289 grad.df_input_captured_inputs =
290 fmap<size_t>(n->is(attr::df_input_captured_inputs));
291 grad.df_input_captured_outputs =
292 fmap<size_t>(n->is(attr::df_input_captured_outputs));
293 grad.df_output_vjps = fmap<size_t>(n->is(attr::df_output_vjps));
299 RegisterOperators reg_graph_executor_ops(
300 {Operator(prim::DifferentiableGraph, [](
const Node* n) -> Operation {
301 return DifferentiableGraphOp(getGradient(n));
306 GraphExecutor* getGradExecutor(Operation& op) {
307 if (
auto diff_op = op.target<DifferentiableGraphOp>()) {
308 return &diff_op->grad_executor;
321 static std::shared_ptr<Graph> prepareGraph(std::shared_ptr<Graph>& graph) {
322 auto copy = graph->copy();
323 EraseShapeInformation(copy);
327 static size_t countFlatInputs(
const TypePtr& ptr) {
329 return countFlatInputs(optional_type->getElementType());
331 if (
auto tuple_type = ptr->cast<
TupleType>()) {
333 for (
auto& elem : tuple_type->elements()) {
334 total += countFlatInputs(elem);
341 static size_t countFlatInputs(
const std::shared_ptr<Graph>& graph) {
343 for (
Value* input : graph->inputs()) {
344 total += countFlatInputs(input->type());
349 inline bool hasMutableOperators(
Block* block) {
350 for (
auto n : block->nodes()) {
351 if (n->kind().is_aten() && n->schema().is_mutable())
353 for (
auto b : n->blocks()) {
354 if (hasMutableOperators(b))
362 : graph(prepareGraph(graph)),
366 num_inputs(this->graph->inputs().size()),
367 num_flat_inputs(countFlatInputs(graph)),
368 num_outputs(this->graph->outputs().size()) {}
371 void run(Stack& stack) {
373 stack.size() >= num_inputs,
376 " inputs, but got only ",
379 if (tracer::isTracing()) {
380 return runTraced(stack);
383 auto& execution_plan =
384 optimize ? getOrCompile(stack) : getOrCompileFallback();
385 return execution_plan.run(stack);
388 std::shared_ptr<Graph> graphFor(
const Stack& stack)
const {
389 AT_ASSERT(stack.size() >= num_inputs);
390 auto inputs = last(stack, num_inputs);
392 autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
395 AT_CHECK(fallback,
"No graph found for given inputs");
396 return fallback.graph;
399 auto it = plan_cache.find(spec);
400 AT_CHECK(it != plan_cache.end(),
"No graph found for given inputs");
401 return it->second.graph;
406 state.graph = graph.get();
408 state.fallback = fallback.getDebugState();
410 for (
auto& entry : plan_cache) {
411 state.execution_plans.emplace(entry.first, entry.second.getDebugState());
417 void debugDisableAutodiffSubgraphInlining() {
419 autodiffSubgraphNodeThreshold = 1;
421 autodiffSubgraphInlineThreshold = 1;
427 const ExecutionPlan& getOrCompileFallback() {
428 std::lock_guard<std::mutex> lock(compile_mutex);
430 auto graph_ = graph->copy();
431 runRequiredPasses(graph_);
432 fallback = ExecutionPlan(graph_);
437 const ExecutionPlan& getOrCompile(
const Stack& stack) {
441 autograd::GradMode::is_enabled(),
442 last(stack, num_inputs),
445 std::lock_guard<std::mutex> lock(compile_mutex);
446 auto it = plan_cache.find(spec);
447 if (it != plan_cache.end())
449 auto plan = compileSpec(spec);
450 auto r = plan_cache.emplace(std::move(spec), std::move(plan));
451 return r.first->second;
456 auto opt_graph = graph->copy();
457 setInputTypes(*opt_graph, spec);
462 runRequiredPasses(opt_graph);
469 ConstantPropagation(opt_graph);
470 PropagateInputShapes(opt_graph);
471 PropagateRequiresGrad(opt_graph);
476 runOptimization(opt_graph, spec);
482 if (needsGradient(opt_graph)) {
484 CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
485 for (
Node* dnode : diff_nodes) {
486 auto diff_graph = std::move(dnode->g(attr::Subgraph));
487 Gradient gradient = differentiate(diff_graph);
488 runNondiffOptimization(gradient.f);
489 packGradient(gradient, dnode);
491 InlineAutodiffSubgraphs(opt_graph, autodiffSubgraphInlineThreshold);
493 runNondiffOptimization(opt_graph);
496 EliminateDeadCode(opt_graph);
497 return ExecutionPlan(opt_graph);
500 void runOptimization(
501 std::shared_ptr<Graph>& graph,
504 EliminateDeadCode(graph);
505 EliminateCommonSubexpression(graph);
506 ConstantPooling(graph);
508 PeepholeOptimize(graph);
513 EliminateCommonSubexpression(graph);
521 void runNondiffOptimization(std::shared_ptr<Graph>& graph) {
525 static bool needsGradient(
const std::shared_ptr<const Graph>& graph) {
526 if (!autograd::GradMode::is_enabled())
528 if (mayIntroduceGradient(graph->block()))
530 for (
const Value* input : graph->inputs()) {
531 if (input->type()->requires_grad())
537 static bool mayIntroduceGradient(
const Block* b) {
538 for (
const Node* n : b->nodes()) {
539 if (n->kind() == prim::PythonOp)
541 for (
const Block* bb : n->blocks()) {
542 if (mayIntroduceGradient(bb))
549 void runTraced(Stack& stack) {
550 const auto& state = tracer::getTracingState();
551 auto inputs = last(stack, num_inputs);
552 auto input_values = fmap(
553 inputs, [](
const IValue& v) {
return tracer::getNestedValueTrace(v); });
556 autograd::GradMode::is_enabled(), inputs, num_flat_inputs);
564 getOrCompileFallback().run(stack);
572 auto local_graph = this->graph->copy();
573 setInputTypes(*local_graph, spec);
574 PropagateInputShapes(local_graph);
576 inlineCallTo(*state->graph, *local_graph, input_values);
578 auto outputs = last(stack, num_outputs);
579 for (
size_t i = 0; i < outputs.
size(); ++i) {
580 tracer::setValueTrace(outputs[i], output_values[i]);
587 std::shared_ptr<Graph> graph;
592 const size_t num_inputs;
593 const size_t num_flat_inputs;
595 const size_t num_outputs;
599 ExecutionPlan fallback;
603 std::unordered_map<ArgumentSpec, ExecutionPlan> plan_cache;
607 std::mutex compile_mutex;
610 size_t autodiffSubgraphNodeThreshold = 2;
611 size_t autodiffSubgraphInlineThreshold = 5;
614 GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph,
bool optimize)
617 void GraphExecutor::run(Stack& inputs) {
618 return pImpl->run(inputs);
621 std::shared_ptr<Graph> GraphExecutor::graph()
const {
625 std::shared_ptr<Graph> GraphExecutor::graphFor(
const Stack& inputs)
const {
626 return pImpl->graphFor(inputs);
630 return pImpl->getDebugState();
633 void GraphExecutor::debugDisableAutodiffSubgraphInlining() {
634 return pImpl->debugDisableAutodiffSubgraphInlining();
637 void runRequiredPasses(
const std::shared_ptr<Graph>& g) {
638 specializeAutogradZero(*g);
646 EliminateDeadCode(g);
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...