Caffe2 - C++ API
A deep learning, cross platform ML framework
symbolic_script.cpp
1 #include <torch/csrc/jit/symbolic_script.h>
2 
3 namespace torch {
4 namespace jit {
5 namespace {
6 std::mutex lock;
7 const std::vector<std::string> functions = {
8  R"(
9 
10  #### HELPER FUNCTIONS ###
11  #### PREFIX: AD_ ###
12  #### SCHEMA NOT SAVED IN CACHE ###
13 
14  def AD_unsqueeze_multiple(t,
15  dims: List[int],
16  n_dims: int):
17  seen = [False] * n_dims
18  for i in range(len(dims)):
19  seen[dims[i]] = True
20 
21  for d in range(n_dims):
22  if seen[d]:
23  t = t.unsqueeze(d)
24  return t
25 
26  def AD_sum_backward(grad,
27  sizes: List[int],
28  dims: List[int],
29  keepdim: bool):
30  if not keepdim and len(sizes) > 0:
31  if len(dims) == 1:
32  return grad.unsqueeze(dims[0]).expand(sizes)
33  else:
34  res = AD_unsqueeze_multiple(grad, dims, len(sizes))
35  return res.expand(sizes)
36  else:
37  return grad.expand(sizes)
38 
39  def AD_logsumexp_backward(grad, self, result,
40  dim: List[int],
41  keepdim: bool):
42  if not keepdim and self.dim() != 0:
43  n_dims = len(self.size())
44  grad = AD_unsqueeze_multiple(grad, dim, n_dims)
45  result = AD_unsqueeze_multiple(result, dim, n_dims)
46  return grad * (self - result).exp()
47 
48  def mean_0(self):
49  self_size = self.size()
50  self_numel = self.numel()
51  def backward(grad_output):
52  grad_self = grad_output.expand(self_size) / self_numel
53  return grad_self
54 
55  return torch.mean(self), backward
56 
57  def mean_1(self,
58  dim: List[int],
59  keepdim: bool):
60  self_size = self.size()
61  def backward(grad_output):
62  grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim) / AD_safe_size(self_size, dim)
63  return grad_self, None, None
64 
65  return torch.mean(self, dim, keepdim), backward
66 
67  def logsumexp(self,
68  dim: List[int],
69  keepdim: bool):
70  result = torch.logsumexp(self, dim, keepdim)
71  self_dim = self.dim()
72  def backward(grad_output):
73  grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
74  return grad_self, None, None
75 
76  return result, backward
77 
78  def AD_bool_to_int(b: bool):
79  # FIXME: torchscript: int - bool
80  if b:
81  i = 1
82  else:
83  i = 0
84  return i
85 
86  def AD_var_backward_0(grad, self, unbiased: bool):
87  b = AD_bool_to_int(unbiased)
88 
89  # FIXME: torchscript: div(float, float)
90  return grad * (self - self.mean()) * 2.0 / (self.numel() - b)
91 
92  def AD_safe_size(sizes: List[int],
93  dims: List[int]):
94  if len(sizes) == 0:
95  return 1
96 
97  size = 1
98  for i in range(len(dims)):
99  d = dims[i]
100  size *= sizes[d]
101 
102  return size
103 
104  def AD_var_backward_1(grad,
105  self,
106  dim: List[int],
107  unbiased: bool,
108  keepdim: bool):
109  if self.dim() == 0:
110  return AD_var_backward_0(grad, self, unbiased)
111  self_size = self.size()
112  b = AD_bool_to_int(unbiased)
113  if not keepdim and self.dim() > 1:
114  grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
115 
116  # FIXME: torchscript: div(float, float)
117  return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - b)
118 
119  def std_0(self,
120  unbiased: bool=True):
121  std_out = torch.std(self, unbiased)
122  def backward(grad_output):
123  grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, unbiased)
124  return grad_self, None
125 
126  return std_out, backward
127 
128  def std_1(self,
129  dim: List[int],
130  unbiased: bool,
131  keepdim: bool):
132  std_out = torch.std(self, dim, unbiased, keepdim)
133  def backward(grad_output):
134  grad_self = AD_var_backward_1(grad_output / (std_out * 2), self, dim, unbiased, keepdim)
135  return grad_self, None, None, None
136 
137  return std_out, backward
138 
139  def var_0(self,
140  unbiased: bool=True):
141  def backward(grad_output):
142  grad_self = AD_var_backward_0(grad_output, self, unbiased)
143  return grad_self, None
144 
145  return torch.var(self, unbiased), backward
146 
147  def var_1(self,
148  dim: List[int],
149  unbiased: bool,
150  keepdim: bool):
151  def backward(grad_output):
152  grad_self = AD_var_backward_1(grad_output, self, dim, unbiased, keepdim)
153  return grad_self, None, None, None
154 
155  return torch.var(self, dim, unbiased, keepdim), backward
156 
157  def tanh(self):
158  output = torch.tanh(self)
159  def backward(grad_output):
160  return grad_output * (1 - output * output)
161 
162  return output, backward
163 
164  def AD_index_select_backward(grad,
165  dim: int,
166  indices,
167  sizes: List[int],
168  keepdim: bool):
169  if not keepdim and len(sizes) > 0:
170  grad = grad.unsqueeze(dim)
171  indices = indices.unsqueeze(dim)
172 
173  # FIXME: torchscript: torch.zeros(sizes, grad.options())
174  return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
175 
176  # def topk(self,
177  # k: int,
178  # dim: int = -1,
179  # largest: bool = True,
180  # sorted: bool = True):
181  # result0, result1 = torch.topk(self, k, dim, largest, sorted)
182  # self_size = self.size()
183  # def backward(grad_output):
184  # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
185  # return grad_self, None, None, None, None
186 
187  # return result0, result1, backward
188 
189  # def kthvalue(self,
190  # k: int,
191  # dim: int,
192  # keepdim: bool):
193  # result0, result1 = torch.kthvalue(self, k, dim, keepdim)
194  # self_size = self.size()
195  # def backward(grad_output):
196  # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
197  # return grad_self, None, None, None
198 
199  # return result0, result1, backward
200 
201  def AD_mm_backward_self(grad, mat2):
202  return grad.mm(mat2.t())
203 
204  def AD_mm_backward_mat2(grad, self):
205  return self.t().mm(grad)
206 
207  def mm(self, mat2):
208  def backward(grad_output):
209  grad_self = AD_mm_backward_self(grad_output, mat2)
210  grad_mat2 = AD_mm_backward_mat2(grad_output, self)
211  return grad_self, grad_mat2
212 
213  return torch.mm(self, mat2), backward
214 
215  def AD_permute_backward(grad,
216  fwd_dims: List[int]):
217  ndims = len(fwd_dims)
218  dims = [0] * ndims
219 
220  for i in range(ndims):
221  dims[fwd_dims[i]] = i
222 
223  return grad.permute(dims)
224 
225  def permute(self,
226  dims: List[int]):
227  def backward(grad_output):
228  grad_self = AD_permute_backward(grad_output, dims)
229  return grad_self, None
230 
231  return torch.permute(self, dims), backward
232 
233  def AD_select_backward(grad,
234  input_sizes: List[int],
235  dim: int,
236  index: int):
237  # FIXME: torchscript: torch.zeros(sizes, grad.options())
238  grad_input = torch.zeros(input_sizes).to(grad)
239  grad_input.select(dim, index).copy_(grad)
240  return grad_input
241 
242  # TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
243  # def select(self,
244  # dim: int,
245  # index: int):
246  # self_size = self.size()
247  # def backward(grad_output):
248  # grad_self = AD_select_backward(grad_output, self_size, dim, index)
249  # return grad_self, None, None
250 
251  # return torch.select(self, dim, index), backward
252 
253  def AD_slice_backward(grad,
254  input_sizes: List[int],
255  dim: int,
256  start: int,
257  end: int,
258  step: int):
259  # FIXME: torchscript: torch.zeros(sizes, grad.options())
260  grad_input = torch.zeros(input_sizes).to(grad)
261  grad_input.slice(dim, start, end, step).copy_(grad)
262  return grad_input
263 
264  # DON'T enable slice unless we can correctly handle view ops in graph executor.
265  # It triggers failure of TestJit.test_sample in test_distributions.py.
266  # def slice(self,
267  # dim: int=0,
268  # start: int=0,
269  # end: int=9223372036854775807,
270  # step: int=1):
271  # def backward(grad_output):
272  # grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
273  # return grad_self, None, None, None, None
274 
275  # return torch.slice(self, dim, start, end, step), backward
276 
277  def AD_unsqueeze_to_0(self,
278  sizes: List[int]):
279  ndims = len(sizes)
280  for i in range(ndims):
281  if sizes[i] == 1:
282  self = self.unsqueeze(i)
283 
284  return self
285 
286  def AD_unsqueeze_to_1(self,
287  dim: int,
288  sizes: List[int]):
289  if len(sizes) > 0 and sizes[dim] == 1:
290  return self.unsqueeze(dim)
291  return self
292 
293  def squeeze_0(self):
294  self_size = self.size()
295  def backward(grad_output):
296  grad_self = AD_unsqueeze_to_0(grad_output, self_size)
297  return grad_self
298 
299  return torch.squeeze(self), backward
300 
301  def squeeze_1(self,
302  dim: int):
303  self_size = self.size()
304  def backward(grad_output):
305  grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
306  return grad_self, None
307 
308  return torch.squeeze(self, dim), backward
309 
310  def AD_infer_size(a: List[int],
311  b: List[int]):
312  dimsA = len(a)
313  dimsB = len(b)
314 
315  ndim = dimsA if dimsA > dimsB else dimsB
316  expand_sizes = [0] * ndim
317 
318  for i in range(ndim):
319  idx = - i + ndim - 1
320  sizeA = a[i] if dimsA + i >= 0 else 1
321  sizeB = b[i] if dimsB + i >= 0 else 1
322 
323  # Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
324  expand_sizes[i] = sizeB if sizeA == 1 else sizeA
325 
326  return expand_sizes
327 
328  def AD_bmm_backward_self(grad, mat2):
329  return grad.bmm(mat2.transpose(1, 2))
330 
331  def AD_bmm_backward_mat2(grad, self):
332  return self.transpose(1, 2).bmm(grad)
333 
334  def bmm(self, mat2):
335  def backward(grad_output):
336  grad_self = AD_bmm_backward_self(grad_output, mat2)
337  grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
338  return grad_self, grad_mat2
339  return torch.bmm(self, mat2), backward
340 
341  def AD_mat_transpose(mat):
342  dim = mat.dim()
343  if dim == 1:
344  out = mat
345  elif dim == 2:
346  out = mat.t()
347  else:
348  dims = rangelist(dim)
349  dims[-1] = dim - 2
350  dims[-2] = dim - 1
351  out = mat.permute(dims)
352  return out
353 
354  def AD_matmul_size(mat1, mat2,
355  out_size: List[int]):
356  dim1 = mat1.dim()
357  dim2 = mat2.dim()
358  dim_out = len(out_size)
359  if dim1 == 0 or dim2 == 0:
360  out = mat1 * mat2
361  elif dim1 + dim2 == dim_out:
362  if dim2 == 1:
363  target_dim2 = 0
364  else:
365  target_dim2 = -2
366  out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
367  elif dim_out == dim1 - dim2:
368  out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
369  elif dim_out == dim2 - dim1:
370  out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
371  else:
372  out = torch.matmul(mat1, mat2)
373  return out
374 
375  def matmul(self, other):
376  def backward(grad_output):
377  self_size = self.size()
378  other_size = other.size()
379  grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
380  grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
381  return grad_self, grad_other
382 
383  return torch.matmul(self, other), backward
384  )",
385  R"(
386  def _dim_arange(like,
387  dim: int):
388  def backward(grad_output):
389  return None, None
390 
391  return torch._dim_arange(like, dim), backward
392 
393  def contiguous(self):
394  def backward(grad_output):
395  return None
396 
397  return self.contiguous(), backward
398 
399  def dot(self, tensor):
400  def backward(grad_output):
401  grad_self = grad_output * tensor
402  grad_tensor = grad_output * self
403  return grad_self, grad_tensor
404 
405  return torch.dot(self, tensor), backward
406 
407  def erf(self):
408  def backward(grad_output):
409  # Precomputed constant C = 2.0 / math.sqrt(math.pi)
410  C = 1.1283791670955126
411  grad_self = C * torch.exp(- self * self) * grad_output
412  return grad_self
413 
414  return torch.erf(self), backward
415 
416  def expand(self,
417  size: List[int],
418  *,
419  implicit: bool=False):
420  self_size = self.size()
421  def backward(grad_output):
422  grad_self = torch._grad_sum_to_size(grad_output, self_size)
423  return grad_self, None, None
424 
425  return torch.expand(self, size, implicit=implicit), backward
426 
427  def expand_as(self, other):
428  self_size = self.size()
429  def backward(grad_output):
430  grad_self = grad_output._grad_sum_to_size(self_size)
431  return grad_self, None
432 
433  return torch.expand_as(self, other), backward
434 
435  def full_like(self,
436  fill_value: float):
437  def backward(grad_output):
438  return None, None
439 
440  return torch.full_like(self, fill_value), backward
441 
442  def mul(self, other):
443  def backward(grad_output):
444  # self & other are used in backward. No need to pass in their size
445  # from forward pass
446  grad_self = (grad_output * other)._grad_sum_to_size(self.size())
447  grad_other = (grad_output * self)._grad_sum_to_size(other.size())
448  return grad_self, grad_other
449 
450  return self * other, backward
451 
452  def mv(self, vec):
453  def backward(grad_output):
454  grad_self = grad_output.ger(vec)
455  grad_vec = self.t().mv(grad_output)
456  return grad_self, grad_vec
457 
458  return torch.mv(self, vec), backward
459 
460  def nonzero(self):
461  def backward(grad_output):
462  return None
463 
464  return torch.nonzero(self), backward
465 
466  def ones_like(self):
467  def backward(grad_output):
468  return None
469 
470  return torch.ones_like(self), backward
471 
472  def pow_0(self,
473  exponent: float):
474  def backward(grad_output):
475  grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))
476  return grad_self, None
477 
478  return torch.pow(self, exponent), backward
479 
480  def pow_1(self, exponent):
481  def backward(grad_output):
482  # self & exponent are used in backward, no need to pass in its size explicitly
483  grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self.size())
484  grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent.size())
485  return grad_self, grad_exponent
486 
487  return torch.pow(self, exponent), backward
488 
489  def pow_2(self: float,
490  exponent):
491  def backward(grad_output):
492  grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self))
493  return None, grad_exponent
494 
495  return torch.pow(self, exponent), backward
496 
497  def rsub_0(self, other,
498  alpha: float = 1.0):
499  self_size = self.size()
500  other_size = other.size()
501  def backward(grad_output):
502  grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
503  grad_other = (grad_output)._grad_sum_to_size(other_size)
504  return grad_self, grad_other, None
505 
506  return torch.rsub(self, other, alpha), backward
507 
508  def rsub_1(self,
509  other: float,
510  alpha: float = 1.0):
511  self_size = self.size()
512  def backward(grad_output):
513  grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
514  return grad_self, None, None
515 
516  return torch.rsub(self, other, alpha), backward
517 
518  def sqrt(self):
519  result = torch.sqrt(self)
520  def backward(grad_output):
521  grad_self = grad_output / (2 * result)
522  return grad_self
523 
524  return result, backward
525 
526  def t(self):
527  def backward(grad_output):
528  grad_self = torch.t(grad_output)
529  return grad_self
530 
531  return torch.t(self), backward
532 
533  def to_0(self,
534  device: Optional[Device],
535  dtype: Optional[int],
536  non_blocking: bool=False,
537  copy: bool=False):
538  self_device = self.device
539  self_dtype = self.dtype
540  if device is not None:
541  result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
542  else:
543  result = self.to(dtype, non_blocking=non_blocking, copy=copy)
544  def backward(grad_output):
545  grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
546  return grad_self, None, None, None, None
547 
548  return result, backward
549 
550 
551  def to_1(self,
552  dtype: int,
553  non_blocking: bool=False,
554  copy: bool=False):
555  self_dtype = self.dtype
556  def backward(grad_output):
557  grad_self = grad_output.to(self_dtype, non_blocking, copy)
558  return grad_self, None, None, None
559 
560  return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
561 
562  def to_2(self,
563  other,
564  non_blocking: bool=False,
565  copy: bool=False):
566  def backward(grad_output):
567  grad_self = grad_output.to(self, non_blocking, copy)
568  return grad_self, None, None, None
569 
570  return self.to(other, non_blocking=non_blocking, copy=copy), backward
571 
572  def transpose(self,
573  dim0: int,
574  dim1: int):
575  def backward(grad_output):
576  grad_self = torch.transpose(grad_output, dim0, dim1)
577  return grad_self, None, None
578 
579  return torch.transpose(self, dim0, dim1), backward
580 
581  def view(self,
582  size: List[int]):
583  self_size = self.size()
584  def backward(grad_output):
585  grad_self = grad_output.reshape(self_size)
586  return grad_self, None
587 
588  return torch.view(self, size), backward
589  )",
590  R"(
591  def AD_adaptive_avg_pool2d_backward(grad,
592  self,
593  output_size: List[int]):
594  if output_size[0] == 1 and output_size[1] == 1:
595  self_size = self.size()
596  grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
597  else:
598  grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
599 
600  return grad_self
601 
602  def AD_adaptive_avg_pool1d_backward(grad,
603  input,
604  output_size: List[int]):
605  output_size_2d = [1, output_size[0]]
606  grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
607  return grad_input
608 
609  def adaptive_avg_pool1d(self,
610  output_size: List[int]):
611  def backward(grad_output):
612  grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
613  return grad_self, None
614 
615  return torch.adaptive_avg_pool1d(self, output_size), backward
616 
617  def adaptive_avg_pool2d(self,
618  output_size: List[int]):
619  def backward(grad_output):
620  # self is used in backward, no need to pass in its size explicitly
621  grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
622  return grad_self, None
623  return torch.adaptive_avg_pool2d(self, output_size), backward
624 
625  def adaptive_avg_pool3d(self,
626  output_size: List[int]):
627  def backward(grad_output):
628  grad_self = torch.adaptive_avg_pool3d_backward(grad_output, self)
629  return grad_self, None
630 
631  return torch.adaptive_avg_pool3d(self, output_size), backward
632 
633  def batch_norm(input : Tensor,
634  weight : Optional[Tensor],
635  bias : Optional[Tensor],
636  running_mean : Optional[Tensor],
637  running_var : Optional[Tensor],
638  training : bool,
639  momentum : float,
640  eps : float,
641  cudnn_enabled : bool):
642 
643  output, save1, save2, impl_idx = torch._batch_norm_impl_index(
644  input, weight, bias, running_mean, running_var, training,
645  momentum, eps, cudnn_enabled)
646  has_weight = weight is not None
647  has_bias = bias is not None
648 
649  def backward(grad_output):
650  dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
651  impl_idx, input, grad_output, weight, running_mean, running_var,
652  save1, save2, training, eps, [True, has_weight, has_bias])
653  return dinput, dweight, dbias, None, None, None, None, None, None
654 
655  return output, backward
656 
657  def layer_norm(input : Tensor,
658  normalied_shape : List[int],
659  weight : Optional[Tensor],
660  bias : Optional[Tensor],
661  eps : float,
662  cudnn_enable : bool):
663 
664  bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index(
665  input, weight, bias, None, None, True,
666  0.0, eps, cudnn_enable)
667  has_weight = weight is not None
668  has_bias = bias is not None
669 
670  bn_out = bn_out.view(input.sizes())
671  if weight is not None and bias is not None:
672  output = bias.addcmul(bn_out, weight)
673  elif weight is not None:
674  output = bn_out.mul(weight)
675  elif bias is not None:
676  output = bn_out.add(bias)
677  else:
678  output = bn_out
679 
680  def backward(grad_output):
681  if weight is not None:
682  grad_output = grad_output * torch.t(weight)
683  weight = grad_output * torch.t(bn_out)
684 
685  grad_output = grad_output.reshape(input.sizes())
686 
687  dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
688  impl_idx, input, grad_output, weight, None, None,
689  save1, save2, True, eps, [True, has_weight, has_bias])
690  return dinput, None, dweight, dbias, None, None
691 
692  return output, backward
693 
694  def AD_fused_dropout_backward(grad,
695  mask,
696  p1m: float):
697  p1r = 1. / p1m
698  if grad.requires_grad:
699  grad_input = grad * (mask.type_as(grad) * p1r)
700  else:
701  grad_input = torch._masked_scale(grad, mask, p1r)
702  return grad_input
703 
704  def dropout(input,
705  p: float,
706  train: bool):
707  use_cuda = input.is_cuda
708  # CUDA has a fused dropout implementation
709  p1m = 1. - p
710  if use_cuda:
711  res, mask = torch._fused_dropout(input, p1m)
712  else:
713  mask = torch.empty_like(input)
714  mask.bernoulli_(p1m)
715  res = mask * input / p1m
716 
717  def backward(grad_output):
718  if use_cuda:
719  grad_input = AD_fused_dropout_backward(grad_output, mask, p1m)
720  else:
721  grad_input = grad_output * mask / p1m
722  return grad_input, None, None
723  return res, backward
724 
725  def embedding(weight,
726  indices,
727  padding_idx: int,
728  scale_grad_by_freq: bool,
729  sparse: bool):
730  weight_size_0 = weight.size()[0]
731  def backward(grad_output):
732  grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
733  return grad_weight, None, None, None, None
734 
735  return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
736 
737  def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
738  result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
739  def backward(grad):
740  return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
741  return result, backward
742 
743  def softmax_0(self, dim: int):
744  result = torch.softmax(self, dim)
745  def backward(grad_output):
746  grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
747  return grad_self, None
748 
749  return result, backward
750 
751  def softmax_1(self, dim: int, dtype: int):
752  result = torch.softmax(self, dim, dtype)
753  def backward(grad_output):
754  grad_self = torch._softmax_backward_data(grad_output, result, dim, self)
755  return grad_self, None, None
756 
757  return torch.softmax(self, dim, dtype), backward
758 
759  def AD_interpolate_backward(grad,
760  input,
761  mode: str,
762  align_corners: bool):
763  output_size = grad.size()[2:]
764  input_size = input.size()
765  input_dim = len(input_size)
766  if input_dim == 3 and mode == 'nearest':
767  grad_input = torch.upsample_nearest1d_backward(grad, output_size, input_size)
768  elif input_dim == 4 and mode == 'nearest':
769  grad_input = torch.upsample_nearest2d_backward(grad, output_size, input_size)
770  elif input_dim == 5 and mode == 'nearest':
771  grad_input = torch.upsample_nearest3d_backward(grad, output_size, input_size)
772  elif input_dim == 3 and mode == 'linear':
773  grad_input = torch.upsample_linear1d_backward(grad, output_size, input_size, align_corners)
774  elif input_dim == 4 and mode == 'bilinear':
775  grad_input = torch.upsample_bilinear2d_backward(grad, output_size, input_size, align_corners)
776  elif input_dim == 5 and mode == 'trilinear':
777  grad_input = torch.upsample_trilinear3d_backward(grad, output_size, input_size, align_corners)
778  elif input_dim == 4 and mode == 'bicubic':
779  grad_input = torch.upsample_bicubic2d_backward(grad, output_size, input_size, align_corners)
780  elif input_dim == 3 and mode == 'area':
781  grad_input = AD_adaptive_avg_pool1d_backward(grad, input, output_size)
782  elif input_dim == 4 and mode == 'area':
783  grad_input = AD_adaptive_avg_pool2d_backward(grad, input, output_size)
784  elif input_dim == 5 and mode == 'area':
785  grad_input = torch.adaptive_avg_pool3d_backward(grad, input)
786  else:
787  # NEVER REACH HERE
788  grad_input = torch.zeros_like(input)
789  raise RuntimeError('Input Error: Only 3D, 4D and 5D input Tensors supported')
790 
791  return grad_input
792 
793  def __interpolate_0(input,
794  size: Optional[int],
795  scale_factor: Optional[List[float]],
796  mode: str='nearest',
797  align_corners: Optional[bool]):
798  def backward(grad_output):
799  if align_corners is None:
800  align_corners = False
801  grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
802  return grad_self, None, None, None, None
803 
804  return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
805 
806  def __interpolate_1(input,
807  size: Optional[List[int]],
808  scale_factor: Optional[List[float]],
809  mode: str='nearest',
810  align_corners: Optional[bool]):
811  def backward(grad_output):
812  if align_corners is None:
813  align_corners = False
814  grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
815  return grad_self, None, None, None, None
816 
817  return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
818 
819  def __interpolate_2(input,
820  size: Optional[int],
821  scale_factor: Optional[float],
822  mode: str='nearest',
823  align_corners: Optional[bool]):
824  def backward(grad_output):
825  if align_corners is None:
826  align_corners = False
827  grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
828  return grad_self, None, None, None, None
829 
830  return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
831 
832  def __interpolate_3(input,
833  size: Optional[List[int]],
834  scale_factor: Optional[float],
835  mode: str='nearest',
836  align_corners: Optional[bool]):
837  def backward(grad_output):
838  if align_corners is None:
839  align_corners = False
840  grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners)
841  return grad_self, None, None, None, None
842 
843  return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward
844 
845  )"};
846 std::unordered_map<std::string, GradientPair> schema_to_graphs;
847 
848 // This map is a workaround to cache compiled gradient_pairs. Ideally this graph
849 // should be compiled only once and saved in Operator structure.
850 // This should be done along with merging into native_functions.yaml.
851 std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
852 } // anonymous namespace
853 
854 std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
855  AT_CHECK(
856  closure->node()->kind() == prim::TupleConstruct,
857  "closure must be a literal tuple construct");
858  Value* fn = closure->node()->inputs().at(0);
859  Value* context = closure->node()->inputs().at(1);
860 
861  AT_CHECK(
862  fn->node()->kind() == prim::Function,
863  "closure tuple must contain a prim::Function");
864  return std::make_pair(fn->node()->g(attr::Subgraph), context);
865 }
866 
867 Argument originalReturnType(const TupleTypePtr& tup) {
868  AT_CHECK(tup->elements().size() > 1);
869  if (tup->elements().size() == 2)
870  return Argument("", tup->elements().at(0));
871  std::vector<TypePtr> types = tup->elements().vec();
872  types.pop_back();
873  return Argument("", TupleType::create(std::move(types)));
874 }
875 
876 // In torchscript AD formulas, we define {func_0, func_1, ...} as
877 // overloaded functions of `func`.
878 // Remove the suffix before adding the schema string to map
879 // schema_to_graphs.
880 std::string overloadedSchemaString(const FunctionSchema& schema) {
881  const auto& schema_name = schema.name();
882  auto pos = schema_name.find_last_of('_');
883  auto schema_name_suffix = schema_name.substr(pos + 1);
884  std::string schema_string = canonicalSchemaString(schema);
885  if (!schema_name_suffix.empty() &&
886  schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
887  schema_string.replace(
888  schema_string.find(schema_name),
889  schema_name.length(),
890  schema_name.substr(0, pos));
891  }
892  return schema_string;
893 }
894 
895 bool isHelperFunction(const std::string& method_name) {
896  std::string helper_prefix = "AD_";
897  return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
898 }
899 
900 void loadModule(const std::shared_ptr<script::Module>& module) {
901  for (const auto& method_ : module->get_methods()) {
902  if (isHelperFunction(method_.key()))
903  continue;
904 
905  const auto& method = method_.value();
906  GradientPair pair;
907  pair.forward = method->graph();
908 
909  // lookup the backward function
910  Node* forward_tuple = pair.forward->outputs().at(0)->node();
911 
912  if (forward_tuple->kind() != prim::TupleConstruct) {
913  throw script::ErrorReport(forward_tuple->getSourceLocation())
914  << "gradient must return literal a tuple";
915  }
916 
917  Value* context;
918  std::tie(pair.backward, context) =
919  extractClosure(forward_tuple->inputs().back());
920 
921  // do surgery on the forward function to remove the closure tuple and
922  // replace it with the context variable:
923  // backward = (<lambda>, context_tuple)
924  // return original, backward
925  // -----
926  // return original, context_tuple
927  std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
928  new_inputs.back() = context;
929  Value* new_tuple =
930  pair.forward->appendNode(pair.forward->createTuple(new_inputs))
931  ->output();
932  pair.forward->eraseOutput(0);
933  pair.forward->registerOutput(new_tuple);
934  forward_tuple->destroy();
935 
936  // derive schema from original function's schema:
937  const FunctionSchema& loaded_schema = method->getSchema();
938  FunctionSchema actual_schema(
939  Symbol::aten(loaded_schema.name()),
940  loaded_schema.overload_name(),
941  loaded_schema.arguments(),
942  {originalReturnType(new_tuple->type()->expect<TupleType>())});
943 
944  // modify canonical string for function overloading
945  // prefer not to modify the schema name
946  auto schema_string = overloadedSchemaString(actual_schema);
947 
948  schema_to_graphs[schema_string] = std::move(pair);
949  }
950 }
951 
952 void loadFunctions() {
953  for (const std::string& str : functions) {
954  auto cu = std::make_shared<script::Module>();
955  script::defineMethodsInModule(
956  cu, str, script::nativeResolver, c10::nullopt);
957  loadModule(cu);
958  }
959 }
960 
961 c10::optional<GradientPair> gradientInfoForSchema(
962  const FunctionSchema& schema) {
963  std::lock_guard<std::mutex> guard(lock);
964  if (schema_to_graphs.size() == 0) {
965  loadFunctions();
966  }
967  auto cache_it = cached_gradient_pairs.find(&schema);
968  if (cache_it != cached_gradient_pairs.end()) {
969  return cache_it->second;
970  } else {
971  auto schema_str = canonicalSchemaString(schema);
972  auto sym_script_it = schema_to_graphs.find(schema_str);
973 
974  if (sym_script_it != schema_to_graphs.end()) {
975  cached_gradient_pairs.emplace_hint(
976  cache_it, &schema, sym_script_it->second);
977  return sym_script_it->second;
978  }
979  }
980  return c10::nullopt;
981 }
982 
983 bool hasGradientInfoForSchema(const FunctionSchema& schema) {
984  return gradientInfoForSchema(schema).has_value();
985 }
986 
987 } // namespace jit
988 } // namespace torch
Definition: jit_type.h:17