Caffe2 - Python API
A deep learning, cross platform ML framework
adaptive.py
1 # -*- coding: utf-8 -*-
2 
3 from collections import namedtuple
4 
5 import torch
6 
7 from . import Sequential, ModuleList, Linear
8 from .module import Module
9 from ..functional import log_softmax
10 
11 
12 _ASMoutput = namedtuple('ASMoutput', ['output', 'loss'])
13 
14 
16  r"""Efficient softmax approximation as described in
17  `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin,
18  Moustapha Cissé, David Grangier, and Hervé Jégou.
19 
20  Adaptive softmax is an approximate strategy for training models with large
21  output spaces. It is most effective when the label distribution is highly
22  imbalanced, for example in natural language modelling, where the word
23  frequency distribution approximately follows the `Zipf's law`_.
24 
25  Adaptive softmax partitions the labels into several clusters, according to
26  their frequency. These clusters may contain different number of targets
27  each.
28  Additionally, clusters containing less frequent labels assign lower
29  dimensional embeddings to those labels, which speeds up the computation.
30  For each minibatch, only clusters for which at least one target is
31  present are evaluated.
32 
33  The idea is that the clusters which are accessed frequently
34  (like the first one, containing most frequent labels), should also be cheap
35  to compute -- that is, contain a small number of assigned labels.
36 
37  We highly recommend taking a look at the original paper for more details.
38 
39  * :attr:`cutoffs` should be an ordered Sequence of integers sorted
40  in the increasing order.
41  It controls number of clusters and the partitioning of targets into
42  clusters. For example setting ``cutoffs = [10, 100, 1000]``
43  means that first `10` targets will be assigned
44  to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
45  assigned to the first cluster, and targets `101, 102, ..., 1000` will be
46  assigned to the second cluster, while targets
47  `1001, 1002, ..., n_classes - 1` will be assigned
48  to the last, third cluster.
49 
50  * :attr:`div_value` is used to compute the size of each additional cluster,
51  which is given as
52  :math:`\left\lfloor\frac{in\_features}{div\_value^{idx}}\right\rfloor`,
53  where :math:`idx` is the cluster index (with clusters
54  for less frequent words having larger indices,
55  and indices starting from :math:`1`).
56 
57  * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
58  adaptive softmax. See paper for details. Set to False in the official
59  implementation.
60 
61  .. warning::
62  Labels passed as inputs to this module should be sorted accoridng to
63  their frequency. This means that the most frequent label should be
64  represented by the index `0`, and the least frequent
65  label should be represented by the index `n_classes - 1`.
66 
67  .. note::
68  This module returns a ``NamedTuple`` with ``output``
69  and ``loss`` fields. See further documentation for details.
70 
71  .. note::
72  To compute log-probabilities for all classes, the ``log_prob``
73  method can be used.
74 
75  Args:
76  in_features (int): Number of features in the input tensor
77  n_classes (int): Number of classes in the dataset
78  cutoffs (Sequence): Cutoffs used to assign targets to their buckets
79  div_value (float, optional): value used as an exponent to compute sizes
80  of the clusters. Default: 4.0
81  head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
82  adaptive softmax. Default: ``False``
83 
84  Returns:
85  ``NamedTuple`` with ``output`` and ``loss`` fields:
86  * **output** is a Tensor of size ``N`` containing computed target
87  log probabilities for each example
88  * **loss** is a Scalar representing the computed negative
89  log likelihood loss
90 
91  Shape:
92  - input: :math:`(N, in\_features)`
93  - target: :math:`(N)` where each value satisfies :math:`0 <= target[i] <= n\_classes`
94  - output1: :math:`(N)`
95  - output2: ``Scalar``
96 
97 
98  .. _Efficient softmax approximation for GPUs:
99  https://arxiv.org/abs/1609.04309
100 
101  .. _Zipf's law:
102  https://en.wikipedia.org/wiki/Zipf%27s_law
103  """
104 
105  def __init__(self, in_features, n_classes, cutoffs, div_value=4., head_bias=False):
106  super(AdaptiveLogSoftmaxWithLoss, self).__init__()
107 
108  cutoffs = list(cutoffs)
109 
110  if (cutoffs != sorted(cutoffs)) \
111  or (min(cutoffs) <= 0) \
112  or (max(cutoffs) > (n_classes - 1)) \
113  or (len(set(cutoffs)) != len(cutoffs)) \
114  or any([int(c) != c for c in cutoffs]):
115 
116  raise ValueError("cutoffs should be a sequence of unique, positive "
117  "integers sorted in an increasing order, where "
118  "each value is between 1 and n_classes-1")
119 
120  self.in_features = in_features
121  self.n_classes = n_classes
122  self.cutoffs = cutoffs + [n_classes]
123  self.div_value = div_value
124  self.head_bias = head_bias
125 
126  self.shortlist_size = self.cutoffs[0]
127  self.n_clusters = len(self.cutoffs) - 1
128  self.head_size = self.shortlist_size + self.n_clusters
129 
130  self.head = Linear(self.in_features, self.head_size, bias=self.head_bias)
131  self.tail = ModuleList()
132 
133  for i in range(self.n_clusters):
134 
135  hsz = int(self.in_features // (self.div_value ** (i + 1)))
136  osz = self.cutoffs[i + 1] - self.cutoffs[i]
137 
138  projection = Sequential(
139  Linear(self.in_features, hsz, bias=False),
140  Linear(hsz, osz, bias=False)
141  )
142 
143  self.tail.append(projection)
144 
145  def reset_parameters(self):
146  self.head.reset_parameters()
147  for i2h, h2o in self.tail:
148  i2h.reset_parameters()
149  h2o.reset_parameters()
150 
151  def forward(self, input, target):
152  if input.size(0) != target.size(0):
153  raise RuntimeError('Input and target should have the same size '
154  'in the batch dimension.')
155 
156  used_rows = 0
157  batch_size = target.size(0)
158 
159  output = input.new_zeros(batch_size)
160  gather_inds = target.new_empty(batch_size)
161 
162  cutoff_values = [0] + self.cutoffs
163  for i in range(len(cutoff_values) - 1):
164 
165  low_idx = cutoff_values[i]
166  high_idx = cutoff_values[i + 1]
167 
168  target_mask = (target >= low_idx) & (target < high_idx)
169  row_indices = target_mask.nonzero().squeeze()
170 
171  if row_indices.numel() == 0:
172  continue
173 
174  if i == 0:
175  gather_inds.index_copy_(0, row_indices, target[target_mask])
176 
177  else:
178  relative_target = target[target_mask] - low_idx
179  input_subset = input.index_select(0, row_indices)
180 
181  cluster_output = self.tail[i - 1](input_subset)
182  cluster_index = self.shortlist_size + i - 1
183 
184  gather_inds.index_fill_(0, row_indices, cluster_index)
185 
186  cluster_logprob = log_softmax(cluster_output, dim=1)
187  local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
188  output.index_copy_(0, row_indices, local_logprob.squeeze(1))
189 
190  used_rows += row_indices.numel()
191 
192  if used_rows != batch_size:
193  raise RuntimeError("Target values should be in [0, {}], "
194  "but values in range [{}, {}] "
195  "were found. ".format(self.n_classes - 1,
196  target.min().item(),
197  target.max().item()))
198 
199  head_output = self.head(input)
200  head_logprob = log_softmax(head_output, dim=1)
201  output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
202  loss = (-output).mean()
203 
204  return _ASMoutput(output, loss)
205 
206  def _get_full_log_prob(self, input, head_output):
207  """ Given input tensor, and output of `self.head`,
208  compute the log of the full distribution """
209 
210  out = input.new_empty((head_output.size(0), self.n_classes))
211  head_logprob = log_softmax(head_output, dim=1)
212 
213  out[:, :self.shortlist_size] = head_logprob[:, :self.shortlist_size]
214 
215  for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
216  cluster_output = self.tail[i](input)
217  cluster_logprob = log_softmax(cluster_output, dim=1)
218  output_logprob = cluster_logprob + head_logprob[:, self.shortlist_size + i].unsqueeze(1)
219 
220  out[:, start_idx:stop_idx] = output_logprob
221 
222  return out
223 
224  def log_prob(self, input):
225  r""" Computes log probabilities for all :math:`n\_classes`
226 
227  Args:
228  input (Tensor): a minibatch of examples
229 
230  Returns:
231  log-probabilities of for each class :math:`c`
232  in range :math:`0 <= c <= n\_classes`, where :math:`n\_classes` is a
233  parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
234 
235  Shape:
236  - Input: :math:`(N, in\_features)`
237  - Output: :math:`(N, n\_classes)`
238 
239  """
240 
241  head_output = self.head(input)
242  return self._get_full_log_prob(input, head_output)
243 
244  def predict(self, input):
245  r""" This is equivalent to `self.log_pob(input).argmax(dim=1)`,
246  but is more efficient in some cases.
247 
248  Args:
249  input (Tensor): a minibatch of examples
250 
251  Returns:
252  output (Tensor): a class with the highest probability for each example
253 
254  Shape:
255  - Input: :math:`(N, in\_features)`
256  - Output: :math:`(N)`
257  """
258 
259  head_output = self.head(input)
260  output = torch.argmax(head_output, dim=1)
261  not_in_shortlist = (output >= self.shortlist_size)
262  all_in_shortlist = not (not_in_shortlist.any())
263 
264  if all_in_shortlist:
265  return output
266 
267  elif not_in_shortlist.all():
268  log_prob = self._get_full_log_prob(input, head_output)
269  return torch.argmax(log_prob, dim=1)
270 
271  else:
272  log_prob = self._get_full_log_prob(input[not_in_shortlist],
273  head_output[not_in_shortlist])
274  output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
275  return output
def _get_full_log_prob(self, input, head_output)
Definition: adaptive.py:206