3 from __future__
import absolute_import, division, print_function, unicode_literals
10 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1 11 In order to support batch_size > 1, we will have to implement the CRFUnit 12 and its gradient in C++ and handle the different batches there. 17 def __init__(self, model, num_classes, transitions_blob=None):
21 if not transitions_blob:
22 transitions_blob = self.model.param_init_net.UniformFill(
24 [core.ScopedBlobReference(
"crf_transitions")],
32 def crf_loss(self, predictions, labels, seq_lengths=None):
36 transitions_snapshot = self.model.net.Copy(
37 self.
transitions, core.ScopedBlobReference(
"transitions_snapshot")
44 predictions = CRFWithLoss.pad_predictions(
45 predictions, self.model.param_init_net, self.model.net, self.
num_classes 47 labels = CRFWithLoss.pad_labels(
48 labels, self.model.param_init_net, self.model.net, self.
num_classes 52 labels, transitions_snapshot, seq_lengths
54 path_total_score = self.model.net.Add(
55 [path_binary_score, path_unary_score],
56 core.ScopedBlobReference(
"path_total"),
59 zero_index = self.model.param_init_net.ConstantFill([], shape=[1], value=0)
60 initial_state = self.model.net.Gather(
61 [predictions, zero_index],
62 core.ScopedBlobReference(
"rnn_initial"),
65 input_data, _ = self.model.net.RemovePadding(
66 [predictions], padding_width=1, end_padding_width=0, outputs=2
68 input_data = self.model.net.ExpandDims(
69 [input_data], core.ScopedBlobReference(
"rnn_input_data"), dims=[1]
73 transitions_copy = self.model.net.Copy(
74 transitions_snapshot, core.ScopedBlobReference(
"transitions_copy")
77 input_data, initial_state, transitions_copy
79 loss = self.model.net.Sub(
80 [all_paths_scores, path_total_score], core.ScopedBlobReference(
"crf_loss")
84 def _path_binary_scores(self, labels, transitions, seq_lengths=None):
85 column_ids, _ = self.model.net.RemovePadding(
86 [labels], outputs=2, padding_width=1, end_padding_width=0
88 row_ids, _ = self.model.net.RemovePadding(
89 [labels], outputs=2, padding_width=0, end_padding_width=1
94 num_columns_blob = self.model.net.ConstantFill(
97 flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
98 flattened_ids = self.model.net.Add([flattened_ids, column_ids])
99 flattened_transitions = self.model.net.FlattenToVec([transitions])
100 entries = self.model.net.Gather(
101 [flattened_transitions, flattened_ids], dense_gradient=
True 103 return self.model.ReduceFrontSum(entries)
105 def _gather_entries_sum(self, in_data, indices, index_size):
106 indices = self.model.net.Cast([indices], to=
"int64")
107 index_size_blob = self.model.param_init_net.ConstantFill(
108 [], shape=[1], value=index_size
110 query_one_hot = self.model.net.OneHot([indices, index_size_blob])
111 flattend_query = self.model.net.FlattenToVec(query_one_hot)
112 flattend_data = self.model.net.FlattenToVec(in_data)
113 query_scores = self.model.net.DotProduct([flattend_query, flattend_data])
114 final_sum = self.model.net.ReduceFrontSum([query_scores])
118 self, input_blob, initial_state, transitions_copy, seq_lengths=
None 121 out_last = self.
build_crf_net(input_blob, initial_state, transitions_copy)
122 out_last, _ = self.model.net.Reshape(
125 zero_segment_id = self.model.param_init_net.ConstantFill(
130 accum_score = self.model.net.SortedSegmentRangeLogSumExp(
131 [out_last, zero_segment_id]
133 accum_score, _ = self.model.net.Reshape(accum_score, outputs=2, shape=())
138 Adds the crf_net recurrent operator to the model. 140 model: model_helper.ModelHelper object new operators would be added 143 input_blob: the input sequence in a format T x N x D 144 where T is sequence size, N - batch size and D - input dimention 145 ##Only supports batch-size 1## 147 seq_lengths: blob containing sequence lengths (unused) 156 return "{}/{}".format(str(scope), str(name))
159 input_t, cell_t_prev, _ = step_model.net.AddExternalInputs(
160 core.ScopedBlobReference(
"input_t"),
161 core.ScopedBlobReference(
"cell_t_prev"),
164 zero_segment_id = step_model.param_init_net.ConstantFill(
166 [s(
"zero_segment_id")],
169 dtype=core.DataType.INT32,
173 step_model.param_init_net.AddExternalOutput(zero_segment_id)
176 prev_transpose = brew.transpose(
177 step_model, cell_t_prev, [s(
"prev_transpose")], axes=(0, 2, 1)
179 prev_tiled = step_model.net.Tile(
182 input_t_tiled = step_model.net.Tile(
185 input_with_prev = step_model.net.Add(
186 [prev_tiled, input_t_tiled], [s(
"input_with_prev")]
188 all_with_transitions = step_model.net.Add(
189 [input_with_prev, transitions],
190 [s(
"prev_with_transitions")],
194 all_with_transitions_reshaped, _ = step_model.net.Reshape(
195 all_with_transitions,
196 [s(
"all_with_transitions_reshaped"), s(
"all_with_transitions_orig")],
199 cell_t = step_model.net.SortedSegmentRangeLogSumExp(
200 [all_with_transitions_reshaped, zero_segment_id], [s(
"cell_t")]
202 step_model.net.AddExternalOutputs(cell_t)
203 """ recurrent network """ 204 cell_input_blob = initial_state
205 out_all, out_last = recurrent.recurrent_net(
207 cell_net=step_model.net,
208 inputs=[(input_t, input_blob)],
209 initial_cell_inputs=[(cell_t_prev, cell_input_blob)],
210 links={cell_t_prev: cell_t},
212 outputs_with_grads=(1,),
216 def update_predictions(self, classes):
217 def crf_update_predictions_op(inputs, outputs):
221 predictions = inputs[0].data
222 transitions = inputs[1].data
223 predictions = inputs[0].data
224 predictions_shape = inputs[0].shape
225 outputs[0].reshape(predictions_shape)
227 trellis = np.zeros(predictions_shape)
228 backpointers = np.zeros(predictions_shape, dtype=np.int32)
229 trellis[0] = predictions[0]
231 for t
in range(1, predictions_shape[0]):
232 v = np.expand_dims(trellis[t - 1], 1) + transitions
233 trellis[t] = predictions[t] + np.max(v, 0)
234 backpointers[t] = np.argmax(v, 0)
236 viterbi = [np.argmax(trellis[-1])]
237 for bp
in reversed(backpointers[1:]):
238 viterbi.append(bp[viterbi[-1]])
241 new_predictions = np.zeros(predictions_shape)
243 for i, w_predictions
in enumerate(predictions):
245 new_predictions[i] = predictions[i]
246 old_best = np.argmax(w_predictions)
247 old_bests.append(old_best)
250 w_predictions[viterbi[i]], w_predictions[old_best] = (
251 w_predictions[old_best],
252 w_predictions[viterbi[i]],
254 new_predictions[i] = w_predictions
256 orig_predictions = new_predictions[1:-1, 0:-2]
257 outputs[0].reshape(orig_predictions.shape)
258 outputs[0].data[...] = orig_predictions
260 padded_classes = CRFWithLoss.pad_predictions(
261 classes, self.model.param_init_net, self.model.net, self.
num_classes 263 new_classes = self.model.net.Python(crf_update_predictions_op)(
265 core.ScopedBlobReference(
"post_crf_classes"),
270 def pad_labels(labels, init_net, net, num_classes):
272 eos_i = num_classes + 1
273 bos_i_b = init_net.ConstantFill([], shape=[1], value=bos_i)
274 eos_i_b = init_net.ConstantFill([], shape=[1], value=eos_i)
275 labels = net.Cast([labels], to=
"int64")
276 padded_labels, _ = net.Concat([bos_i_b, labels, eos_i_b], axis=0, outputs=2)
280 def pad_predictions(predictions, init_net, net, num_classes):
286 b_scores = np.array([[low_score] * num_classes + [0, low_score]]).astype(
290 e_scores = np.array([[low_score] * num_classes + [low_score, 0]]).astype(
294 b_scores = init_net.GivenTensorFill(
295 [],
"b_scores", shape=[1, num_classes + 2], values=b_scores
297 e_scores = init_net.GivenTensorFill(
298 [],
"e_scores", shape=[1, num_classes + 2], values=e_scores
301 zero_index = net.ConstantFill([], shape=[1], value=0)
302 length = net.Gather([net.Shape([predictions]), zero_index])
303 length = net.Cast(length, to=
"int32")
304 t_range = net.LengthsRangeFill(length)
305 padding = net.ConstantFill([t_range], value=low_score)
306 padding = net.ExpandDims(padding, dims=[1])
307 padded_predictions, _ = net.Concat(
308 [predictions, padding, padding], outputs=2, axis=1
310 padded_predictions_concat, _ = net.Concat(
311 [b_scores, padded_predictions, e_scores], outputs=2, axis=0
313 return padded_predictions_concat
def build_crf_net(self, input_blob, initial_state, transitions)
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
def _gather_entries_sum(self, in_data, indices, index_size)