3 #include "test/cpp/jit/test_base.h" 5 #include <torch/csrc/jit/passes/canonicalize.h> 6 #include "ATen/core/interned_strings.h" 7 #include "torch/csrc/autograd/generated/variable_factories.h" 8 #include "torch/csrc/autograd/variable.h" 9 #include "torch/csrc/jit/argument_spec.h" 10 #include "torch/csrc/jit/attributes.h" 11 #include "torch/csrc/jit/autodiff.h" 12 #include "torch/csrc/jit/code_template.h" 13 #include "torch/csrc/jit/custom_operator.h" 14 #include "torch/csrc/jit/dynamic_dag.h" 15 #include "torch/csrc/jit/fuser/interface.h" 16 #include "torch/csrc/jit/import.h" 17 #include "torch/csrc/jit/interpreter.h" 18 #include "torch/csrc/jit/passes/alias_analysis.h" 19 #include "torch/csrc/jit/passes/common_subexpression_elimination.h" 20 #include "torch/csrc/jit/passes/constant_propagation.h" 21 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" 22 #include "torch/csrc/jit/passes/dead_code_elimination.h" 23 #include "torch/csrc/jit/passes/graph_fuser.h" 24 #include "torch/csrc/jit/passes/lower_grad_of.h" 25 #include "torch/csrc/jit/passes/lower_tuples.h" 26 #include "torch/csrc/jit/passes/requires_grad_analysis.h" 27 #include "torch/csrc/jit/passes/shape_analysis.h" 28 #include "torch/csrc/jit/passes/utils/subgraph_utils.h" 29 #include "torch/csrc/jit/symbolic_script.h" 30 #include "torch/csrc/jit/symbolic_variable.h" 31 #include "torch/csrc/jit/tracer.h" 32 #include "torch/csrc/utils/hash.h" 33 #include "torch/csrc/utils/memory.h" 35 #include "torch/csrc/autograd/engine.h" 36 #include "torch/csrc/autograd/variable.h" 38 #include <torch/csrc/jit/testing/file_check.h> 39 #include "ATen/core/ivalue.h" 40 #include "torch/csrc/jit/graph_executor.h" 41 #include "torch/csrc/jit/script/compiler.h" 42 #include "torch/csrc/jit/script/module.h" 44 #include "onnx/onnx_pb.h" 46 #include <ATen/ATen.h> 48 #include <c10/util/Exception.h> 58 #include <unordered_set> 66 using Var = SymbolicVariable;
69 auto testSimple = [&] {
71 Var i0 = Var::asNewInput(graph);
72 Var i1 = Var::asNewInput(graph);
75 auto a = at::rand({3, 4}, at::kCUDA);
76 auto b = at::rand({4, 3}, at::kCUDA).transpose(0, 1);
77 auto o = at::zeros({3, 4}, at::kCUDA);
78 auto outputs = debugLaunchGraph(graph, {a, b});
79 ASSERT_EQ(outputs.size(), 1);
81 float max_diff = (o2 - outputs[0]).abs().max().item<
double>();
83 ASSERT_EQ(max_diff, 0);
87 auto testOne = [&](
int ti,
int tj) {
90 Var i0 = Var::asNewInput(graph);
91 Var i1 = Var::asNewInput(graph);
92 Var i2 = Var::asNewInput(graph);
93 Var i3 = Var::asNewInput(graph);
94 Var i4 = Var::asNewInput(graph);
96 auto p22 = i4.sigmoid();
97 auto p20 = i3.sigmoid();
99 auto p16 = i1.sigmoid();
101 auto p11 = p22 * p18;
110 std::vector<at::Tensor> inputs;
115 for (
size_t i = 0; i < graph.inputs().size(); i++) {
116 std::vector<int64_t> dims = {128, 128, 32};
117 std::swap(dims[ti], dims[tj]);
118 inputs.push_back(at::rand(dims, at::kCUDA).transpose(ti, tj));
121 auto t22 = inputs[4].sigmoid();
122 auto t20 = inputs[3].sigmoid();
123 auto t18 = inputs[2].tanh();
124 auto t16 = inputs[1].sigmoid();
125 auto t14 = t20 * inputs[0];
126 auto t11 = t22 * t18;
127 auto out1 = t14 + t11;
128 auto t5 = out1.tanh();
129 auto out0 = t16 * t5;
131 auto outputs = debugLaunchGraph(graph, inputs);
132 ASSERT_EQ(outputs.size(), graph.outputs().size());
133 ASSERT_TRUE(out0.is_same_size(outputs.front()));
134 float max_diff = (outputs.front() - out0).abs().max().item<
double>();
135 ASSERT_TRUE(max_diff < 1e-6);
142 auto createFusedConcat =
145 .insertNode(graph.create(prim::FusedConcat, inputs)->i_(attr::dim, dim))
149 auto testConcat = [&](
int dim) {
151 Var i0 = Var::asNewInput(graph);
152 Var i1 = Var::asNewInput(graph);
155 Var(createFusedConcat(graph, {i0, o0}, dim)).addAsOutput();
157 auto a = at::rand({3, 4, 5}, at::kCUDA);
158 auto b = at::rand({4, 3, 5}, at::kCUDA).transpose(0, 1);
161 auto o2_r = at::cat({a, o_r}, dim);
162 auto outputs = debugLaunchGraph(graph, {a, b});
163 ASSERT_EQ(outputs.size(), 2);
165 float max_diff = (o_r - outputs[0]).abs().max().item<
double>();
166 ASSERT_EQ(max_diff, 0);
167 float max_diff2 = (o2_r - outputs[1]).abs().max().item<
double>();
168 ASSERT_EQ(max_diff2, 0);
175 void testRegisterFusionCachesKernel(std::ostream& out = std::cout) {
177 auto createGraphWithNames = [](std::string cname, std::string dname) {
178 auto graph = std::make_shared<Graph>();
179 at::ScalarType s = at::ScalarType::Float;
180 auto type = CompleteTensorType::create(s, at::kCPU, {2, 3, 4}, {12, 4, 1});
181 auto a = SymbolicVariable::asNewInput(*graph, type);
182 auto b = SymbolicVariable::asNewInput(*graph, type);
185 c.value()->setUniqueName(cname);
186 d.value()->setUniqueName(dname);
187 graph->registerOutput(d.value());
188 torch::jit::overrideCanFuseOnCPU(
true);
190 torch::jit::overrideCanFuseOnCPU(
false);
194 auto getFusionGroup = [](
const std::shared_ptr<Graph>& graph) {
195 const auto& nodes = graph->nodes();
196 auto maybe_fusion_group =
197 std::find_if(nodes.begin(), nodes.end(), [](
const Node* node) {
198 return node->kind() == prim::FusionGroup;
201 maybe_fusion_group != nodes.end(),
202 "testRegisterFusionCachesKernel: could not create FusionGroup");
203 return *maybe_fusion_group;
207 auto graph1 = createGraphWithNames(
"c1",
"d1");
208 auto fg1 = getFusionGroup(graph1);
210 auto graph2 = createGraphWithNames(
"c2",
"d2");
211 auto fg2 = getFusionGroup(graph2);
214 auto expected_key = registerFusion(fg1);
215 auto second_key = registerFusion(fg2);
219 ASSERT_EQ(second_key, expected_key);
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...