1 #include <torch/csrc/jit/fuser/codegen.h> 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/jit/code_template.h> 6 #include <torch/csrc/jit/fuser/compiler.h> 7 #include <torch/csrc/jit/fuser/interface.h> 8 #include <torch/csrc/jit/fuser/tensor_info.h> 9 #include <torch/csrc/jit/ir.h> 11 #include <torch/csrc/jit/fuser/cpu/resource_strings.h> 12 #include <torch/csrc/jit/fuser/cuda/resource_strings.h> 26 static auto dim_calc = CodeTemplate(R
"( 27 //printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]); 28 size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes}; 29 ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride}; 32 static std::string valueName(
const Value* n) {
33 return "n" + std::to_string(n->unique());
36 static std::string scalarValue(
const int64_t v) {
37 return std::to_string(v);
40 static std::string scalarValue(
const bool v) {
41 return std::to_string(v);
47 static std::string scalarValue(
const double v) {
48 std::ostringstream out;
51 }
else if (std::isinf(v)) {
53 out <<
"NEG_INFINITY";
55 out <<
"POS_INFINITY";
58 out << std::scientific << v <<
"f";
64 static const char* scalarTypeName(
const at::ScalarType type) {
65 if (type == at::ScalarType::Half) {
70 #define DEFINE_CASE(ctype, name, _) \ 71 case at::ScalarType::name: \ 73 AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(DEFINE_CASE)
76 throw std::runtime_error(
"unknown scalar type");
80 static const char* calcScalarTypeName(
const at::ScalarType type) {
81 if (type == at::ScalarType::Half) {
84 return scalarTypeName(type);
87 static std::string variableType(
const std::shared_ptr<c10::Type>& t) {
88 if (t->kind() == TypeKind::IntType) {
90 }
else if (t->kind() == TypeKind::FloatType) {
92 }
else if (t->kind() == TypeKind::BoolType) {
94 }
else if (t->kind() == TypeKind::DimensionedTensorType) {
95 auto const tt = t->cast<DimensionedTensorType>();
96 return calcScalarTypeName(tt->scalarType());
99 throw std::runtime_error(
100 "unknown scalar type during JIT fusion code generation");
103 static std::string typeCastedValueName(
104 const std::shared_ptr<c10::Type>& t,
105 const at::ScalarType outtype,
106 const std::string& vn) {
107 if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
108 if (!isIntegralType(outtype)) {
109 return std::string(
"((") + calcScalarTypeName(outtype) +
") " + vn +
")";
112 }
else if (t->kind() == TypeKind::FloatType) {
113 if (!isFloatingType(outtype)) {
114 return std::string(
"((") + calcScalarTypeName(outtype) +
") " + vn +
")";
117 }
else if (t->kind() == TypeKind::DimensionedTensorType) {
118 auto const tt = t->cast<DimensionedTensorType>();
119 if (tt->scalarType() != outtype) {
120 return std::string(
"((") + calcScalarTypeName(outtype) +
") " + vn +
")";
125 throw std::runtime_error(
126 "unknown scalar type during JIT fusion code generation");
130 static std::string encodeSpecialRHS(
const Node* n, TemplateEnv& env) {
135 if (n->kind() == aten::clamp) {
136 const auto min = n->input(1);
137 const auto max = n->input(2);
138 env.s(
"0", valueName(n->input(0)));
140 if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
141 env.s(
"1", valueName(min));
142 env.s(
"2", valueName(max));
143 return format(
"(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
144 }
else if (min->node()->mustBeNone()) {
145 env.s(
"1", valueName(max));
146 return format(
"(${0} > ${1} ? ${1} : ${0})", env);
147 }
else if (max->node()->mustBeNone()) {
148 env.s(
"1", valueName(min));
149 return format(
"(${0} < ${1} ? ${1} : ${0})", env);
151 throw std::runtime_error(
152 "At least one of 'min' or 'max' must not be None");
155 throw std::runtime_error(
"Cannot encode RHS of the node, op not supported");
160 static std::string encodeRHS(
const Node* n) {
161 static std::unordered_map<NodeKind, std::string> simple_map_ops = {
163 {aten::_cast_Float,
"static_cast<float>(${0})"},
164 {aten::abs,
"fabs(${0})"},
165 {aten::sigmoid,
"1.f / (1.f + expf(-${0}))"},
166 {aten::relu,
"${0} < 0 ? 0.f : ${0} "},
168 "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
169 {aten::log,
"logf(${0})"},
170 {aten::log10,
"log10f(${0})"},
171 {aten::log1p,
"log1pf(${0})"},
172 {aten::log2,
"log2f(${0})"},
173 {aten::lgamma,
"lgammaf(${0})"},
174 {aten::exp,
"expf(${0})"},
175 {aten::expm1,
"expm1f(${0})"},
176 {aten::erf,
"erff(${0})"},
177 {aten::erfc,
"erfcf(${0})"},
178 {aten::cos,
"cosf(${0})"},
179 {aten::acos,
"acosf(${0})"},
180 {aten::cosh,
"coshf(${0})"},
181 {aten::sin,
"sinf(${0})"},
182 {aten::asin,
"asinf(${0})"},
183 {aten::sinh,
"sinhf(${0})"},
184 {aten::tan,
"tanf(${0})"},
185 {aten::atan,
"atanf(${0})"},
186 {aten::tanh,
"tanhf(${0})"},
187 {aten::sqrt,
"sqrtf(${0})"},
188 {aten::rsqrt,
"rsqrtf(${0})"},
189 {aten::ceil,
"ceilf(${0})"},
190 {aten::floor,
"floorf(${0})"},
191 {aten::round,
"roundf(${0})"},
192 {aten::trunc,
"truncf(${0})"},
193 {aten::frac,
"fracf(${0})"},
194 {aten::reciprocal,
"1.f/(${0})"},
195 {aten::neg,
"-${0}"},
197 {aten::atan2,
"atan2(${0}, ${1})"},
198 {aten::min,
"fminf(${0}, ${1})"},
199 {aten::max,
"fmaxf(${0}, ${1})"},
206 {aten::__and__,
"${0} && ${1}"},
207 {aten::__lshift__,
"${0} << ${1}"},
208 {aten::__or__,
"${0} || ${1}"},
209 {aten::__rshift__,
"${0} >> ${1}"},
210 {aten::__xor__,
"${0} ^ ${1}"},
211 {aten::div,
"${cast_0} / ${cast_1}"},
212 {aten::eq,
"${0} == ${1}"},
213 {aten::fmod,
"fmodf(${cast_0}, ${cast_1})"},
214 {aten::ge,
"(${0} >= ${1})"},
215 {aten::gt,
"${0} > ${1}"},
216 {aten::le,
"(${0} <= ${1})"},
217 {aten::lt,
"${0} < ${1}"},
218 {aten::type_as,
"(${cast_0})"},
219 {aten::mul,
"${cast_0} * ${cast_1}"},
220 {aten::ne,
"${0} != ${1}"},
221 {aten::remainder,
"remainderf(${0}, ${1})"},
222 {aten::pow,
"powf(${cast_0}, ${cast_1})"},
225 {aten::add,
"${cast_0} + ${cast_2}*${cast_1}"},
226 {aten::sub,
"(${cast_0} - ${cast_2}*${cast_1})"},
227 {aten::rand_like,
"uniform(rnd())"},
230 {aten::where,
"(${0} ? ${1} : ${2})"},
233 {aten::_sigmoid_backward,
"${0} * ${1} * (1.f - ${1})"},
234 {aten::_tanh_backward,
"${0} * (1.f - ${1} * ${1})"},
237 if (n->kind() == prim::Constant) {
238 const auto val = toIValue(n->output()).value();
239 if (val.isDouble()) {
240 return scalarValue(val.toDouble());
241 }
else if (val.isBool()) {
242 return scalarValue(val.toBool());
244 AT_ASSERT(val.isInt());
245 return scalarValue(val.toInt());
251 if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
252 return encodeSpecialRHS(n, env);
255 auto outtype = n->output()
260 for (
auto in : n->inputs()) {
263 env.s(std::to_string(i), valueName(in));
265 std::string(
"cast_") + std::to_string(i),
266 typeCastedValueName(in->type(), outtype, valueName(in)));
270 const auto& str = simple_map_ops.at(n->kind());
271 return format(str, env);
275 static void emitIndexingFor(
277 const std::string& tensor,
279 const bool last_is_cont) {
281 env.s(
"tensor", tensor);
282 out << format(
"IndexType ${tensor}_offset = 0;\n", env);
283 out << format(
"IndexType ${tensor}_linearIndex = linearIndex;\n", env);
284 for (
int d = ndim - 1; d >= 0; --d) {
286 env.s(
"mod_sizes", d > 0 ? format(
"% ${tensor}.sizes[${d}]", env) :
"");
289 (d < ndim - 1 || !last_is_cont)
290 ? format(
"* ${tensor}.strides[${d}]", env)
292 out << dim_calc.format(env);
294 out << format(
"${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
300 std::string generateKernel(
301 const std::string& name,
303 const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
304 const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
305 const bool use_cuda) {
307 env.s(
"kernelName", name);
312 std::stringstream body;
313 std::stringstream tensorOffsets;
314 std::vector<std::string> formals;
315 std::vector<std::string> argument_loads;
318 auto emitFormal = [&](
const Value* n,
const TensorDesc& desc) {
323 const auto nDim = desc.nDim();
324 emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
325 env.s(
"tensor", tensor);
331 env.s(
"scalar_type", scalarTypeName(desc.scalar_type));
333 format(
"TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
334 argument_loads.push_back(format(
335 "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
340 for (
const auto& input : inputs) {
341 emitFormal(input.first, input.second);
345 for (
const auto& output : outputs) {
346 emitFormal(output.first, output.second);
350 bool has_half_tensor =
false;
351 size_t formal_count = 0;
352 for (
const auto& input : inputs) {
353 auto p = input.first;
354 env.s(
"node", valueName(p));
355 env.d(
"formal", formal_count++);
361 const auto is_half = (input.second.scalar_type == at::ScalarType::Half);
366 format(
"__half2float(t${formal}.data[t${formal}_offset])", env));
367 has_half_tensor =
true;
369 env.s(
"access", format(
"t${formal}.data[t${formal}_offset]", env));
371 env.s(
"lhs_type", calcScalarTypeName(input.second.scalar_type));
373 body << format(
"${lhs_type} ${node} = ${access};\n", env);
376 bool has_random =
false;
382 for (
const auto& n : graph.nodes()) {
385 if (n->kind() == prim::FusedConcat)
387 if (n->kind() == prim::ConstantChunk)
391 if (n->kind() == aten::rand_like) {
396 env.s(
"node", valueName(n->output()));
397 env.s(
"rhs", encodeRHS(n));
398 env.s(
"lhs_type", variableType(n->output()->type()));
399 body << format(
"${lhs_type} ${node} = ${rhs};\n", env);
403 for (
const auto& output : outputs) {
404 env.d(
"formal", formal_count++);
405 env.s(
"access", format(
"t${formal}.data[t${formal}_offset]", env));
406 env.s(
"node", valueName(output.first));
410 const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
413 body << format(
"${access} = __float2half(${node});\n", env);
414 has_half_tensor =
true;
416 body << format(
"${access} = ${node};\n", env);
422 if (has_half_tensor) {
423 env.s(
"HalfHeader", cuda::half_support_literal);
425 env.s(
"HalfHeader",
"");
429 env.s(
"RandHeader", cuda::rand_support_literal);
430 env.s(
"RandParam", cuda::rand_param);
431 env.s(
"RandInit", cuda::rand_init);
433 env.s(
"RandHeader",
"");
434 env.s(
"RandParam",
"");
435 env.s(
"RandInit",
"");
439 env.s(
"tensorOffsets", tensorOffsets.str());
440 env.s(
"kernelBody", body.str());
441 env.v(
"formals", formals);
442 env.v(
"argument_loads", argument_loads);
443 std::string code_string;
445 env.s(
"type_declarations", cuda::type_declarations_template.format(env));
446 code_string = cuda::cuda_compilation_unit_template.format(env);
448 env.s(
"type_declarations", cpu::type_declarations_template.format(env));
449 code_string = cpu::cpu_compilation_unit_template.format(env);
453 std::cerr <<
"fusion code:" << code_string << std::endl;