Caffe2 - Python API
A deep learning, cross platform ML framework
crf_predict.py
1 from __future__ import absolute_import, division, print_function, unicode_literals
2 
3 import numpy as np
4 from caffe2.python.crf import CRFWithLoss
5 
6 
7 def crf_update_predictions(model, crf_with_loss, classes):
8  return apply_crf(
9  model.param_init_net,
10  model.net,
11  crf_with_loss.transitions,
12  classes,
13  crf_with_loss.num_classes,
14  )
15 
16 
17 def apply_crf(init_net, net, transitions, predictions, num_classes):
18  padded_classes = CRFWithLoss.pad_predictions(
19  predictions, init_net, net, num_classes
20  )
21  bestPath = net.ViterbiPath([padded_classes, transitions])
22  new_padded_classes = net.SwapBestPath([padded_classes, bestPath])
23  # Revert the effect of pad_predictions by removing the last two rows and
24  # the last two columns
25  new_classes = net.RemovePadding(
26  [new_padded_classes], padding_width=1, end_padding_width=1
27  )
28  slice_starts = np.array([0, 0]).astype(np.int32)
29  slice_ends = np.array([-1, -3]).astype(np.int32)
30  slice_starts = net.GivenTensorIntFill([], shape=[2], values=slice_starts)
31  slice_ends = net.GivenTensorIntFill([], shape=[2], values=slice_ends)
32  new_classes = net.Slice([new_classes, slice_starts, slice_ends])
33  return new_classes