5 TensorRT related transformation 6 Note that ONNX-TRT enforce an NCHW input! 9 from __future__
import absolute_import
10 from __future__
import division
11 from __future__
import print_function
12 from __future__
import unicode_literals
14 from caffe2.proto
import caffe2_pb2
21 def _dim_values_to_list(dim_values):
22 return [x.dim_value
for x
in dim_values]
25 def _get_output_shapes(output_value_infos):
26 names = [x.name
for x
in output_value_infos]
27 shapes = [_dim_values_to_list(x.type.tensor_type.shape.dim)
for x
in output_value_infos]
28 return dict(zip(names, shapes))
34 except Exception
as _:
35 raise Exception(
"TensorRT related functions require CUDA support")
37 def convert_onnx_model_to_trt_op(onnx_model,
39 max_workspace_size=2*1024*1024,
43 Convert the whole ONNX model to a TensorRT C2 op 46 trt_str = C.onnx_to_trt_op(onnx_model.SerializeToString(),
47 _get_output_shapes(onnx_model.graph.output),
52 op = caffe2_pb2.OperatorDef()
53 op.ParseFromString(trt_str)
58 def _infer_shapes(pred_net, inputs):
59 workspace.RunNetOnce(pred_net)
61 for op
in pred_net.op:
64 blob = workspace.FetchBlob(o)
65 if hasattr(blob,
'shape'):
69 blob = workspace.FetchBlob(i)
70 if hasattr(blob,
'shape'):
76 def transform_caffe2_net(
79 populate_shapes =
False,
81 max_workspace_size=2*1024*1024,
84 build_serializable_op=
True):
86 Transfrom the caffe2_net by collapsing TRT-runnable nodes into trt c2 ops 95 for k,v
in input_shapes.items():
96 input_data[k] = np.random.randn(*v).astype(np.float32)
97 shape_hints = _infer_shapes(pred_net, input_data)
99 for k,v
in input_shapes.items():
101 pred_net_str = C.transform_trt(pred_net.SerializeToString(),
107 build_serializable_op)
108 pred_net_cut = caffe2_pb2.NetDef()
109 pred_net_cut.ParseFromString(pred_net_str)