Caffe2 - Python API
A deep learning, cross platform ML framework
sparse.py
1 import torch
2 from torch.autograd.function import Function
3 from torch._thnn import type2backend
4 from torch.autograd.function import once_differentiable
5 
6 from . import _all_functions
7 
8 
9 MODE_SUM = 0
10 MODE_MEAN = 1
11 
12 
14 
15  @staticmethod
16  def _renorm(ctx, indices, weight, max_norm, norm_type):
17  # clone indices since LookupTable_renorm modifies it in-place
18  ctx._backend.LookupTable_renorm(
19  ctx._backend.library_state,
20  indices.clone().view(-1),
21  weight,
22  max_norm,
23  norm_type
24  )
25 
26  @classmethod
27  def forward(cls, ctx, weight, indices, offsets,
28  max_norm, norm_type, scale_grad_by_freq, mode):
29 
30  ctx.max_norm = max_norm
31  ctx.norm_type = norm_type
32  ctx.scale_grad_by_freq = scale_grad_by_freq
33 
34  if mode == 'sum':
35  ctx.mode = MODE_SUM
36  elif mode == 'mean':
37  ctx.mode = MODE_MEAN
38  else:
39  raise ValueError("mode needs to be 'sum' or 'mean', but got {}"
40  .format(mode))
41 
42  assert not ctx.needs_input_grad[1], "EmbeddingBag doesn't " \
43  "compute the gradient w.r.t. the indices"
44 
45  assert not ctx.needs_input_grad[2], "EmbeddingBag doesn't " \
46  "compute the gradient w.r.t. the offsets"
47 
48  assert indices.dim() == 1
49  if offsets.dim() != 1:
50  raise ValueError("offsets has to be a 1D Tensor")
51 
52  if offsets[0] != 0:
53  raise ValueError("offsets[0] has to be 0, i.e. the first sequence"
54  " in the mini-batch has to start from position 0."
55  "However, got {}".format(offsets[0]))
56  if offsets[-1] > indices.size(0):
57  raise ValueError("offsets[-1] has to be smaller than indices's length"
58  " ({}), but got offsets[-1] of {}"
59  .format(indices.size(0), offsets[-1]))
60 
61  ctx._backend = type2backend[weight.type()]
62  ctx._weight_size = weight.size()
63  ctx._offset2bag = offsets.new()
64 
65  ctx.save_for_backward(indices)
66 
67  indices = indices.contiguous().view(-1)
68  output = weight.new()
69 
70  if ctx.max_norm is not None:
71  cls._renorm(ctx, indices, weight, max_norm=max_norm, norm_type=norm_type)
72 
73  if weight.is_cuda:
74  if ctx.mode == MODE_MEAN:
75  ctx.bag_size = offsets.new().resize_(offsets.size())
76  else:
77  ctx.bag_size = None
78 
79  ctx._backend.LookupTableBag_updateOutput(
80  ctx._backend.library_state,
81  indices,
82  offsets,
83  weight,
84  output,
85  ctx._offset2bag,
86  ctx.mode,
87  ctx.bag_size
88  )
89  else:
90  # slow CPU implementation
91  index_output = torch.index_select(weight, 0, indices)
92  # indices = [1, 2, 30, 100, 12], offsets = [0, 2, 3]
93  ctx._offset2bag.resize_(indices.size(0)).zero_() # offset2bag = [0 0 0 0 0]
94  ctx._offset2bag.index_fill_(0, offsets, 1) # offset2bag = [1 0 1 0 1]
95  ctx._offset2bag[0] = 0 # offset2bag = [0 0 1 0 1]
96  ctx._offset2bag = ctx._offset2bag.cumsum(0) # offset2bag = [0 0 1 1 2]
97  output.resize_(offsets.size(0), weight.size(1)).zero_()
98  output.index_add_(0, ctx._offset2bag, index_output)
99  if ctx.mode == MODE_MEAN:
100  if offsets.size(0) == 1:
101  ctx.bag_size = indices.size(0)
102  else:
103  ctx.bag_size = weight.new().resize_(offsets.size())
104  ctx.bag_size[:-1] = offsets[1:] - offsets[:-1]
105  ctx.bag_size[-1] = indices.size(0) - offsets[-1]
106  ctx.bag_size = ctx.bag_size[:, None].expand_as(output)
107  output /= ctx.bag_size
108 
109  return output
110 
111  @staticmethod
112  @once_differentiable
113  def backward(ctx, grad_output):
114  indices, = ctx.saved_tensors
115  indices = indices.contiguous().view(-1)
116  grad_output = grad_output.contiguous()
117 
118  with torch.cuda.device_of(grad_output):
119  if grad_output.is_cuda:
120  _sorted = torch.cuda.LongTensor()
121  _indices = torch.cuda.LongTensor()
122  _count = torch.cuda.LongTensor()
123  else:
124  _count = torch.IntTensor()
125  _sorted = _indices = None
126 
127  grad_weight = grad_output.new(ctx._weight_size).zero_()
128 
129  if grad_output.is_cuda:
130  ctx._backend.LookupTableBag_accGradParameters(
131  ctx._backend.library_state,
132  indices,
133  grad_output,
134  grad_weight,
135  ctx._offset2bag,
136  _count,
137  _sorted,
138  _indices,
139  ctx.scale_grad_by_freq,
140  ctx.mode,
141  ctx.bag_size,
142  1
143  )
144  else:
145  # slow CPU implementation
146  if ctx.mode == MODE_MEAN:
147  # divide by average count
148  grad_output = grad_output / ctx.bag_size
149 
150  index_grad_output = grad_output.index_select(0, ctx._offset2bag)
151  ctx._backend.LookupTable_accGradParameters(
152  ctx._backend.library_state,
153  indices,
154  index_grad_output,
155  grad_weight,
156  _count,
157  _sorted,
158  _indices,
159  ctx.scale_grad_by_freq,
160  -1,
161  1
162  )
163 
164  return grad_weight, None, None, None, None, None, None
165 
166 
167 _all_functions.append(EmbeddingBag)
def _renorm(ctx, indices, weight, max_norm, norm_type)
Definition: sparse.py:16