4 from __future__
import absolute_import
5 from __future__
import division
6 from __future__
import print_function
7 from __future__
import unicode_literals
9 from caffe2.proto
import caffe2_pb2
13 def add_tensor(net, name, blob):
14 ''' Create an operator to store the tensor 'blob', 15 run the operator to put the blob to workspace. 16 uint8 is stored as an array of string with one element. 19 np.dtype(
'float32'):
"GivenTensorFill",
20 np.dtype(
'int32'):
"GivenTensorIntFill",
21 np.dtype(
'int64'):
"GivenTensorInt64Fill",
22 np.dtype(
'uint8'):
"GivenTensorByteStringToUInt8Fill",
23 np.dtype(
'O'):
"GivenTensorStringFill" 30 if blob.dtype == np.dtype(
'uint8'):
32 values = [blob.tobytes()]
36 if blob.dtype == np.dtype(
'O'):
38 assert(isinstance(blob_val, bytes))
40 op = core.CreateOperator(
41 kTypeNameMapper[blob.dtype],
44 utils.MakeArgument(
"shape", shape),
45 utils.MakeArgument(
"values", values),
51 def Export(workspace, net, params):
52 """Returns init_net and predict_net suitable for writing to disk 53 and loading into a Predictor""" 54 proto = net
if isinstance(net, caffe2_pb2.NetDef)
else net.Proto()
55 predict_net = caffe2_pb2.NetDef()
56 predict_net.CopyFrom(proto)
57 init_net = caffe2_pb2.NetDef()
59 ssa, blob_versions = core.get_ssa(net)
61 for versioned_inputs, _
in ssa:
62 inputs += [name
for name, _
in versioned_inputs]
64 input_blobs = [blob_name
for blob_name, version
in 66 if version == 0
and blob_name
not in params]
69 output_blobs = [blob_name
for blob_name, version
in 71 if version != 0
and blob_name
not in inputs]
73 for blob_ref
in params:
74 blob_name = str(blob_ref)
75 blob = workspace.FetchBlob(blob_name)
76 add_tensor(init_net, blob_name, blob)
80 for blob_name
in input_blobs:
84 "GivenTensorFill", [], [blob_name],
86 utils.MakeArgument(
"shape", [1, 1]),
87 utils.MakeArgument(
"values", [0.0])
94 del predict_net.external_input[:]
96 new_external_inputs = input_blobs
97 for external_input
in proto.external_input:
98 if external_input
not in new_external_inputs:
99 new_external_inputs.append(external_input)
102 predict_net.external_input.extend(new_external_inputs)
104 del predict_net.external_output[:]
105 predict_net.external_output.extend(output_blobs)
106 return init_net, predict_net