Caffe2 - Python API
A deep learning, cross platform ML framework
crf.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package crf
17 # Module caffe2.python.crf
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 from caffe2.python import core, recurrent, model_helper, brew
23 import numpy as np
24 
25 '''
26 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
27 In order to support batch_size > 1, we will have to implement the CRFUnit
28 and its gradient in C++ and handle the different batches there.
29 '''
30 
31 
32 class CRFWithLoss(object):
33  def __init__(self, model, num_classes, transitions_blob=None):
34  self.model = model
35  self.num_classes = num_classes
36  self.num_classes_padded = num_classes + 2 # After adding BOS and EOS
37  if not transitions_blob:
38  transitions_blob = self.model.param_init_net.UniformFill(
39  [],
40  [core.ScopedBlobReference('crf_transitions')],
41  shape=[self.num_classes_padded, self.num_classes_padded],
42  min=-1.0,
43  max=1.0
44  )
45  self.transitions = transitions_blob
46  self.model.params.append(self.transitions)
47 
48  def crf_loss(self, predictions, labels, seq_lengths=None):
49  # Since the transitions matrix is a shared parameter, need to
50  # take a snapshot of it at the beginning since it can be updated
51  # in between the operators that uses it when doing parallel updates
52  transitions_snapshot = self.model.net.Copy(
53  self.transitions, core.ScopedBlobReference('transitions_snapshot')
54  )
55  # Compute best path unary score from the logits
56  path_unary_score = self._gather_entries_sum(
57  predictions, labels, self.num_classes
58  )
59  # Append BOS and EOS entries to the predictions and labels
60  predictions = self._pad_predictions(predictions)
61  labels = self._pad_labels(labels)
62  # Compute best path binary scores from the transitions matrix
63  path_binary_score = self._path_binary_scores(
64  labels, transitions_snapshot, seq_lengths
65  )
66  path_total_score = self.model.net.Add(
67  [path_binary_score, path_unary_score],
68  core.ScopedBlobReference('path_total')
69  )
70  # Compute all paths score
71  zero_index = self.model.param_init_net.ConstantFill(
72  [], shape=[1], value=0
73  )
74  initial_state = self.model.net.Gather(
75  [predictions, zero_index],
76  core.ScopedBlobReference('rnn_initial'),
77  dense_gradient=True
78  )
79  input_data, _ = self.model.net.RemovePadding(
80  [predictions],
81  padding_width=1,
82  end_padding_width=0,
83  outputs=2,
84  )
85  input_data = self.model.net.ExpandDims(
86  [input_data],
87  core.ScopedBlobReference('rnn_input_data'),
88  dims=[1]
89  )
90  # Due to a bug in RecurrentNetworkGradientOp, we need to copy the
91  # transitions blob before sending it to the recurrent network
92  transitions_copy = self.model.net.Copy(
93  transitions_snapshot, core.ScopedBlobReference('transitions_copy')
94  )
95  all_paths_scores = self._crf_forward(
96  input_data, initial_state, transitions_copy
97  )
98  loss = self.model.net.Sub(
99  [all_paths_scores, path_total_score],
100  core.ScopedBlobReference('crf_loss')
101  )
102  return loss
103 
104  def _pad_predictions(self, predictions):
105  # This function will introduce two labels for beginning of sequence
106  # And end of sequence, it will make the necessary udpates to the
107  # the predictions blob
108 
109  low_score = -1000.0 # An arbitray very low number
110  b_scores = np.array(
111  [[low_score] * self.num_classes + [0, low_score]]
112  ).astype(np.float32)
113 
114  e_scores = np.array(
115  [[low_score] * self.num_classes + [low_score, 0]]
116  ).astype(np.float32)
117 
118  b_scores = self.model.param_init_net.GivenTensorFill(
119  [], "b_scores", shape=[1, self.num_classes_padded], values=b_scores
120  )
121  e_scores = self.model.param_init_net.GivenTensorFill(
122  [], "e_scores", shape=[1, self.num_classes_padded], values=e_scores
123  )
124 
125  zero_index = self.model.net.ConstantFill(
126  [], shape=[1, ], value=0
127  )
128  length = self.model.net.Gather(
129  [self.model.net.Shape([predictions]), zero_index],
130  )
131  length = self.model.net.Cast(length, to='int32')
132  t_range = self.model.net.LengthsRangeFill(length)
133  padding = self.model.net.ConstantFill([t_range], value=low_score)
134  padding = self.model.net.ExpandDims(padding, dims=[1])
135  padded_predictions, _ = self.model.net.Concat(
136  [predictions, padding, padding],
137  outputs=2,
138  axis=1
139  )
140  padded_predictions_concat, _ = self.model.net.Concat(
141  [b_scores, padded_predictions, e_scores],
142  outputs=2,
143  axis=0
144  )
145  return padded_predictions_concat
146 
147  def _pad_labels(self, labels):
148  bos_i = self.num_classes
149  eos_i = self.num_classes + 1
150  bos_i_b = self.model.param_init_net.ConstantFill(
151  [], shape=[1], value=bos_i
152  )
153  eos_i_b = self.model.param_init_net.ConstantFill(
154  [], shape=[1], value=eos_i
155  )
156  labels = self.model.net.Cast([labels], to='int64')
157  padded_labels, _ = self.model.net.Concat(
158  [bos_i_b, labels, eos_i_b],
159  axis=0,
160  outputs=2
161  )
162  return padded_labels
163 
164  def _path_binary_scores(self, labels, transitions, seq_lengths=None):
165  column_ids, _ = self.model.net.RemovePadding(
166  [labels],
167  outputs=2,
168  padding_width=1,
169  end_padding_width=0
170  )
171  row_ids, _ = self.model.net.RemovePadding(
172  [labels],
173  outputs=2,
174  padding_width=0,
175  end_padding_width=1
176  )
177  # Since there is no multi-dimensional gather, I flatten the matrix to
178  # a 1-d vector and transform the ids to (row_ids * num_columns +
179  # column_ids) and do gather in 1-d
180  num_columns_blob = self.model.net.ConstantFill(
181  [row_ids],
182  value=self.num_classes_padded,
183  )
184  flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
185  flattened_ids = self.model.net.Add([flattened_ids, column_ids])
186  flattened_transitions = self.model.net.FlattenToVec([transitions])
187  entries = self.model.net.Gather(
188  [flattened_transitions, flattened_ids],
189  dense_gradient=True
190  )
191  return self.model.ReduceFrontSum(entries)
192 
193  def _gather_entries_sum(self, in_data, indices, index_size):
194  indices = self.model.net.Cast([indices], to='int64')
195  index_size_blob = self.model.param_init_net.ConstantFill(
196  [],
197  shape=[1],
198  value=index_size,
199  )
200  query_one_hot = self.model.net.OneHot(
201  [indices, index_size_blob]
202  )
203  flattend_query = self.model.net.FlattenToVec(query_one_hot)
204  flattend_data = self.model.net.FlattenToVec(in_data)
205  query_scores = self.model.net.DotProduct(
206  [flattend_query, flattend_data]
207  )
208  final_sum = self.model.net.ReduceFrontSum([query_scores])
209  return final_sum
210 
211  def _crf_forward(
212  self,
213  input_blob,
214  initial_state,
215  transitions_copy,
216  seq_lengths=None
217  ):
218  # Build the RNN net and get the last timestep output
219  out_last = self.build_crf_net(
220  input_blob, initial_state, transitions_copy
221  )
222  out_last, _ = self.model.net.Reshape(
223  [out_last],
224  outputs=2,
225  shape=(self.num_classes_padded,)
226  )
227  zero_segment_id = self.model.param_init_net.ConstantFill(
228  [],
229  value=0,
230  shape=[self.num_classes_padded],
231  dtype=core.DataType.INT32,
232  )
233 
234  # Compute the accumlated total score of all the paths
235  accum_score = self.model.net.SortedSegmentRangeLogSumExp(
236  [out_last, zero_segment_id]
237  )
238  accum_score, _ = self.model.net.Reshape(
239  accum_score,
240  outputs=2,
241  shape=()
242  )
243  return accum_score
244 
245  def build_crf_net(self, input_blob, initial_state, transitions):
246  '''
247  Adds the crf_net recurrent operator to the model.
248 
249  model: model_helper.ModelHelper object new operators would be added
250  to
251 
252  input_blob: the input sequence in a format T x N x D
253  where T is sequence size, N - batch size and D - input dimention
254  ##Only supports batch-size 1##
255 
256  seq_lengths: blob containing sequence lengths (unused)
257  '''
258 
259  scope = 'crf_net'
260 
261  def s(name):
262  ''
263  # We have to manually scope due to our internal/external blob
264  # relationships.
265  return "{}/{}".format(str(scope), str(name))
266 
267  step_model = model_helper.ModelHelper(name='crf_step',
268  param_model=self.model)
269  input_t, cell_t_prev, _ = (
270  step_model.net.AddExternalInputs(
271  core.ScopedBlobReference('input_t'),
272  core.ScopedBlobReference('cell_t_prev'),
273  transitions
274  )
275  )
276  zero_segment_id = step_model.param_init_net.ConstantFill(
277  [],
278  [s('zero_segment_id')],
279  value=0,
280  shape=[self.num_classes_padded],
281  dtype=core.DataType.INT32,
282  )
283 
284  # A hack to bypass model cloning for test
285  step_model.param_init_net.AddExternalOutput(zero_segment_id)
286  """ the CRF step """
287  # Do tile
288  prev_transpose = brew.transpose(
289  step_model,
290  cell_t_prev,
291  [s('prev_transpose')],
292  axes=(0, 2, 1),
293  )
294  prev_tiled = step_model.net.Tile(
295  prev_transpose,
296  [s('prev_tiled')],
297  tiles=self.num_classes_padded,
298  axis=2,
299  )
300  input_t_tiled = step_model.net.Tile(
301  input_t,
302  [s('input_t_tiled')],
303  tiles=self.num_classes_padded,
304  axis=1,
305  )
306  input_with_prev = step_model.net.Add(
307  [prev_tiled, input_t_tiled],
308  [s('input_with_prev')]
309  )
310  all_with_transitions = step_model.net.Add(
311  [input_with_prev, transitions],
312  [s('prev_with_transitions')],
313  broadcast=1,
314  use_grad_hack=1,
315  )
316  all_with_transitions_reshaped, _ = step_model.net.Reshape(
317  all_with_transitions,
318  [s('all_with_transitions_reshaped'), s('all_with_transitions_orig')],
319  shape=(self.num_classes_padded, self.num_classes_padded)
320  )
321  cell_t = step_model.net.SortedSegmentRangeLogSumExp(
322  [all_with_transitions_reshaped, zero_segment_id],
323  [s('cell_t')],
324  )
325  step_model.net.AddExternalOutputs(cell_t)
326  """ recurrent network """
327  cell_input_blob = initial_state
328  out_all, out_last = recurrent.recurrent_net(
329  net=self.model.net,
330  cell_net=step_model.net,
331  inputs=[(input_t, input_blob)],
332  initial_cell_inputs=[
333  (cell_t_prev, cell_input_blob),
334  ],
335  links={
336  cell_t_prev: cell_t,
337  },
338  scope=scope,
339  outputs_with_grads=(1,)
340  )
341  return out_last
342 
343  def update_predictions(self, classes):
344 
345  def crf_update_predictions_op(inputs, outputs):
346  # This operator will compute the best path of classes by performing
347  # Viterbi decoding and then updates the predictions to make the tag
348  # On the best path has the highest score among the others
349  predictions = inputs[0].data
350  transitions = inputs[1].data
351  predictions = inputs[0].data
352  predictions_shape = inputs[0].shape
353  outputs[0].reshape(predictions_shape)
354 
355  trellis = np.zeros(predictions_shape)
356  backpointers = np.zeros(predictions_shape, dtype=np.int32)
357  trellis[0] = predictions[0]
358 
359  for t in range(1, predictions_shape[0]):
360  v = np.expand_dims(trellis[t - 1], 1) + transitions
361  trellis[t] = predictions[t] + np.max(v, 0)
362  backpointers[t] = np.argmax(v, 0)
363 
364  viterbi = [np.argmax(trellis[-1])]
365  for bp in reversed(backpointers[1:]):
366  viterbi.append(bp[viterbi[-1]])
367  viterbi.reverse()
368 
369  new_predictions = np.zeros(predictions_shape)
370  old_bests = []
371  for i, w_predictions in enumerate(predictions):
372  # Get the current tag with the maximum score
373  new_predictions[i] = predictions[i]
374  old_best = np.argmax(w_predictions)
375  old_bests.append(old_best)
376  # Swap the scores of the current best tag and the tag on the
377  # Viterbi path
378  w_predictions[viterbi[i]], w_predictions[old_best] = \
379  w_predictions[old_best], w_predictions[viterbi[i]]
380  new_predictions[i] = w_predictions
381  # Remove the BOS and EOS entries from the predictions matrix
382  orig_predictions = new_predictions[1:-1, 0:-2]
383  outputs[0].reshape(orig_predictions.shape)
384  outputs[0].data[...] = orig_predictions
385  padded_classes = self._pad_predictions(classes)
386  new_classes = self.model.net.Python(crf_update_predictions_op)(
387  [padded_classes, self.transitions],
388  core.ScopedBlobReference('post_crf_classes')
389  )
390  return new_classes
def build_crf_net(self, input_blob, initial_state, transitions)
Definition: crf.py:245
def _pad_predictions(self, predictions)
Definition: crf.py:104
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
Definition: crf.py:164
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
Definition: crf.py:217
def _gather_entries_sum(self, in_data, indices, index_size)
Definition: crf.py:193
def _pad_labels(self, labels)
Definition: crf.py:147