Caffe2 - C++ API
A deep learning, cross platform ML framework
1 #include <torch/csrc/jit/fuser/codegen.h>
3 #include <ATen/ATen.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/jit/code_template.h>
6 #include <torch/csrc/jit/fuser/compiler.h>
7 #include <torch/csrc/jit/fuser/interface.h>
8 #include <torch/csrc/jit/fuser/tensor_info.h>
9 #include <torch/csrc/jit/ir.h>
11 #include <torch/csrc/jit/fuser/cpu/resource_strings.h>
12 #include <torch/csrc/jit/fuser/cuda/resource_strings.h>
14 #include <cmath>
15 #include <cstdint>
16 #include <iostream>
17 #include <sstream>
18 #include <tuple>
19 #include <vector>
21 namespace torch {
22 namespace jit {
23 namespace fuser {
25 // Template for computing the offset into the tensor to access a value
26 static auto dim_calc = CodeTemplate(R"(
27 //printf("tensor ${tensor} sizes[${d}] = %d, strides[${d}] = %d\n", ${tensor}.sizes[${d}],${tensor}.strides[${d}]);
28 size_t ${tensor}_dimIndex${d} = ${tensor}_linearIndex ${mod_sizes};
29 ${tensor}_offset += ${tensor}_dimIndex${d} ${times_stride};
30 )");
32 static std::string valueName(const Value* n) {
33  return "n" + std::to_string(n->unique());
34 }
36 static std::string scalarValue(const int64_t v) {
37  return std::to_string(v);
38 }
40 static std::string scalarValue(const bool v) {
41  return std::to_string(v);
42 }
44 // Note: The NAN, NEG_INFINITY and POS_INFINITY strings map to device-specific
45 // implementations of these special values. These macros are found in the
46 // resource strings for each device.
47 static std::string scalarValue(const double v) {
48  std::ostringstream out;
49  if (std::isnan(v)) {
50  out << "NAN";
51  } else if (std::isinf(v)) {
52  if (v < 0) {
53  out << "NEG_INFINITY";
54  } else {
55  out << "POS_INFINITY";
56  }
57  } else {
58  out << std::scientific << v << "f";
59  }
60  return out.str();
61 }
63 // Note: Half is special-cased to avoid returning at::Half
64 static const char* scalarTypeName(const at::ScalarType type) {
65  if (type == at::ScalarType::Half) {
66  return "half";
67  }
69  switch (type) {
70 #define DEFINE_CASE(ctype, name, _) \
71  case at::ScalarType::name: \
72  return #ctype;
74 #undef DEFINE_CASE
75  default:
76  throw std::runtime_error("unknown scalar type");
77  }
78 }
80 static const char* calcScalarTypeName(const at::ScalarType type) {
81  if (type == at::ScalarType::Half) {
82  return "float";
83  }
84  return scalarTypeName(type);
85 }
87 static std::string variableType(const std::shared_ptr<c10::Type>& t) {
88  if (t->kind() == TypeKind::IntType) {
89  return "int";
90  } else if (t->kind() == TypeKind::FloatType) {
91  return "float";
92  } else if (t->kind() == TypeKind::BoolType) {
93  return "bool";
94  } else if (t->kind() == TypeKind::DimensionedTensorType) {
95  auto const tt = t->cast<DimensionedTensorType>();
96  return calcScalarTypeName(tt->scalarType());
97  }
98  // something went wrong with the type analysis during shape propagation
99  throw std::runtime_error(
100  "unknown scalar type during JIT fusion code generation");
101 }
103 static std::string typeCastedValueName(
104  const std::shared_ptr<c10::Type>& t,
105  const at::ScalarType outtype,
106  const std::string& vn) {
107  if (t->kind() == TypeKind::IntType || t->kind() == TypeKind::BoolType) {
108  if (!isIntegralType(outtype)) {
109  return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
110  }
111  return vn;
112  } else if (t->kind() == TypeKind::FloatType) {
113  if (!isFloatingType(outtype)) {
114  return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
115  }
116  return vn;
117  } else if (t->kind() == TypeKind::DimensionedTensorType) {
118  auto const tt = t->cast<DimensionedTensorType>();
119  if (tt->scalarType() != outtype) {
120  return std::string("((") + calcScalarTypeName(outtype) + ") " + vn + ")";
121  }
122  return vn;
123  }
124  // something went wrong with the type analysis during shape propagation
125  throw std::runtime_error(
126  "unknown scalar type during JIT fusion code generation");
127 }
129 // Writes RHS of special handling "simple mappable" ops
130 static std::string encodeSpecialRHS(const Node* n, TemplateEnv& env) {
131  // special case for clamp fusion on missing min/max inputs
132  // Note: It may seem unusual to have the bounds as the first case below,
133  // this is so that if min or max is NaN, they are "ignored"
134  // and when the input is NaN, the output is, too
135  if (n->kind() == aten::clamp) {
136  const auto min = n->input(1);
137  const auto max = n->input(2);
138  env.s("0", valueName(n->input(0)));
140  if (!min->node()->mustBeNone() && !max->node()->mustBeNone()) {
141  env.s("1", valueName(min));
142  env.s("2", valueName(max));
143  return format("(${0} < ${1} ? ${1} : (${0} > ${2}? ${2} : ${0}))", env);
144  } else if (min->node()->mustBeNone()) {
145  env.s("1", valueName(max));
146  return format("(${0} > ${1} ? ${1} : ${0})", env);
147  } else if (max->node()->mustBeNone()) {
148  env.s("1", valueName(min));
149  return format("(${0} < ${1} ? ${1} : ${0})", env);
150  } else {
151  throw std::runtime_error(
152  "At least one of 'min' or 'max' must not be None");
153  }
154  } else {
155  throw std::runtime_error("Cannot encode RHS of the node, op not supported");
156  }
157 }
159 // Writes "simple mappable" ops
160 static std::string encodeRHS(const Node* n) {
161  static std::unordered_map<NodeKind, std::string> simple_map_ops = {
162  // unary
163  {aten::_cast_Float, "static_cast<float>(${0})"},
164  {aten::abs, "fabs(${0})"},
165  {aten::sigmoid, "1.f / (1.f + expf(-${0}))"},
166  {aten::relu, "${0} < 0 ? 0.f : ${0} "},
167  {aten::threshold,
168  "${0} <= ${1} ? static_cast<decltype(${0})>(${2}) : ${0} "},
169  {aten::log, "logf(${0})"},
170  {aten::log10, "log10f(${0})"},
171  {aten::log1p, "log1pf(${0})"},
172  {aten::log2, "log2f(${0})"},
173  {aten::lgamma, "lgammaf(${0})"},
174  {aten::exp, "expf(${0})"},
175  {aten::expm1, "expm1f(${0})"},
176  {aten::erf, "erff(${0})"},
177  {aten::erfc, "erfcf(${0})"},
178  {aten::cos, "cosf(${0})"},
179  {aten::acos, "acosf(${0})"},
180  {aten::cosh, "coshf(${0})"},
181  {aten::sin, "sinf(${0})"},
182  {aten::asin, "asinf(${0})"},
183  {aten::sinh, "sinhf(${0})"},
184  {aten::tan, "tanf(${0})"},
185  {aten::atan, "atanf(${0})"},
186  {aten::tanh, "tanhf(${0})"},
187  {aten::sqrt, "sqrtf(${0})"},
188  {aten::rsqrt, "rsqrtf(${0})"},
189  {aten::ceil, "ceilf(${0})"},
190  {aten::floor, "floorf(${0})"},
191  {aten::round, "roundf(${0})"},
192  {aten::trunc, "truncf(${0})"},
193  {aten::frac, "fracf(${0})"},
194  {aten::reciprocal, "1.f/(${0})"},
195  {aten::neg, "-${0}"},
196  // simple binary
197  {aten::atan2, "atan2(${0}, ${1})"},
198  {aten::min, "fminf(${0}, ${1})"},
199  {aten::max, "fmaxf(${0}, ${1})"},
201  // binary with other
202  // TODO: some of these ops will not get generated because
203  // we only work on float inputs/outputs, but they are here to record
204  // that they are valid mappable ops once we handle more type
206  {aten::__and__, "${0} && ${1}"},
207  {aten::__lshift__, "${0} << ${1}"},
208  {aten::__or__, "${0} || ${1}"},
209  {aten::__rshift__, "${0} >> ${1}"},
210  {aten::__xor__, "${0} ^ ${1}"},
211  {aten::div, "${cast_0} / ${cast_1}"},
212  {aten::eq, "${0} == ${1}"},
213  {aten::fmod, "fmodf(${cast_0}, ${cast_1})"},
214  {aten::ge, "(${0} >= ${1})"},
215  {aten::gt, "${0} > ${1}"},
216  {aten::le, "(${0} <= ${1})"},
217  {aten::lt, "${0} < ${1}"},
218  {aten::type_as, "(${cast_0})"},
219  {aten::mul, "${cast_0} * ${cast_1}"},
220  {aten::ne, "${0} != ${1}"},
221  {aten::remainder, "remainderf(${0}, ${1})"},
222  {aten::pow, "powf(${cast_0}, ${cast_1})"},
224  // alpha
225  {aten::add, "${cast_0} + ${cast_2}*${cast_1}"},
226  {aten::sub, "(${cast_0} - ${cast_2}*${cast_1})"},
227  {aten::rand_like, "uniform(rnd())"},
229  // where
230  {aten::where, "(${0} ? ${1} : ${2})"},
232  // simple derivatives
233  {aten::_sigmoid_backward, "${0} * ${1} * (1.f - ${1})"},
234  {aten::_tanh_backward, "${0} * (1.f - ${1} * ${1})"},
235  };
237  if (n->kind() == prim::Constant) {
238  const auto val = toIValue(n->output()).value();
239  if (val.isDouble()) {
240  return scalarValue(val.toDouble());
241  } else if (val.isBool()) {
242  return scalarValue(val.toBool());
243  } else {
244  AT_ASSERT(val.isInt());
245  return scalarValue(val.toInt());
246  }
247  }
249  TemplateEnv env;
251  if (simple_map_ops.find(n->kind()) == simple_map_ops.end()) {
252  return encodeSpecialRHS(n, env);
253  } else {
254  size_t i = 0;
255  auto outtype = n->output()
256  ->type()
257  ->expect<c10::DimensionedTensorType const>()
258  ->scalarType();
260  for (auto in : n->inputs()) {
261  // PyTorch converts (scalar) argument types to result before applying the
262  // operator e.g. 1.4-torch.tensor(3) = -2
263  env.s(std::to_string(i), valueName(in));
264  env.s(
265  std::string("cast_") + std::to_string(i),
266  typeCastedValueName(in->type(), outtype, valueName(in)));
267  i++;
268  }
270  const auto& str =>kind());
271  return format(str, env);
272  }
273 }
275 static void emitIndexingFor(
276  std::ostream& out,
277  const std::string& tensor,
278  const int ndim,
279  const bool last_is_cont) {
280  TemplateEnv env;
281  env.s("tensor", tensor);
282  out << format("IndexType ${tensor}_offset = 0;\n", env);
283  out << format("IndexType ${tensor}_linearIndex = linearIndex;\n", env);
284  for (int d = ndim - 1; d >= 0; --d) {
285  env.d("d", d);
286  env.s("mod_sizes", d > 0 ? format("% ${tensor}.sizes[${d}]", env) : "");
287  env.s(
288  "times_stride",
289  (d < ndim - 1 || !last_is_cont)
290  ? format("* ${tensor}.strides[${d}]", env)
291  : "");
292  out << dim_calc.format(env);
293  if (d > 0) {
294  out << format("${tensor}_linearIndex /= ${tensor}.sizes[${d}];\n", env);
295  }
296  }
297 }
299 // TODO: handle cases where we need to generate > 2^32 element tensors
300 std::string generateKernel(
301  const std::string& name,
302  const Graph& graph,
303  const std::vector<std::pair<const Value*, const TensorDesc>>& inputs,
304  const std::vector<std::pair<const Value*, const TensorDesc>>& outputs,
305  const bool use_cuda) {
306  TemplateEnv env;
307  env.s("kernelName", name);
308  env.s(
309  "IndexType",
310  "unsigned int"); // Note: not uint32_t to avoid including cstdint
312  std::stringstream body;
313  std::stringstream tensorOffsets;
314  std::vector<std::string> formals;
315  std::vector<std::string> argument_loads;
317  // Lambda for writing arguments
318  auto emitFormal = [&](const Value* n, const TensorDesc& desc) {
319  std::string tensor =
320  "t" +
321  std::to_string(
322  formals.size()); // can't be unique() because Param may be an output
323  const auto nDim = desc.nDim();
324  emitIndexingFor(tensorOffsets, tensor, nDim, desc.lastIsContiguous());
325  env.s("tensor", tensor);
326  env.d(
327  "formal_index",
328  formals.size() +
329  1); // + 1 because the first argument is the linearIndex
330  env.d("nDim", nDim);
331  env.s("scalar_type", scalarTypeName(desc.scalar_type));
332  formals.push_back(
333  format("TensorInfo<${scalar_type},${nDim}> ${tensor}", env));
334  argument_loads.push_back(format(
335  "*static_cast<TensorInfo<${scalar_type},${nDim}>*>(args[${formal_index}])",
336  env));
337  };
339  // Writes input parameters
340  for (const auto& input : inputs) {
341  emitFormal(input.first, input.second);
342  }
344  // Writes output parameters
345  for (const auto& output : outputs) {
346  emitFormal(output.first, output.second);
347  }
349  // Acquires input values
350  bool has_half_tensor = false;
351  size_t formal_count = 0;
352  for (const auto& input : inputs) {
353  auto p = input.first;
354  env.s("node", valueName(p));
355  env.d("formal", formal_count++);
357  // Acquires and converts (if needed) inputs
358  // Note: conversion from half is only supported for CUDA kernels.
359  // The conversion immediately converts fp16 inputs to float.
360  // Access for other types is common to CUDA and CPU kernels.
361  const auto is_half = (input.second.scalar_type == at::ScalarType::Half);
362  if (is_half) {
363  AT_ASSERT(use_cuda);
364  env.s(
365  "access",
366  format("__half2float(t${formal}.data[t${formal}_offset])", env));
367  has_half_tensor = true;
368  } else {
369  env.s("access", format("t${formal}.data[t${formal}_offset]", env));
370  }
371  env.s("lhs_type", calcScalarTypeName(input.second.scalar_type));
373  body << format("${lhs_type} ${node} = ${access};\n", env);
374  }
376  bool has_random = false;
377  // Generates code for intermediate nodes
378  // Note: Concat and Chunk are implicitly generated
379  // Note: Random number generation is only supported for CUDA kernels.
380  // Note: Constant None node is ignored and we will handle it in the
381  // places where the constant None node is used
382  for (const auto& n : graph.nodes()) {
383  // Note: FusedConcat nodes work by narrowing the output Tensors before the
384  // kernel runs
385  if (n->kind() == prim::FusedConcat)
386  continue;
387  if (n->kind() == prim::ConstantChunk)
388  continue;
389  if (n->mustBeNone())
390  continue;
391  if (n->kind() == aten::rand_like) {
392  AT_ASSERT(use_cuda);
393  has_random = true;
394  }
396  env.s("node", valueName(n->output()));
397  env.s("rhs", encodeRHS(n));
398  env.s("lhs_type", variableType(n->output()->type()));
399  body << format("${lhs_type} ${node} = ${rhs};\n", env);
400  }
402  // Generates writes to output tensors
403  for (const auto& output : outputs) {
404  env.d("formal", formal_count++);
405  env.s("access", format("t${formal}.data[t${formal}_offset]", env));
406  env.s("node", valueName(output.first));
408  // Acquires and converts (if needed) outputs
409  // Note: conversion to half is only supported for CUDA kernels.
410  const auto is_half = (output.second.scalar_type == at::ScalarType::Half);
411  if (is_half) {
412  AT_ASSERT(use_cuda);
413  body << format("${access} = __float2half(${node});\n", env);
414  has_half_tensor = true;
415  } else {
416  body << format("${access} = ${node};\n", env);
417  }
418  }
420  // Includes headers
421  // Note: CUDA kernels support halfs and random generation, CPU kernels do not
422  if (has_half_tensor) {
423  env.s("HalfHeader", cuda::half_support_literal);
424  } else {
425  env.s("HalfHeader", "");
426  }
428  if (has_random) {
429  env.s("RandHeader", cuda::rand_support_literal);
430  env.s("RandParam", cuda::rand_param);
431  env.s("RandInit", cuda::rand_init);
432  } else {
433  env.s("RandHeader", "");
434  env.s("RandParam", "");
435  env.s("RandInit", "");
436  }
438  // Insantiates the CUDA or CPU-specific templates
439  env.s("tensorOffsets", tensorOffsets.str());
440  env.s("kernelBody", body.str());
441  env.v("formals", formals);
442  env.v("argument_loads", argument_loads);
443  std::string code_string;
444  if (use_cuda) {
445  env.s("type_declarations", cuda::type_declarations_template.format(env));
446  code_string = cuda::cuda_compilation_unit_template.format(env);
447  } else {
448  env.s("type_declarations", cpu::type_declarations_template.format(env));
449  code_string = cpu::cpu_compilation_unit_template.format(env);
450  }
452  if (debugFuser()) {
453  std::cerr << "fusion code:" << code_string << std::endl;
454  }
455  return code_string;
456 }
458 } // namespace fuser
459 } // namespace jit
460 } // namespace torch
Definition: jit_type.h:17