3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto
import caffe2_pb2
10 from future.utils
import viewitems
11 from google.protobuf.message
import DecodeError, Message
12 from google.protobuf
import text_format
18 from six
import integer_types, binary_type, text_type, string_types
20 OPTIMIZER_ITERATION_NAME =
"optimizer_iteration" 21 ITERATION_MUTEX_NAME =
"iteration_mutex" 24 def OpAlmostEqual(op_a, op_b, ignore_fields=None):
26 Two ops are identical except for each field in the `ignore_fields`. 28 ignore_fields = ignore_fields
or []
29 if not isinstance(ignore_fields, list):
30 ignore_fields = [ignore_fields]
32 assert all(isinstance(f, text_type)
for f
in ignore_fields), (
33 'Expect each field is text type, but got {}'.format(ignore_fields))
36 op = copy.deepcopy(op)
37 for field
in ignore_fields:
38 if op.HasField(field):
47 def CaffeBlobToNumpyArray(blob):
50 return (np.asarray(blob.data, dtype=np.float32)
51 .reshape(blob.num, blob.channels, blob.height, blob.width))
54 return (np.asarray(blob.data, dtype=np.float32)
55 .reshape(blob.shape.dim))
58 def Caffe2TensorToNumpyArray(tensor):
59 if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
61 tensor.float_data, dtype=np.float32).reshape(tensor.dims)
62 elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
64 tensor.double_data, dtype=np.float64).reshape(tensor.dims)
65 elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
67 tensor.int32_data, dtype=np.int).reshape(tensor.dims)
68 elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
70 tensor.int32_data, dtype=np.int16).reshape(tensor.dims)
71 elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
73 tensor.int32_data, dtype=np.uint16).reshape(tensor.dims)
74 elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
76 tensor.int32_data, dtype=np.int8).reshape(tensor.dims)
77 elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
79 tensor.int32_data, dtype=np.uint8).reshape(tensor.dims)
83 "Tensor data type not supported yet: " + str(tensor.data_type))
86 def NumpyArrayToCaffe2Tensor(arr, name=None):
87 tensor = caffe2_pb2.TensorProto()
88 tensor.dims.extend(arr.shape)
91 if arr.dtype == np.float32:
92 tensor.data_type = caffe2_pb2.TensorProto.FLOAT
93 tensor.float_data.extend(list(arr.flatten().astype(float)))
94 elif arr.dtype == np.float64:
95 tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
96 tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
97 elif arr.dtype == np.int
or arr.dtype == np.int32:
98 tensor.data_type = caffe2_pb2.TensorProto.INT32
99 tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
100 elif arr.dtype == np.int16:
101 tensor.data_type = caffe2_pb2.TensorProto.INT16
102 tensor.int32_data.extend(list(arr.flatten().astype(np.int16)))
103 elif arr.dtype == np.uint16:
104 tensor.data_type = caffe2_pb2.TensorProto.UINT16
105 tensor.int32_data.extend(list(arr.flatten().astype(np.uint16)))
106 elif arr.dtype == np.int8:
107 tensor.data_type = caffe2_pb2.TensorProto.INT8
108 tensor.int32_data.extend(list(arr.flatten().astype(np.int8)))
109 elif arr.dtype == np.uint8:
110 tensor.data_type = caffe2_pb2.TensorProto.UINT8
111 tensor.int32_data.extend(list(arr.flatten().astype(np.uint8)))
115 "Numpy data type not supported yet: " + str(arr.dtype))
119 def MakeArgument(key, value):
120 """Makes an argument based on the value type.""" 121 argument = caffe2_pb2.Argument()
123 iterable = isinstance(value, container_abcs.Iterable)
129 if isinstance(value, np.ndarray)
and value.dtype.type
is np.float32:
130 argument.floats.extend(value.flatten().tolist())
133 if isinstance(value, np.ndarray):
134 value = value.flatten().tolist()
135 elif isinstance(value, np.generic):
137 value = np.asscalar(value)
139 if type(value)
is float:
141 elif type(value)
in integer_types
or type(value)
is bool:
145 elif isinstance(value, binary_type):
147 elif isinstance(value, text_type):
148 argument.s = value.encode(
'utf-8')
149 elif isinstance(value, caffe2_pb2.NetDef):
150 argument.n.CopyFrom(value)
151 elif isinstance(value, Message):
152 argument.s = value.SerializeToString()
153 elif iterable
and all(type(v)
in [float, np.float_]
for v
in value):
154 argument.floats.extend(
155 v.item()
if type(v)
is np.float_
else v
for v
in value
157 elif iterable
and all(
158 type(v)
in integer_types
or type(v)
in [bool, np.int_]
for v
in value
160 argument.ints.extend(
161 v.item()
if type(v)
is np.int_
else v
for v
in value
163 elif iterable
and all(
164 isinstance(v, binary_type)
or isinstance(v, text_type)
for v
in value
166 argument.strings.extend(
167 v.encode(
'utf-8')
if isinstance(v, text_type)
else v
170 elif iterable
and all(isinstance(v, caffe2_pb2.NetDef)
for v
in value):
171 argument.nets.extend(value)
172 elif iterable
and all(isinstance(v, Message)
for v
in value):
173 argument.strings.extend(v.SerializeToString()
for v
in value)
177 "Unknown iterable argument type: key={} value={}, value " 178 "type={}[{}]".format(
179 key, value, type(value), set(type(v)
for v
in value)
184 "Unknown argument type: key={} value={}, value type={}".format(
185 key, value, type(value)
191 def TryReadProtoWithClass(cls, s):
192 """Reads a protobuffer with the given proto class. 195 cls: a protobuffer class. 196 s: a string of either binary or text protobuffer content. 199 proto: the protobuffer of cls 202 google.protobuf.message.DecodeError: if we cannot decode the message. 206 text_format.Parse(s, obj)
208 except text_format.ParseError:
209 obj.ParseFromString(s)
213 def GetContentFromProto(obj, function_map):
214 """Gets a specific field from a protocol buffer that matches the given class 216 for cls, func
in viewitems(function_map):
221 def GetContentFromProtoString(s, function_map):
222 for cls, func
in viewitems(function_map):
224 obj = TryReadProtoWithClass(cls, s)
229 raise DecodeError(
"Cannot find a fit protobuffer class.")
232 def ConvertProtoToBinary(proto_class, filename, out_filename):
233 """Convert a text file of the given protobuf class to binary.""" 234 with open(filename)
as f:
235 proto = TryReadProtoWithClass(proto_class, f.read())
236 with open(out_filename,
'w')
as fid:
237 fid.write(proto.SerializeToString())
240 def GetGPUMemoryUsageStats():
241 """Get GPU memory usage stats from CUDAContext/HIPContext. This requires flag 242 --caffe2_gpu_memory_tracking to be enabled""" 244 workspace.RunOperatorOnce(
249 device_option=core.DeviceOption(workspace.GpuDeviceType, 0),
252 b = workspace.FetchBlob(
"____mem____")
254 'total_by_gpu': b[0, :],
255 'max_by_gpu': b[1, :],
256 'total': np.sum(b[0, :]),
257 'max_total': np.sum(b[1, :])
261 def ResetBlobs(blobs):
263 workspace.RunOperatorOnce(
268 device_option=core.DeviceOption(caffe2_pb2.CPU),
275 This class allows to drop you into an interactive debugger 276 if there is an unhandled exception in your python script 284 if __name__ == '__main__': 285 from caffe2.python.utils import DebugMode 293 except KeyboardInterrupt:
299 'Entering interactive debugger. Type "bt" to print ' 300 'the full stacktrace. Type "help" to see command listing.')
301 print(sys.exc_info()[1])
309 def raiseIfNotEqual(a, b, msg):
311 raise Exception(
"{}. {} != {}".format(msg, a, b))
316 Use this method to decorate your function with DebugMode's functionality 322 raise Exception("Bar") 327 def wrapper(*args, **kwargs):
329 return f(*args, **kwargs)
330 return DebugMode.run(func)
335 def BuildUniqueMutexIter(
343 Often, a mutex guarded iteration counter is needed. This function creates a 344 mutex iter in the net uniquely (if the iter already existing, it does 347 This function returns the iter blob 349 iter = iter
if iter
is not None else OPTIMIZER_ITERATION_NAME
350 iter_mutex = iter_mutex
if iter_mutex
is not None else ITERATION_MUTEX_NAME
352 if not init_net.BlobIsDefined(iter):
354 with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
355 iteration = init_net.ConstantFill(
360 dtype=core.DataType.INT64,
362 iter_mutex = init_net.CreateMutex([], [iter_mutex])
363 net.AtomicIter([iter_mutex, iteration], [iteration])
365 iteration = init_net.GetBlobRef(iter)
369 def EnumClassKeyVals(cls):
371 assert type(cls) == type
377 if isinstance(v, string_types):
378 assert v
not in enum.values(), (
379 "Failed to resolve {} as Enum: " 380 "duplicate entries {}={}, {}={}".format(
381 cls, k, v, [key
for key
in enum
if enum[key] == v][0], v
388 def ArgsToDict(args):
390 Convert a list of arguments to a name, value dictionary. Assumes that 391 each argument has a name. Otherwise, the argument is skipped. 395 if not arg.HasField(
"name"):
397 for d
in arg.DESCRIPTOR.fields:
400 if d.label == d.LABEL_OPTIONAL
and arg.HasField(d.name):
401 ans[arg.name] = getattr(arg, d.name)
403 elif d.label == d.LABEL_REPEATED:
404 list_ = getattr(arg, d.name)
406 ans[arg.name] = list_
413 def NHWC2NCHW(tensor):
414 assert tensor.ndim >= 1
415 return tensor.transpose((0, tensor.ndim - 1) + tuple(range(1, tensor.ndim - 1)))
418 def NCHW2NHWC(tensor):
419 assert tensor.ndim >= 2
420 return tensor.transpose((0,) + tuple(range(2, tensor.ndim)) + (1,))