Caffe2 - Python API
A deep learning, cross platform ML framework
crf.py
1 ## @package crf
2 # Module caffe2.python.crf
3 from __future__ import absolute_import, division, print_function, unicode_literals
4 
5 import numpy as np
6 from caffe2.python import brew, core, model_helper, recurrent
7 
8 
9 """
10 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
11 In order to support batch_size > 1, we will have to implement the CRFUnit
12 and its gradient in C++ and handle the different batches there.
13 """
14 
15 
16 class CRFWithLoss(object):
17  def __init__(self, model, num_classes, transitions_blob=None):
18  self.model = model
19  self.num_classes = num_classes
20  self.num_classes_padded = num_classes + 2 # After adding BOS and EOS
21  if not transitions_blob:
22  transitions_blob = self.model.param_init_net.UniformFill(
23  [],
24  [core.ScopedBlobReference("crf_transitions")],
25  shape=[self.num_classes_padded, self.num_classes_padded],
26  min=-1.0,
27  max=1.0,
28  )
29  self.transitions = transitions_blob
30  self.model.params.append(self.transitions)
31 
32  def crf_loss(self, predictions, labels, seq_lengths=None):
33  # Since the transitions matrix is a shared parameter, need to
34  # take a snapshot of it at the beginning since it can be updated
35  # in between the operators that uses it when doing parallel updates
36  transitions_snapshot = self.model.net.Copy(
37  self.transitions, core.ScopedBlobReference("transitions_snapshot")
38  )
39  # Compute best path unary score from the logits
40  path_unary_score = self._gather_entries_sum(
41  predictions, labels, self.num_classes
42  )
43  # Append BOS and EOS entries to the predictions and labels
44  predictions = CRFWithLoss.pad_predictions(
45  predictions, self.model.param_init_net, self.model.net, self.num_classes
46  )
47  labels = CRFWithLoss.pad_labels(
48  labels, self.model.param_init_net, self.model.net, self.num_classes
49  )
50  # Compute best path binary scores from the transitions matrix
51  path_binary_score = self._path_binary_scores(
52  labels, transitions_snapshot, seq_lengths
53  )
54  path_total_score = self.model.net.Add(
55  [path_binary_score, path_unary_score],
56  core.ScopedBlobReference("path_total"),
57  )
58  # Compute all paths score
59  zero_index = self.model.param_init_net.ConstantFill([], shape=[1], value=0)
60  initial_state = self.model.net.Gather(
61  [predictions, zero_index],
62  core.ScopedBlobReference("rnn_initial"),
63  dense_gradient=True,
64  )
65  input_data, _ = self.model.net.RemovePadding(
66  [predictions], padding_width=1, end_padding_width=0, outputs=2
67  )
68  input_data = self.model.net.ExpandDims(
69  [input_data], core.ScopedBlobReference("rnn_input_data"), dims=[1]
70  )
71  # Due to a bug in RecurrentNetworkGradientOp, we need to copy the
72  # transitions blob before sending it to the recurrent network
73  transitions_copy = self.model.net.Copy(
74  transitions_snapshot, core.ScopedBlobReference("transitions_copy")
75  )
76  all_paths_scores = self._crf_forward(
77  input_data, initial_state, transitions_copy
78  )
79  loss = self.model.net.Sub(
80  [all_paths_scores, path_total_score], core.ScopedBlobReference("crf_loss")
81  )
82  return loss
83 
84  def _path_binary_scores(self, labels, transitions, seq_lengths=None):
85  column_ids, _ = self.model.net.RemovePadding(
86  [labels], outputs=2, padding_width=1, end_padding_width=0
87  )
88  row_ids, _ = self.model.net.RemovePadding(
89  [labels], outputs=2, padding_width=0, end_padding_width=1
90  )
91  # Since there is no multi-dimensional gather, I flatten the matrix to
92  # a 1-d vector and transform the ids to (row_ids * num_columns +
93  # column_ids) and do gather in 1-d
94  num_columns_blob = self.model.net.ConstantFill(
95  [row_ids], value=self.num_classes_padded
96  )
97  flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
98  flattened_ids = self.model.net.Add([flattened_ids, column_ids])
99  flattened_transitions = self.model.net.FlattenToVec([transitions])
100  entries = self.model.net.Gather(
101  [flattened_transitions, flattened_ids], dense_gradient=True
102  )
103  return self.model.ReduceFrontSum(entries)
104 
105  def _gather_entries_sum(self, in_data, indices, index_size):
106  indices = self.model.net.Cast([indices], to="int64")
107  index_size_blob = self.model.param_init_net.ConstantFill(
108  [], shape=[1], value=index_size
109  )
110  query_one_hot = self.model.net.OneHot([indices, index_size_blob])
111  flattend_query = self.model.net.FlattenToVec(query_one_hot)
112  flattend_data = self.model.net.FlattenToVec(in_data)
113  query_scores = self.model.net.DotProduct([flattend_query, flattend_data])
114  final_sum = self.model.net.ReduceFrontSum([query_scores])
115  return final_sum
116 
117  def _crf_forward(
118  self, input_blob, initial_state, transitions_copy, seq_lengths=None
119  ):
120  # Build the RNN net and get the last timestep output
121  out_last = self.build_crf_net(input_blob, initial_state, transitions_copy)
122  out_last, _ = self.model.net.Reshape(
123  [out_last], outputs=2, shape=(self.num_classes_padded,)
124  )
125  zero_segment_id = self.model.param_init_net.ConstantFill(
126  [], value=0, shape=[self.num_classes_padded], dtype=core.DataType.INT32
127  )
128 
129  # Compute the accumlated total score of all the paths
130  accum_score = self.model.net.SortedSegmentRangeLogSumExp(
131  [out_last, zero_segment_id]
132  )
133  accum_score, _ = self.model.net.Reshape(accum_score, outputs=2, shape=())
134  return accum_score
135 
136  def build_crf_net(self, input_blob, initial_state, transitions):
137  """
138  Adds the crf_net recurrent operator to the model.
139 
140  model: model_helper.ModelHelper object new operators would be added
141  to
142 
143  input_blob: the input sequence in a format T x N x D
144  where T is sequence size, N - batch size and D - input dimention
145  ##Only supports batch-size 1##
146 
147  seq_lengths: blob containing sequence lengths (unused)
148  """
149 
150  scope = "crf_net"
151 
152  def s(name):
153  ""
154  # We have to manually scope due to our internal/external blob
155  # relationships.
156  return "{}/{}".format(str(scope), str(name))
157 
158  step_model = model_helper.ModelHelper(name="crf_step", param_model=self.model)
159  input_t, cell_t_prev, _ = step_model.net.AddExternalInputs(
160  core.ScopedBlobReference("input_t"),
161  core.ScopedBlobReference("cell_t_prev"),
162  transitions,
163  )
164  zero_segment_id = step_model.param_init_net.ConstantFill(
165  [],
166  [s("zero_segment_id")],
167  value=0,
168  shape=[self.num_classes_padded],
169  dtype=core.DataType.INT32,
170  )
171 
172  # A hack to bypass model cloning for test
173  step_model.param_init_net.AddExternalOutput(zero_segment_id)
174  """ the CRF step """
175  # Do tile
176  prev_transpose = brew.transpose(
177  step_model, cell_t_prev, [s("prev_transpose")], axes=(0, 2, 1)
178  )
179  prev_tiled = step_model.net.Tile(
180  prev_transpose, [s("prev_tiled")], tiles=self.num_classes_padded, axis=2
181  )
182  input_t_tiled = step_model.net.Tile(
183  input_t, [s("input_t_tiled")], tiles=self.num_classes_padded, axis=1
184  )
185  input_with_prev = step_model.net.Add(
186  [prev_tiled, input_t_tiled], [s("input_with_prev")]
187  )
188  all_with_transitions = step_model.net.Add(
189  [input_with_prev, transitions],
190  [s("prev_with_transitions")],
191  broadcast=1,
192  use_grad_hack=1,
193  )
194  all_with_transitions_reshaped, _ = step_model.net.Reshape(
195  all_with_transitions,
196  [s("all_with_transitions_reshaped"), s("all_with_transitions_orig")],
197  shape=(self.num_classes_padded, self.num_classes_padded),
198  )
199  cell_t = step_model.net.SortedSegmentRangeLogSumExp(
200  [all_with_transitions_reshaped, zero_segment_id], [s("cell_t")]
201  )
202  step_model.net.AddExternalOutputs(cell_t)
203  """ recurrent network """
204  cell_input_blob = initial_state
205  out_all, out_last = recurrent.recurrent_net(
206  net=self.model.net,
207  cell_net=step_model.net,
208  inputs=[(input_t, input_blob)],
209  initial_cell_inputs=[(cell_t_prev, cell_input_blob)],
210  links={cell_t_prev: cell_t},
211  scope=scope,
212  outputs_with_grads=(1,),
213  )
214  return out_last
215 
216  def update_predictions(self, classes):
217  def crf_update_predictions_op(inputs, outputs):
218  # This operator will compute the best path of classes by performing
219  # Viterbi decoding and then updates the predictions to make the tag
220  # On the best path has the highest score among the others
221  predictions = inputs[0].data
222  transitions = inputs[1].data
223  predictions = inputs[0].data
224  predictions_shape = inputs[0].shape
225  outputs[0].reshape(predictions_shape)
226 
227  trellis = np.zeros(predictions_shape)
228  backpointers = np.zeros(predictions_shape, dtype=np.int32)
229  trellis[0] = predictions[0]
230 
231  for t in range(1, predictions_shape[0]):
232  v = np.expand_dims(trellis[t - 1], 1) + transitions
233  trellis[t] = predictions[t] + np.max(v, 0)
234  backpointers[t] = np.argmax(v, 0)
235 
236  viterbi = [np.argmax(trellis[-1])]
237  for bp in reversed(backpointers[1:]):
238  viterbi.append(bp[viterbi[-1]])
239  viterbi.reverse()
240 
241  new_predictions = np.zeros(predictions_shape)
242  old_bests = []
243  for i, w_predictions in enumerate(predictions):
244  # Get the current tag with the maximum score
245  new_predictions[i] = predictions[i]
246  old_best = np.argmax(w_predictions)
247  old_bests.append(old_best)
248  # Swap the scores of the current best tag and the tag on the
249  # Viterbi path
250  w_predictions[viterbi[i]], w_predictions[old_best] = (
251  w_predictions[old_best],
252  w_predictions[viterbi[i]],
253  )
254  new_predictions[i] = w_predictions
255  # Remove the BOS and EOS entries from the predictions matrix
256  orig_predictions = new_predictions[1:-1, 0:-2]
257  outputs[0].reshape(orig_predictions.shape)
258  outputs[0].data[...] = orig_predictions
259 
260  padded_classes = CRFWithLoss.pad_predictions(
261  classes, self.model.param_init_net, self.model.net, self.num_classes
262  )
263  new_classes = self.model.net.Python(crf_update_predictions_op)(
264  [padded_classes, self.transitions],
265  core.ScopedBlobReference("post_crf_classes"),
266  )
267  return new_classes
268 
269  @staticmethod
270  def pad_labels(labels, init_net, net, num_classes):
271  bos_i = num_classes
272  eos_i = num_classes + 1
273  bos_i_b = init_net.ConstantFill([], shape=[1], value=bos_i)
274  eos_i_b = init_net.ConstantFill([], shape=[1], value=eos_i)
275  labels = net.Cast([labels], to="int64")
276  padded_labels, _ = net.Concat([bos_i_b, labels, eos_i_b], axis=0, outputs=2)
277  return padded_labels
278 
279  @staticmethod
280  def pad_predictions(predictions, init_net, net, num_classes):
281  # This function will introduce two labels for beginning of sequence
282  # And end of sequence, it will make the necessary udpates to the
283  # the predictions blob
284 
285  low_score = -1000.0 # An arbitray very low number
286  b_scores = np.array([[low_score] * num_classes + [0, low_score]]).astype(
287  np.float32
288  )
289 
290  e_scores = np.array([[low_score] * num_classes + [low_score, 0]]).astype(
291  np.float32
292  )
293 
294  b_scores = init_net.GivenTensorFill(
295  [], "b_scores", shape=[1, num_classes + 2], values=b_scores
296  )
297  e_scores = init_net.GivenTensorFill(
298  [], "e_scores", shape=[1, num_classes + 2], values=e_scores
299  )
300 
301  zero_index = net.ConstantFill([], shape=[1], value=0)
302  length = net.Gather([net.Shape([predictions]), zero_index])
303  length = net.Cast(length, to="int32")
304  t_range = net.LengthsRangeFill(length)
305  padding = net.ConstantFill([t_range], value=low_score)
306  padding = net.ExpandDims(padding, dims=[1])
307  padded_predictions, _ = net.Concat(
308  [predictions, padding, padding], outputs=2, axis=1
309  )
310  padded_predictions_concat, _ = net.Concat(
311  [b_scores, padded_predictions, e_scores], outputs=2, axis=0
312  )
313  return padded_predictions_concat
def build_crf_net(self, input_blob, initial_state, transitions)
Definition: crf.py:136
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
Definition: crf.py:84
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
Definition: crf.py:119
def _gather_entries_sum(self, in_data, indices, index_size)
Definition: crf.py:105