1 #include <torch/csrc/jit/symbolic_script.h> 7 const std::vector<std::string> functions = {
10 #### HELPER FUNCTIONS ### 12 #### SCHEMA NOT SAVED IN CACHE ### 14 def AD_unsqueeze_multiple(t, 17 seen = [False] * n_dims 18 for i in range(len(dims)): 21 for d in range(n_dims): 26 def AD_sum_backward(grad, 30 if not keepdim and len(sizes) > 0: 32 return grad.unsqueeze(dims[0]).expand(sizes) 34 res = AD_unsqueeze_multiple(grad, dims, len(sizes)) 35 return res.expand(sizes) 37 return grad.expand(sizes) 39 def AD_logsumexp_backward(grad, self, result, 42 if not keepdim and self.dim() != 0: 43 n_dims = len(self.size()) 44 grad = AD_unsqueeze_multiple(grad, dim, n_dims) 45 result = AD_unsqueeze_multiple(result, dim, n_dims) 46 return grad * (self - result).exp() 49 self_size = self.size() 50 self_numel = self.numel() 51 def backward(grad_output): 52 grad_self = grad_output.expand(self_size) / self_numel 55 return torch.mean(self), backward 60 self_size = self.size() 61 def backward(grad_output): 62 grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim) / AD_safe_size(self_size, dim) 63 return grad_self, None, None 65 return torch.mean(self, dim, keepdim), backward 70 result = torch.logsumexp(self, dim, keepdim) 72 def backward(grad_output): 73 grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim) 74 return grad_self, None, None 76 return result, backward 78 def AD_bool_to_int(b: bool): 79 # FIXME: torchscript: int - bool 86 def AD_var_backward_0(grad, self, unbiased: bool): 87 b = AD_bool_to_int(unbiased) 89 # FIXME: torchscript: div(float, float) 90 return grad * (self - self.mean()) * 2.0 / (self.numel() - b) 92 def AD_safe_size(sizes: List[int], 98 for i in range(len(dims)): 104 def AD_var_backward_1(grad, 110 return AD_var_backward_0(grad, self, unbiased) 111 self_size = self.size() 112 b = AD_bool_to_int(unbiased) 113 if not keepdim and self.dim() > 1: 114 grad = AD_unsqueeze_multiple(grad, dim, len(self_size)) 116 # FIXME: torchscript: div(float, float) 117 return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - b) 120 unbiased: bool=True): 121 std_out = torch.std(self, unbiased) 122 def backward(grad_output): 123 grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, unbiased) 124 return grad_self, None 126 return std_out, backward 132 std_out = torch.std(self, dim, unbiased, keepdim) 133 def backward(grad_output): 134 grad_self = AD_var_backward_1(grad_output / (std_out * 2), self, dim, unbiased, keepdim) 135 return grad_self, None, None, None 137 return std_out, backward 140 unbiased: bool=True): 141 def backward(grad_output): 142 grad_self = AD_var_backward_0(grad_output, self, unbiased) 143 return grad_self, None 145 return torch.var(self, unbiased), backward 151 def backward(grad_output): 152 grad_self = AD_var_backward_1(grad_output, self, dim, unbiased, keepdim) 153 return grad_self, None, None, None 155 return torch.var(self, dim, unbiased, keepdim), backward 158 output = torch.tanh(self) 159 def backward(grad_output): 160 return grad_output * (1 - output * output) 162 return output, backward 164 def AD_index_select_backward(grad, 169 if not keepdim and len(sizes) > 0: 170 grad = grad.unsqueeze(dim) 171 indices = indices.unsqueeze(dim) 173 # FIXME: torchscript: torch.zeros(sizes, grad.options()) 174 return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad) 179 # largest: bool = True, 180 # sorted: bool = True): 181 # result0, result1 = torch.topk(self, k, dim, largest, sorted) 182 # self_size = self.size() 183 # def backward(grad_output): 184 # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True) 185 # return grad_self, None, None, None, None 187 # return result0, result1, backward 193 # result0, result1 = torch.kthvalue(self, k, dim, keepdim) 194 # self_size = self.size() 195 # def backward(grad_output): 196 # grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim) 197 # return grad_self, None, None, None 199 # return result0, result1, backward 201 def AD_mm_backward_self(grad, mat2): 202 return grad.mm(mat2.t()) 204 def AD_mm_backward_mat2(grad, self): 205 return self.t().mm(grad) 208 def backward(grad_output): 209 grad_self = AD_mm_backward_self(grad_output, mat2) 210 grad_mat2 = AD_mm_backward_mat2(grad_output, self) 211 return grad_self, grad_mat2 213 return torch.mm(self, mat2), backward 215 def AD_permute_backward(grad, 216 fwd_dims: List[int]): 217 ndims = len(fwd_dims) 220 for i in range(ndims): 221 dims[fwd_dims[i]] = i 223 return grad.permute(dims) 227 def backward(grad_output): 228 grad_self = AD_permute_backward(grad_output, dims) 229 return grad_self, None 231 return torch.permute(self, dims), backward 233 def AD_select_backward(grad, 234 input_sizes: List[int], 237 # FIXME: torchscript: torch.zeros(sizes, grad.options()) 238 grad_input = torch.zeros(input_sizes).to(grad) 239 grad_input.select(dim, index).copy_(grad) 242 # TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue 246 # self_size = self.size() 247 # def backward(grad_output): 248 # grad_self = AD_select_backward(grad_output, self_size, dim, index) 249 # return grad_self, None, None 251 # return torch.select(self, dim, index), backward 253 def AD_slice_backward(grad, 254 input_sizes: List[int], 259 # FIXME: torchscript: torch.zeros(sizes, grad.options()) 260 grad_input = torch.zeros(input_sizes).to(grad) 261 grad_input.slice(dim, start, end, step).copy_(grad) 264 # DON'T enable slice unless we can correctly handle view ops in graph executor. 265 # It triggers failure of TestJit.test_sample in test_distributions.py. 269 # end: int=9223372036854775807, 271 # def backward(grad_output): 272 # grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step) 273 # return grad_self, None, None, None, None 275 # return torch.slice(self, dim, start, end, step), backward 277 def AD_unsqueeze_to_0(self, 280 for i in range(ndims): 282 self = self.unsqueeze(i) 286 def AD_unsqueeze_to_1(self, 289 if len(sizes) > 0 and sizes[dim] == 1: 290 return self.unsqueeze(dim) 294 self_size = self.size() 295 def backward(grad_output): 296 grad_self = AD_unsqueeze_to_0(grad_output, self_size) 299 return torch.squeeze(self), backward 303 self_size = self.size() 304 def backward(grad_output): 305 grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size) 306 return grad_self, None 308 return torch.squeeze(self, dim), backward 310 def AD_infer_size(a: List[int], 315 ndim = dimsA if dimsA > dimsB else dimsB 316 expand_sizes = [0] * ndim 318 for i in range(ndim): 320 sizeA = a[i] if dimsA + i >= 0 else 1 321 sizeB = b[i] if dimsB + i >= 0 else 1 323 # Assert sizeA == sizeB or sizeA == 1 or sizeB == 1 324 expand_sizes[i] = sizeB if sizeA == 1 else sizeA 328 def AD_bmm_backward_self(grad, mat2): 329 return grad.bmm(mat2.transpose(1, 2)) 331 def AD_bmm_backward_mat2(grad, self): 332 return self.transpose(1, 2).bmm(grad) 335 def backward(grad_output): 336 grad_self = AD_bmm_backward_self(grad_output, mat2) 337 grad_mat2 = AD_bmm_backward_mat2(grad_output, self) 338 return grad_self, grad_mat2 339 return torch.bmm(self, mat2), backward 341 def AD_mat_transpose(mat): 348 dims = rangelist(dim) 351 out = mat.permute(dims) 354 def AD_matmul_size(mat1, mat2, 355 out_size: List[int]): 358 dim_out = len(out_size) 359 if dim1 == 0 or dim2 == 0: 361 elif dim1 + dim2 == dim_out: 366 out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2)) 367 elif dim_out == dim1 - dim2: 368 out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1) 369 elif dim_out == dim2 - dim1: 370 out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2) 372 out = torch.matmul(mat1, mat2) 375 def matmul(self, other): 376 def backward(grad_output): 377 self_size = self.size() 378 other_size = other.size() 379 grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size) 380 grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size) 381 return grad_self, grad_other 383 return torch.matmul(self, other), backward 386 def _dim_arange(like, 388 def backward(grad_output): 391 return torch._dim_arange(like, dim), backward 393 def contiguous(self): 394 def backward(grad_output): 397 return self.contiguous(), backward 399 def dot(self, tensor): 400 def backward(grad_output): 401 grad_self = grad_output * tensor 402 grad_tensor = grad_output * self 403 return grad_self, grad_tensor 405 return torch.dot(self, tensor), backward 408 def backward(grad_output): 409 # Precomputed constant C = 2.0 / math.sqrt(math.pi) 410 C = 1.1283791670955126 411 grad_self = C * torch.exp(- self * self) * grad_output 414 return torch.erf(self), backward 419 implicit: bool=False): 420 self_size = self.size() 421 def backward(grad_output): 422 grad_self = torch._grad_sum_to_size(grad_output, self_size) 423 return grad_self, None, None 425 return torch.expand(self, size, implicit=implicit), backward 427 def expand_as(self, other): 428 self_size = self.size() 429 def backward(grad_output): 430 grad_self = grad_output._grad_sum_to_size(self_size) 431 return grad_self, None 433 return torch.expand_as(self, other), backward 437 def backward(grad_output): 440 return torch.full_like(self, fill_value), backward 442 def mul(self, other): 443 def backward(grad_output): 444 # self & other are used in backward. No need to pass in their size 446 grad_self = (grad_output * other)._grad_sum_to_size(self.size()) 447 grad_other = (grad_output * self)._grad_sum_to_size(other.size()) 448 return grad_self, grad_other 450 return self * other, backward 453 def backward(grad_output): 454 grad_self = grad_output.ger(vec) 455 grad_vec = self.t().mv(grad_output) 456 return grad_self, grad_vec 458 return torch.mv(self, vec), backward 461 def backward(grad_output): 464 return torch.nonzero(self), backward 467 def backward(grad_output): 470 return torch.ones_like(self), backward 474 def backward(grad_output): 475 grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1)) 476 return grad_self, None 478 return torch.pow(self, exponent), backward 480 def pow_1(self, exponent): 481 def backward(grad_output): 482 # self & exponent are used in backward, no need to pass in its size explicitly 483 grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self.size()) 484 grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent.size()) 485 return grad_self, grad_exponent 487 return torch.pow(self, exponent), backward 489 def pow_2(self: float, 491 def backward(grad_output): 492 grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self)) 493 return None, grad_exponent 495 return torch.pow(self, exponent), backward 497 def rsub_0(self, other, 499 self_size = self.size() 500 other_size = other.size() 501 def backward(grad_output): 502 grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size) 503 grad_other = (grad_output)._grad_sum_to_size(other_size) 504 return grad_self, grad_other, None 506 return torch.rsub(self, other, alpha), backward 511 self_size = self.size() 512 def backward(grad_output): 513 grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size) 514 return grad_self, None, None 516 return torch.rsub(self, other, alpha), backward 519 result = torch.sqrt(self) 520 def backward(grad_output): 521 grad_self = grad_output / (2 * result) 524 return result, backward 527 def backward(grad_output): 528 grad_self = torch.t(grad_output) 531 return torch.t(self), backward 534 device: Optional[Device], 535 dtype: Optional[int], 536 non_blocking: bool=False, 538 self_device = self.device 539 self_dtype = self.dtype 540 if device is not None: 541 result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy) 543 result = self.to(dtype, non_blocking=non_blocking, copy=copy) 544 def backward(grad_output): 545 grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy) 546 return grad_self, None, None, None, None 548 return result, backward 553 non_blocking: bool=False, 555 self_dtype = self.dtype 556 def backward(grad_output): 557 grad_self = grad_output.to(self_dtype, non_blocking, copy) 558 return grad_self, None, None, None 560 return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward 564 non_blocking: bool=False, 566 def backward(grad_output): 567 grad_self = grad_output.to(self, non_blocking, copy) 568 return grad_self, None, None, None 570 return self.to(other, non_blocking=non_blocking, copy=copy), backward 575 def backward(grad_output): 576 grad_self = torch.transpose(grad_output, dim0, dim1) 577 return grad_self, None, None 579 return torch.transpose(self, dim0, dim1), backward 583 self_size = self.size() 584 def backward(grad_output): 585 grad_self = grad_output.reshape(self_size) 586 return grad_self, None 588 return torch.view(self, size), backward 591 def AD_adaptive_avg_pool2d_backward(grad, 593 output_size: List[int]): 594 if output_size[0] == 1 and output_size[1] == 1: 595 self_size = self.size() 596 grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2]) 598 grad_self = torch._adaptive_avg_pool2d_backward(grad, self) 602 def AD_adaptive_avg_pool1d_backward(grad, 604 output_size: List[int]): 605 output_size_2d = [1, output_size[0]] 606 grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2) 609 def adaptive_avg_pool1d(self, 610 output_size: List[int]): 611 def backward(grad_output): 612 grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size) 613 return grad_self, None 615 return torch.adaptive_avg_pool1d(self, output_size), backward 617 def adaptive_avg_pool2d(self, 618 output_size: List[int]): 619 def backward(grad_output): 620 # self is used in backward, no need to pass in its size explicitly 621 grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size) 622 return grad_self, None 623 return torch.adaptive_avg_pool2d(self, output_size), backward 625 def adaptive_avg_pool3d(self, 626 output_size: List[int]): 627 def backward(grad_output): 628 grad_self = torch.adaptive_avg_pool3d_backward(grad_output, self) 629 return grad_self, None 631 return torch.adaptive_avg_pool3d(self, output_size), backward 633 def batch_norm(input : Tensor, 634 weight : Optional[Tensor], 635 bias : Optional[Tensor], 636 running_mean : Optional[Tensor], 637 running_var : Optional[Tensor], 641 cudnn_enabled : bool): 643 output, save1, save2, impl_idx = torch._batch_norm_impl_index( 644 input, weight, bias, running_mean, running_var, training, 645 momentum, eps, cudnn_enabled) 646 has_weight = weight is not None 647 has_bias = bias is not None 649 def backward(grad_output): 650 dinput, dweight, dbias = torch._batch_norm_impl_index_backward( 651 impl_idx, input, grad_output, weight, running_mean, running_var, 652 save1, save2, training, eps, [True, has_weight, has_bias]) 653 return dinput, dweight, dbias, None, None, None, None, None, None 655 return output, backward 657 def layer_norm(input : Tensor, 658 normalied_shape : List[int], 659 weight : Optional[Tensor], 660 bias : Optional[Tensor], 662 cudnn_enable : bool): 664 bn_out, save1, save2, impl_idx = torch._batch_norm_impl_index( 665 input, weight, bias, None, None, True, 666 0.0, eps, cudnn_enable) 667 has_weight = weight is not None 668 has_bias = bias is not None 670 bn_out = bn_out.view(input.sizes()) 671 if weight is not None and bias is not None: 672 output = bias.addcmul(bn_out, weight) 673 elif weight is not None: 674 output = bn_out.mul(weight) 675 elif bias is not None: 676 output = bn_out.add(bias) 680 def backward(grad_output): 681 if weight is not None: 682 grad_output = grad_output * torch.t(weight) 683 weight = grad_output * torch.t(bn_out) 685 grad_output = grad_output.reshape(input.sizes()) 687 dinput, dweight, dbias = torch._batch_norm_impl_index_backward( 688 impl_idx, input, grad_output, weight, None, None, 689 save1, save2, True, eps, [True, has_weight, has_bias]) 690 return dinput, None, dweight, dbias, None, None 692 return output, backward 694 def AD_fused_dropout_backward(grad, 698 if grad.requires_grad: 699 grad_input = grad * (mask.type_as(grad) * p1r) 701 grad_input = torch._masked_scale(grad, mask, p1r) 707 use_cuda = input.is_cuda 708 # CUDA has a fused dropout implementation 711 res, mask = torch._fused_dropout(input, p1m) 713 mask = torch.empty_like(input) 715 res = mask * input / p1m 717 def backward(grad_output): 719 grad_input = AD_fused_dropout_backward(grad_output, mask, p1m) 721 grad_input = grad_output * mask / p1m 722 return grad_input, None, None 725 def embedding(weight, 728 scale_grad_by_freq: bool, 730 weight_size_0 = weight.size()[0] 731 def backward(grad_output): 732 grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse) 733 return grad_weight, None, None, None, None 735 return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward 737 def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int): 738 result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index) 740 return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None 741 return result, backward 743 def softmax_0(self, dim: int): 744 result = torch.softmax(self, dim) 745 def backward(grad_output): 746 grad_self = torch._softmax_backward_data(grad_output, result, dim, self) 747 return grad_self, None 749 return result, backward 751 def softmax_1(self, dim: int, dtype: int): 752 result = torch.softmax(self, dim, dtype) 753 def backward(grad_output): 754 grad_self = torch._softmax_backward_data(grad_output, result, dim, self) 755 return grad_self, None, None 757 return torch.softmax(self, dim, dtype), backward 759 def AD_interpolate_backward(grad, 762 align_corners: bool): 763 output_size = grad.size()[2:] 764 input_size = input.size() 765 input_dim = len(input_size) 766 if input_dim == 3 and mode == 'nearest': 767 grad_input = torch.upsample_nearest1d_backward(grad, output_size, input_size) 768 elif input_dim == 4 and mode == 'nearest': 769 grad_input = torch.upsample_nearest2d_backward(grad, output_size, input_size) 770 elif input_dim == 5 and mode == 'nearest': 771 grad_input = torch.upsample_nearest3d_backward(grad, output_size, input_size) 772 elif input_dim == 3 and mode == 'linear': 773 grad_input = torch.upsample_linear1d_backward(grad, output_size, input_size, align_corners) 774 elif input_dim == 4 and mode == 'bilinear': 775 grad_input = torch.upsample_bilinear2d_backward(grad, output_size, input_size, align_corners) 776 elif input_dim == 5 and mode == 'trilinear': 777 grad_input = torch.upsample_trilinear3d_backward(grad, output_size, input_size, align_corners) 778 elif input_dim == 4 and mode == 'bicubic': 779 grad_input = torch.upsample_bicubic2d_backward(grad, output_size, input_size, align_corners) 780 elif input_dim == 3 and mode == 'area': 781 grad_input = AD_adaptive_avg_pool1d_backward(grad, input, output_size) 782 elif input_dim == 4 and mode == 'area': 783 grad_input = AD_adaptive_avg_pool2d_backward(grad, input, output_size) 784 elif input_dim == 5 and mode == 'area': 785 grad_input = torch.adaptive_avg_pool3d_backward(grad, input) 788 grad_input = torch.zeros_like(input) 789 raise RuntimeError('Input Error: Only 3D, 4D and 5D input Tensors supported') 793 def __interpolate_0(input, 795 scale_factor: Optional[List[float]], 797 align_corners: Optional[bool]): 798 def backward(grad_output): 799 if align_corners is None: 800 align_corners = False 801 grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners) 802 return grad_self, None, None, None, None 804 return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward 806 def __interpolate_1(input, 807 size: Optional[List[int]], 808 scale_factor: Optional[List[float]], 810 align_corners: Optional[bool]): 811 def backward(grad_output): 812 if align_corners is None: 813 align_corners = False 814 grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners) 815 return grad_self, None, None, None, None 817 return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward 819 def __interpolate_2(input, 821 scale_factor: Optional[float], 823 align_corners: Optional[bool]): 824 def backward(grad_output): 825 if align_corners is None: 826 align_corners = False 827 grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners) 828 return grad_self, None, None, None, None 830 return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward 832 def __interpolate_3(input, 833 size: Optional[List[int]], 834 scale_factor: Optional[float], 836 align_corners: Optional[bool]): 837 def backward(grad_output): 838 if align_corners is None: 839 align_corners = False 840 grad_self = AD_interpolate_backward(grad_output, input, mode, align_corners) 841 return grad_self, None, None, None, None 843 return torch.__interpolate(input, size, scale_factor, mode, align_corners), backward 846 std::unordered_map<std::string, GradientPair> schema_to_graphs; 851 std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
854 std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
856 closure->node()->kind() == prim::TupleConstruct,
857 "closure must be a literal tuple construct");
858 Value* fn = closure->node()->inputs().at(0);
859 Value* context = closure->node()->inputs().at(1);
862 fn->node()->kind() == prim::Function,
863 "closure tuple must contain a prim::Function");
864 return std::make_pair(fn->node()->g(attr::Subgraph), context);
867 Argument originalReturnType(
const TupleTypePtr& tup) {
868 AT_CHECK(tup->elements().size() > 1);
869 if (tup->elements().size() == 2)
870 return Argument(
"", tup->elements().at(0));
871 std::vector<TypePtr> types = tup->elements().vec();
873 return Argument(
"", TupleType::create(std::move(types)));
880 std::string overloadedSchemaString(
const FunctionSchema& schema) {
881 const auto& schema_name = schema.name();
882 auto pos = schema_name.find_last_of(
'_');
883 auto schema_name_suffix = schema_name.substr(pos + 1);
884 std::string schema_string = canonicalSchemaString(schema);
885 if (!schema_name_suffix.empty() &&
886 schema_name_suffix.find_first_not_of(
"0123456789") == std::string::npos) {
887 schema_string.replace(
888 schema_string.find(schema_name),
889 schema_name.length(),
890 schema_name.substr(0, pos));
892 return schema_string;
895 bool isHelperFunction(
const std::string& method_name) {
896 std::string helper_prefix =
"AD_";
897 return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
900 void loadModule(
const std::shared_ptr<script::Module>& module) {
901 for (
const auto& method_ : module->get_methods()) {
902 if (isHelperFunction(method_.key()))
905 const auto& method = method_.value();
907 pair.forward = method->graph();
910 Node* forward_tuple = pair.forward->outputs().at(0)->node();
912 if (forward_tuple->kind() != prim::TupleConstruct) {
913 throw script::ErrorReport(forward_tuple->getSourceLocation())
914 <<
"gradient must return literal a tuple";
918 std::tie(pair.backward, context) =
919 extractClosure(forward_tuple->inputs().back());
927 std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
928 new_inputs.back() = context;
930 pair.forward->appendNode(pair.forward->createTuple(new_inputs))
932 pair.forward->eraseOutput(0);
933 pair.forward->registerOutput(new_tuple);
934 forward_tuple->destroy();
937 const FunctionSchema& loaded_schema = method->getSchema();
938 FunctionSchema actual_schema(
939 Symbol::aten(loaded_schema.name()),
940 loaded_schema.overload_name(),
941 loaded_schema.arguments(),
942 {originalReturnType(new_tuple->type()->expect<TupleType>())});
946 auto schema_string = overloadedSchemaString(actual_schema);
948 schema_to_graphs[schema_string] = std::move(pair);
952 void loadFunctions() {
953 for (
const std::string& str : functions) {
954 auto cu = std::make_shared<script::Module>();
955 script::defineMethodsInModule(
956 cu, str, script::nativeResolver, c10::nullopt);
962 const FunctionSchema& schema) {
963 std::lock_guard<std::mutex> guard(lock);
964 if (schema_to_graphs.size() == 0) {
967 auto cache_it = cached_gradient_pairs.find(&schema);
968 if (cache_it != cached_gradient_pairs.end()) {
969 return cache_it->second;
971 auto schema_str = canonicalSchemaString(schema);
972 auto sym_script_it = schema_to_graphs.find(schema_str);
974 if (sym_script_it != schema_to_graphs.end()) {
975 cached_gradient_pairs.emplace_hint(
976 cache_it, &schema, sym_script_it->second);
977 return sym_script_it->second;
983 bool hasGradientInfoForSchema(
const FunctionSchema& schema) {
984 return gradientInfoForSchema(schema).has_value();