Caffe2 - Python API
A deep learning, cross platform ML framework
rewrite_graph.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import copy
22 from caffe2.proto import caffe2_pb2
23 from caffe2.python import core
24 
25 
26 def rewrite_init_net_simple(net):
27  for op in net.op:
28  op.device_option.device_type = caffe2_pb2.MKLDNN
29 
30 def last_producer(ops, blob):
31  for (i, op) in reversed(list(enumerate(ops))):
32  if blob in op.output:
33  return i
34  raise ValueError("Failed to find last producer of blob, %s", blob)
35 
36 
37 def rewrite_run_net_simple(net):
38  # Simple rewrite for now - assume entire graph can be executed
39  # with MKL, so just insert copy ops for external_input[0] and
40  # external_output[0]
41  def mkl_tmp(name):
42  return "{}__MKL__".format(name)
43 
44  input_blob = net.external_input[0]
45  if input_blob != net.op[0].input[0]:
46  raise Exception(
47  "Input blob: {} is not consumed by first op: {}".format(
48  input_blob, net.op[0]))
49  # Modify input/outputs to point to copied MKL blobs.
50  copy_input_op = core.CreateOperator(
51  "CopyCPUToMKL", input_blob, mkl_tmp(input_blob))
52  net.op[0].input[0] = mkl_tmp(input_blob)
53 
54  copy_output_ops = [
55  core.CreateOperator("CopyMKLToCPU", mkl_tmp(output_blob), output_blob)
56  for output_blob in net.external_output]
57 
58  for output_blob in net.external_output:
59  last_producer_idx = last_producer(net.op, output_blob)
60  renamed_outputs = [blob if blob != output_blob else mkl_tmp(blob)
61  for blob in net.op[last_producer_idx].output]
62  net.op[last_producer_idx].output[:] = renamed_outputs
63  # Rename any subsequent consumers of an output blob.
64  for op in net.op[last_producer_idx + 1:]:
65  renamed_input = [blob if blob != output_blob else mkl_tmp(blob)
66  for blob in op.input]
67  op.input[:] = renamed_input
68 
69  ops = [copy_input_op] + net.op[:] + copy_output_ops
70  del net.op[:]
71  net.op.extend(ops)
72  for op in net.op:
73  op.device_option.MergeFrom(
74  core.DeviceOption(device_type=caffe2_pb2.MKLDNN))
75  op.engine = ""
76 
77 
78 def rewrite_model_helper_simple(model):
79  model = copy.deepcopy(model)
80  # All parameter initialization should run on MKL
81  rewrite_init_net_simple(model.param_init_net.Proto())
82  rewrite_run_net_simple(model.net.Proto())
83  return model