6 from .
import _all_functions
16 def _renorm(ctx, indices, weight, max_norm, norm_type):
18 ctx._backend.LookupTable_renorm(
19 ctx._backend.library_state,
20 indices.clone().view(-1),
27 def forward(cls, ctx, weight, indices, offsets,
28 max_norm, norm_type, scale_grad_by_freq, mode):
30 ctx.max_norm = max_norm
31 ctx.norm_type = norm_type
32 ctx.scale_grad_by_freq = scale_grad_by_freq
39 raise ValueError(
"mode needs to be 'sum' or 'mean', but got {}" 42 assert not ctx.needs_input_grad[1],
"EmbeddingBag doesn't " \
43 "compute the gradient w.r.t. the indices" 45 assert not ctx.needs_input_grad[2],
"EmbeddingBag doesn't " \
46 "compute the gradient w.r.t. the offsets" 48 assert indices.dim() == 1
49 if offsets.dim() != 1:
50 raise ValueError(
"offsets has to be a 1D Tensor")
53 raise ValueError(
"offsets[0] has to be 0, i.e. the first sequence" 54 " in the mini-batch has to start from position 0." 55 "However, got {}".format(offsets[0]))
56 if offsets[-1] > indices.size(0):
57 raise ValueError(
"offsets[-1] has to be smaller than indices's length" 58 " ({}), but got offsets[-1] of {}" 59 .format(indices.size(0), offsets[-1]))
61 ctx._backend = type2backend[weight.type()]
62 ctx._weight_size = weight.size()
63 ctx._offset2bag = offsets.new()
65 ctx.save_for_backward(indices)
67 indices = indices.contiguous().view(-1)
70 if ctx.max_norm
is not None:
71 cls.
_renorm(ctx, indices, weight, max_norm=max_norm, norm_type=norm_type)
74 if ctx.mode == MODE_MEAN:
75 ctx.bag_size = offsets.new().resize_(offsets.size())
79 ctx._backend.LookupTableBag_updateOutput(
80 ctx._backend.library_state,
91 index_output = torch.index_select(weight, 0, indices)
93 ctx._offset2bag.resize_(indices.size(0)).zero_()
94 ctx._offset2bag.index_fill_(0, offsets, 1)
95 ctx._offset2bag[0] = 0
96 ctx._offset2bag = ctx._offset2bag.cumsum(0)
97 output.resize_(offsets.size(0), weight.size(1)).zero_()
98 output.index_add_(0, ctx._offset2bag, index_output)
99 if ctx.mode == MODE_MEAN:
100 if offsets.size(0) == 1:
101 ctx.bag_size = indices.size(0)
103 ctx.bag_size = weight.new().resize_(offsets.size())
104 ctx.bag_size[:-1] = offsets[1:] - offsets[:-1]
105 ctx.bag_size[-1] = indices.size(0) - offsets[-1]
106 ctx.bag_size = ctx.bag_size[:,
None].expand_as(output)
107 output /= ctx.bag_size
113 def backward(ctx, grad_output):
114 indices, = ctx.saved_tensors
115 indices = indices.contiguous().view(-1)
116 grad_output = grad_output.contiguous()
119 if grad_output.is_cuda:
120 _sorted = torch.cuda.LongTensor()
121 _indices = torch.cuda.LongTensor()
122 _count = torch.cuda.LongTensor()
124 _count = torch.IntTensor()
125 _sorted = _indices =
None 127 grad_weight = grad_output.new(ctx._weight_size).zero_()
129 if grad_output.is_cuda:
130 ctx._backend.LookupTableBag_accGradParameters(
131 ctx._backend.library_state,
139 ctx.scale_grad_by_freq,
146 if ctx.mode == MODE_MEAN:
148 grad_output = grad_output / ctx.bag_size
150 index_grad_output = grad_output.index_select(0, ctx._offset2bag)
151 ctx._backend.LookupTable_accGradParameters(
152 ctx._backend.library_state,
159 ctx.scale_grad_by_freq,
164 return grad_weight,
None,
None,
None,
None,
None,
None 167 _all_functions.append(EmbeddingBag)
def _renorm(ctx, indices, weight, max_norm, norm_type)