3 from __future__ 
import absolute_import, division, print_function, unicode_literals
    10 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1    11 In order to support batch_size > 1, we will have to implement the CRFUnit    12 and its gradient in C++ and handle the different batches there.    17     def __init__(self, model, num_classes, transitions_blob=None):
    21         if not transitions_blob:
    22             transitions_blob = self.model.param_init_net.UniformFill(
    24                 [core.ScopedBlobReference(
"crf_transitions")],
    32     def crf_loss(self, predictions, labels, seq_lengths=None):
    36         transitions_snapshot = self.model.net.Copy(
    37             self.
transitions, core.ScopedBlobReference(
"transitions_snapshot")
    44         predictions = CRFWithLoss.pad_predictions(
    45             predictions, self.model.param_init_net, self.model.net, self.
num_classes    47         labels = CRFWithLoss.pad_labels(
    48             labels, self.model.param_init_net, self.model.net, self.
num_classes    52             labels, transitions_snapshot, seq_lengths
    54         path_total_score = self.model.net.Add(
    55             [path_binary_score, path_unary_score],
    56             core.ScopedBlobReference(
"path_total"),
    59         zero_index = self.model.param_init_net.ConstantFill([], shape=[1], value=0)
    60         initial_state = self.model.net.Gather(
    61             [predictions, zero_index],
    62             core.ScopedBlobReference(
"rnn_initial"),
    65         input_data, _ = self.model.net.RemovePadding(
    66             [predictions], padding_width=1, end_padding_width=0, outputs=2
    68         input_data = self.model.net.ExpandDims(
    69             [input_data], core.ScopedBlobReference(
"rnn_input_data"), dims=[1]
    73         transitions_copy = self.model.net.Copy(
    74             transitions_snapshot, core.ScopedBlobReference(
"transitions_copy")
    77             input_data, initial_state, transitions_copy
    79         loss = self.model.net.Sub(
    80             [all_paths_scores, path_total_score], core.ScopedBlobReference(
"crf_loss")
    84     def _path_binary_scores(self, labels, transitions, seq_lengths=None):
    85         column_ids, _ = self.model.net.RemovePadding(
    86             [labels], outputs=2, padding_width=1, end_padding_width=0
    88         row_ids, _ = self.model.net.RemovePadding(
    89             [labels], outputs=2, padding_width=0, end_padding_width=1
    94         num_columns_blob = self.model.net.ConstantFill(
    97         flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
    98         flattened_ids = self.model.net.Add([flattened_ids, column_ids])
    99         flattened_transitions = self.model.net.FlattenToVec([transitions])
   100         entries = self.model.net.Gather(
   101             [flattened_transitions, flattened_ids], dense_gradient=
True   103         return self.model.ReduceFrontSum(entries)
   105     def _gather_entries_sum(self, in_data, indices, index_size):
   106         indices = self.model.net.Cast([indices], to=
"int64")
   107         index_size_blob = self.model.param_init_net.ConstantFill(
   108             [], shape=[1], value=index_size
   110         query_one_hot = self.model.net.OneHot([indices, index_size_blob])
   111         flattend_query = self.model.net.FlattenToVec(query_one_hot)
   112         flattend_data = self.model.net.FlattenToVec(in_data)
   113         query_scores = self.model.net.DotProduct([flattend_query, flattend_data])
   114         final_sum = self.model.net.ReduceFrontSum([query_scores])
   118         self, input_blob, initial_state, transitions_copy, seq_lengths=
None   121         out_last = self.
build_crf_net(input_blob, initial_state, transitions_copy)
   122         out_last, _ = self.model.net.Reshape(
   125         zero_segment_id = self.model.param_init_net.ConstantFill(
   130         accum_score = self.model.net.SortedSegmentRangeLogSumExp(
   131             [out_last, zero_segment_id]
   133         accum_score, _ = self.model.net.Reshape(accum_score, outputs=2, shape=())
   138             Adds the crf_net recurrent operator to the model.   140             model: model_helper.ModelHelper object new operators would be added   143             input_blob: the input sequence in a format T x N x D   144             where T is sequence size, N - batch size and D - input dimention   145             ##Only supports batch-size 1##   147             seq_lengths: blob containing sequence lengths (unused)   156             return "{}/{}".format(str(scope), str(name))
   159         input_t, cell_t_prev, _ = step_model.net.AddExternalInputs(
   160             core.ScopedBlobReference(
"input_t"),
   161             core.ScopedBlobReference(
"cell_t_prev"),
   164         zero_segment_id = step_model.param_init_net.ConstantFill(
   166             [s(
"zero_segment_id")],
   169             dtype=core.DataType.INT32,
   173         step_model.param_init_net.AddExternalOutput(zero_segment_id)
   176         prev_transpose = brew.transpose(
   177             step_model, cell_t_prev, [s(
"prev_transpose")], axes=(0, 2, 1)
   179         prev_tiled = step_model.net.Tile(
   182         input_t_tiled = step_model.net.Tile(
   185         input_with_prev = step_model.net.Add(
   186             [prev_tiled, input_t_tiled], [s(
"input_with_prev")]
   188         all_with_transitions = step_model.net.Add(
   189             [input_with_prev, transitions],
   190             [s(
"prev_with_transitions")],
   194         all_with_transitions_reshaped, _ = step_model.net.Reshape(
   195             all_with_transitions,
   196             [s(
"all_with_transitions_reshaped"), s(
"all_with_transitions_orig")],
   199         cell_t = step_model.net.SortedSegmentRangeLogSumExp(
   200             [all_with_transitions_reshaped, zero_segment_id], [s(
"cell_t")]
   202         step_model.net.AddExternalOutputs(cell_t)
   203         """ recurrent network """   204         cell_input_blob = initial_state
   205         out_all, out_last = recurrent.recurrent_net(
   207             cell_net=step_model.net,
   208             inputs=[(input_t, input_blob)],
   209             initial_cell_inputs=[(cell_t_prev, cell_input_blob)],
   210             links={cell_t_prev: cell_t},
   212             outputs_with_grads=(1,),
   216     def update_predictions(self, classes):
   217         def crf_update_predictions_op(inputs, outputs):
   221             predictions = inputs[0].data
   222             transitions = inputs[1].data
   223             predictions = inputs[0].data
   224             predictions_shape = inputs[0].shape
   225             outputs[0].reshape(predictions_shape)
   227             trellis = np.zeros(predictions_shape)
   228             backpointers = np.zeros(predictions_shape, dtype=np.int32)
   229             trellis[0] = predictions[0]
   231             for t 
in range(1, predictions_shape[0]):
   232                 v = np.expand_dims(trellis[t - 1], 1) + transitions
   233                 trellis[t] = predictions[t] + np.max(v, 0)
   234                 backpointers[t] = np.argmax(v, 0)
   236             viterbi = [np.argmax(trellis[-1])]
   237             for bp 
in reversed(backpointers[1:]):
   238                 viterbi.append(bp[viterbi[-1]])
   241             new_predictions = np.zeros(predictions_shape)
   243             for i, w_predictions 
in enumerate(predictions):
   245                 new_predictions[i] = predictions[i]
   246                 old_best = np.argmax(w_predictions)
   247                 old_bests.append(old_best)
   250                 w_predictions[viterbi[i]], w_predictions[old_best] = (
   251                     w_predictions[old_best],
   252                     w_predictions[viterbi[i]],
   254                 new_predictions[i] = w_predictions
   256             orig_predictions = new_predictions[1:-1, 0:-2]
   257             outputs[0].reshape(orig_predictions.shape)
   258             outputs[0].data[...] = orig_predictions
   260         padded_classes = CRFWithLoss.pad_predictions(
   261             classes, self.model.param_init_net, self.model.net, self.
num_classes   263         new_classes = self.model.net.Python(crf_update_predictions_op)(
   265             core.ScopedBlobReference(
"post_crf_classes"),
   270     def pad_labels(labels, init_net, net, num_classes):
   272         eos_i = num_classes + 1
   273         bos_i_b = init_net.ConstantFill([], shape=[1], value=bos_i)
   274         eos_i_b = init_net.ConstantFill([], shape=[1], value=eos_i)
   275         labels = net.Cast([labels], to=
"int64")
   276         padded_labels, _ = net.Concat([bos_i_b, labels, eos_i_b], axis=0, outputs=2)
   280     def pad_predictions(predictions, init_net, net, num_classes):
   286         b_scores = np.array([[low_score] * num_classes + [0, low_score]]).astype(
   290         e_scores = np.array([[low_score] * num_classes + [low_score, 0]]).astype(
   294         b_scores = init_net.GivenTensorFill(
   295             [], 
"b_scores", shape=[1, num_classes + 2], values=b_scores
   297         e_scores = init_net.GivenTensorFill(
   298             [], 
"e_scores", shape=[1, num_classes + 2], values=e_scores
   301         zero_index = net.ConstantFill([], shape=[1], value=0)
   302         length = net.Gather([net.Shape([predictions]), zero_index])
   303         length = net.Cast(length, to=
"int32")
   304         t_range = net.LengthsRangeFill(length)
   305         padding = net.ConstantFill([t_range], value=low_score)
   306         padding = net.ExpandDims(padding, dims=[1])
   307         padded_predictions, _ = net.Concat(
   308             [predictions, padding, padding], outputs=2, axis=1
   310         padded_predictions_concat, _ = net.Concat(
   311             [b_scores, padded_predictions, e_scores], outputs=2, axis=0
   313         return padded_predictions_concat
 def build_crf_net(self, input_blob, initial_state, transitions)
 
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
 
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
 
def _gather_entries_sum(self, in_data, indices, index_size)