Caffe2 - Python API
A deep learning, cross platform ML framework
transform.py
1 ## @package onnx
2 #Module caffe2.python.trt.transform
3 
4 """
5 TensorRT related transformation
6 Note that ONNX-TRT enforce an NCHW input!
7 """
8 
9 from __future__ import absolute_import
10 from __future__ import division
11 from __future__ import print_function
12 from __future__ import unicode_literals
13 
14 from caffe2.proto import caffe2_pb2
15 from caffe2.python.onnx.helper import c2_native_run_net, c2_native_run_op
16 from caffe2.python import core, workspace
17 import caffe2.python.onnx.frontend as c2_front
19 import numpy as np
20 
21 def _dim_values_to_list(dim_values):
22  return [x.dim_value for x in dim_values]
23 
24 
25 def _get_output_shapes(output_value_infos):
26  names = [x.name for x in output_value_infos]
27  shapes = [_dim_values_to_list(x.type.tensor_type.shape.dim) for x in output_value_infos]
28  return dict(zip(names, shapes))
29 
30 
31 def check_gpu_():
32  try:
33  C.get_cuda_version()
34  except Exception as _:
35  raise Exception("TensorRT related functions require CUDA support")
36 
37 def convert_onnx_model_to_trt_op(onnx_model,
38  max_batch_size=64,
39  max_workspace_size=2*1024*1024,
40  verbosity=1,
41  debug_builder=False):
42  """
43  Convert the whole ONNX model to a TensorRT C2 op
44  """
45  check_gpu_()
46  trt_str = C.onnx_to_trt_op(onnx_model.SerializeToString(),
47  _get_output_shapes(onnx_model.graph.output),
48  max_batch_size,
49  max_workspace_size,
50  verbosity,
51  debug_builder)
52  op = caffe2_pb2.OperatorDef()
53  op.ParseFromString(trt_str)
54  return op
55 
56 
57 # Assume the workspace is already filled with init weights
58 def _infer_shapes(pred_net, inputs):
59  workspace.RunNetOnce(pred_net)
60  hints = {}
61  for op in pred_net.op:
62  for o in op.output:
63  if o not in hints:
64  blob = workspace.FetchBlob(o)
65  if hasattr(blob, 'shape'):
66  hints[o] = blob.shape
67  for i in op.input:
68  if i not in hints:
69  blob = workspace.FetchBlob(i)
70  if hasattr(blob, 'shape'):
71  hints[i] = blob.shape
72 
73  return hints
74 
75 
76 def transform_caffe2_net(
77  pred_net,
78  input_shapes,
79  populate_shapes = False,
80  max_batch_size=64,
81  max_workspace_size=2*1024*1024,
82  verbosity=1,
83  debug_builder=False,
84  build_serializable_op=True):
85  """
86  Transfrom the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops
87  """
88  check_gpu_()
89 
90  # Hacky way to infer shapes as not all our operators have shape inference function.
91  # Normally this is not needed
92  shape_hints = {}
93  if populate_shapes:
94  input_data = {}
95  for k,v in input_shapes.items():
96  input_data[k] = np.random.randn(*v).astype(np.float32)
97  shape_hints = _infer_shapes(pred_net, input_data)
98 
99  for k,v in input_shapes.items():
100  shape_hints[k] = v
101  pred_net_str = C.transform_trt(pred_net.SerializeToString(),
102  shape_hints,
103  max_batch_size,
104  max_workspace_size,
105  verbosity,
106  debug_builder,
107  build_serializable_op)
108  pred_net_cut = caffe2_pb2.NetDef()
109  pred_net_cut.ParseFromString(pred_net_str)
110  return pred_net_cut
111