Caffe2 - C++ API
A deep learning, cross platform ML framework
test_create_autodiff_subgraphs.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
7 
8 namespace torch {
9 namespace jit {
10 namespace test {
11 
12 void testCreateAutodiffSubgraphs() {
13  auto graph = build_lstm();
14  CreateAutodiffSubgraphs(graph, /*threshold=*/2);
15  // all of the ops are within the DifferentiableGraph
16  testing::FileCheck()
17  .check_not("aten::mm")
18  ->check_not("aten::sigmoid")
19  ->check_not("aten::tanh")
20  ->check_not("aten::mul")
21  ->check("DifferentiableGraph")
22  ->check_next("return")
23  ->run(*graph);
24 }
25 
26 } // namespace test
27 } // namespace jit
28 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17