Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 # @package utils
2 # Module caffe2.python.utils
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto import caffe2_pb2
9 from caffe2.python.compatibility import container_abcs
10 from future.utils import viewitems
11 from google.protobuf.message import DecodeError, Message
12 from google.protobuf import text_format
13 
14 import sys
15 import copy
16 import functools
17 import numpy as np
18 from six import integer_types, binary_type, text_type, string_types
19 
20 OPTIMIZER_ITERATION_NAME = "optimizer_iteration"
21 ITERATION_MUTEX_NAME = "iteration_mutex"
22 
23 
24 def OpAlmostEqual(op_a, op_b, ignore_fields=None):
25  '''
26  Two ops are identical except for each field in the `ignore_fields`.
27  '''
28  ignore_fields = ignore_fields or []
29  if not isinstance(ignore_fields, list):
30  ignore_fields = [ignore_fields]
31 
32  assert all(isinstance(f, text_type) for f in ignore_fields), (
33  'Expect each field is text type, but got {}'.format(ignore_fields))
34 
35  def clean_op(op):
36  op = copy.deepcopy(op)
37  for field in ignore_fields:
38  if op.HasField(field):
39  op.ClearField(field)
40  return op
41 
42  op_a = clean_op(op_a)
43  op_b = clean_op(op_b)
44  return op_a == op_b
45 
46 
47 def CaffeBlobToNumpyArray(blob):
48  if (blob.num != 0):
49  # old style caffe blob.
50  return (np.asarray(blob.data, dtype=np.float32)
51  .reshape(blob.num, blob.channels, blob.height, blob.width))
52  else:
53  # new style caffe blob.
54  return (np.asarray(blob.data, dtype=np.float32)
55  .reshape(blob.shape.dim))
56 
57 
58 def Caffe2TensorToNumpyArray(tensor):
59  if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
60  return np.asarray(
61  tensor.float_data, dtype=np.float32).reshape(tensor.dims)
62  elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
63  return np.asarray(
64  tensor.double_data, dtype=np.float64).reshape(tensor.dims)
65  elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
66  return np.asarray(
67  tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
68  elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
69  return np.asarray(
70  tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
71  elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
72  return np.asarray(
73  tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
74  elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
75  return np.asarray(
76  tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
77  elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
78  return np.asarray(
79  tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
80  else:
81  # TODO: complete the data type: bool, float16, byte, int64, string
82  raise RuntimeError(
83  "Tensor data type not supported yet: " + str(tensor.data_type))
84 
85 
86 def NumpyArrayToCaffe2Tensor(arr, name=None):
87  tensor = caffe2_pb2.TensorProto()
88  tensor.dims.extend(arr.shape)
89  if name:
90  tensor.name = name
91  if arr.dtype == np.float32:
92  tensor.data_type = caffe2_pb2.TensorProto.FLOAT
93  tensor.float_data.extend(list(arr.flatten().astype(float)))
94  elif arr.dtype == np.float64:
95  tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
96  tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
97  elif arr.dtype == np.int or arr.dtype == np.int32:
98  tensor.data_type = caffe2_pb2.TensorProto.INT32
99  tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
100  elif arr.dtype == np.int16:
101  tensor.data_type = caffe2_pb2.TensorProto.INT16
102  tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
103  elif arr.dtype == np.uint16:
104  tensor.data_type = caffe2_pb2.TensorProto.UINT16
105  tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
106  elif arr.dtype == np.int8:
107  tensor.data_type = caffe2_pb2.TensorProto.INT8
108  tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
109  elif arr.dtype == np.uint8:
110  tensor.data_type = caffe2_pb2.TensorProto.UINT8
111  tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
112  else:
113  # TODO: complete the data type: bool, float16, byte, int64, string
114  raise RuntimeError(
115  "Numpy data type not supported yet: " + str(arr.dtype))
116  return tensor
117 
118 
119 def MakeArgument(key, value):
120  """Makes an argument based on the value type."""
121  argument = caffe2_pb2.Argument()
122  argument.name = key
123  iterable = isinstance(value, container_abcs.Iterable)
124 
125  # Fast tracking common use case where a float32 array of tensor parameters
126  # needs to be serialized. The entire array is guaranteed to have the same
127  # dtype, so no per-element checking necessary and no need to convert each
128  # element separately.
129  if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
130  argument.floats.extend(value.flatten().tolist())
131  return argument
132 
133  if isinstance(value, np.ndarray):
134  value = value.flatten().tolist()
135  elif isinstance(value, np.generic):
136  # convert numpy scalar to native python type
137  value = np.asscalar(value)
138 
139  if type(value) is float:
140  argument.f = value
141  elif type(value) in integer_types or type(value) is bool:
142  # We make a relaxation that a boolean variable will also be stored as
143  # int.
144  argument.i = value
145  elif isinstance(value, binary_type):
146  argument.s = value
147  elif isinstance(value, text_type):
148  argument.s = value.encode('utf-8')
149  elif isinstance(value, caffe2_pb2.NetDef):
150  argument.n.CopyFrom(value)
151  elif isinstance(value, Message):
152  argument.s = value.SerializeToString()
153  elif iterable and all(type(v) in [float, np.float_] for v in value):
154  argument.floats.extend(
155  v.item() if type(v) is np.float_ else v for v in value
156  )
157  elif iterable and all(
158  type(v) in integer_types or type(v) in [bool, np.int_] for v in value
159  ):
160  argument.ints.extend(
161  v.item() if type(v) is np.int_ else v for v in value
162  )
163  elif iterable and all(
164  isinstance(v, binary_type) or isinstance(v, text_type) for v in value
165  ):
166  argument.strings.extend(
167  v.encode('utf-8') if isinstance(v, text_type) else v
168  for v in value
169  )
170  elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
171  argument.nets.extend(value)
172  elif iterable and all(isinstance(v, Message) for v in value):
173  argument.strings.extend(v.SerializeToString() for v in value)
174  else:
175  if iterable:
176  raise ValueError(
177  "Unknown iterable argument type: key={} value={}, value "
178  "type={}[{}]".format(
179  key, value, type(value), set(type(v) for v in value)
180  )
181  )
182  else:
183  raise ValueError(
184  "Unknown argument type: key={} value={}, value type={}".format(
185  key, value, type(value)
186  )
187  )
188  return argument
189 
190 
191 def TryReadProtoWithClass(cls, s):
192  """Reads a protobuffer with the given proto class.
193 
194  Inputs:
195  cls: a protobuffer class.
196  s: a string of either binary or text protobuffer content.
197 
198  Outputs:
199  proto: the protobuffer of cls
200 
201  Throws:
202  google.protobuf.message.DecodeError: if we cannot decode the message.
203  """
204  obj = cls()
205  try:
206  text_format.Parse(s, obj)
207  return obj
208  except text_format.ParseError:
209  obj.ParseFromString(s)
210  return obj
211 
212 
213 def GetContentFromProto(obj, function_map):
214  """Gets a specific field from a protocol buffer that matches the given class
215  """
216  for cls, func in viewitems(function_map):
217  if type(obj) is cls:
218  return func(obj)
219 
220 
221 def GetContentFromProtoString(s, function_map):
222  for cls, func in viewitems(function_map):
223  try:
224  obj = TryReadProtoWithClass(cls, s)
225  return func(obj)
226  except DecodeError:
227  continue
228  else:
229  raise DecodeError("Cannot find a fit protobuffer class.")
230 
231 
232 def ConvertProtoToBinary(proto_class, filename, out_filename):
233  """Convert a text file of the given protobuf class to binary."""
234  with open(filename) as f:
235  proto = TryReadProtoWithClass(proto_class, f.read())
236  with open(out_filename, 'w') as fid:
237  fid.write(proto.SerializeToString())
238 
239 
240 def GetGPUMemoryUsageStats():
241  """Get GPU memory usage stats from CUDAContext/HIPContext. This requires flag
242  --caffe2_gpu_memory_tracking to be enabled"""
243  from caffe2.python import workspace, core
244  workspace.RunOperatorOnce(
245  core.CreateOperator(
246  "GetGPUMemoryUsage",
247  [],
248  ["____mem____"],
249  device_option=core.DeviceOption(workspace.GpuDeviceType, 0),
250  ),
251  )
252  b = workspace.FetchBlob("____mem____")
253  return {
254  'total_by_gpu': b[0, :],
255  'max_by_gpu': b[1, :],
256  'total': np.sum(b[0, :]),
257  'max_total': np.sum(b[1, :])
258  }
259 
260 
261 def ResetBlobs(blobs):
262  from caffe2.python import workspace, core
263  workspace.RunOperatorOnce(
264  core.CreateOperator(
265  "Free",
266  list(blobs),
267  list(blobs),
268  device_option=core.DeviceOption(caffe2_pb2.CPU),
269  ),
270  )
271 
272 
273 class DebugMode(object):
274  '''
275  This class allows to drop you into an interactive debugger
276  if there is an unhandled exception in your python script
277 
278  Example of usage:
279 
280  def main():
281  # your code here
282  pass
283 
284  if __name__ == '__main__':
285  from caffe2.python.utils import DebugMode
286  DebugMode.run(main)
287  '''
288 
289  @classmethod
290  def run(cls, func):
291  try:
292  return func()
293  except KeyboardInterrupt:
294  raise
295  except Exception:
296  import pdb
297 
298  print(
299  'Entering interactive debugger. Type "bt" to print '
300  'the full stacktrace. Type "help" to see command listing.')
301  print(sys.exc_info()[1])
302  print
303 
304  pdb.post_mortem()
305  sys.exit(1)
306  raise
307 
308 
309 def raiseIfNotEqual(a, b, msg):
310  if a != b:
311  raise Exception("{}. {} != {}".format(msg, a, b))
312 
313 
314 def debug(f):
315  '''
316  Use this method to decorate your function with DebugMode's functionality
317 
318  Example:
319 
320  @debug
321  def test_foo(self):
322  raise Exception("Bar")
323 
324  '''
325 
326  @functools.wraps(f)
327  def wrapper(*args, **kwargs):
328  def func():
329  return f(*args, **kwargs)
330  return DebugMode.run(func)
331 
332  return wrapper
333 
334 
335 def BuildUniqueMutexIter(
336  init_net,
337  net,
338  iter=None,
339  iter_mutex=None,
340  iter_val=0
341 ):
342  '''
343  Often, a mutex guarded iteration counter is needed. This function creates a
344  mutex iter in the net uniquely (if the iter already existing, it does
345  nothing)
346 
347  This function returns the iter blob
348  '''
349  iter = iter if iter is not None else OPTIMIZER_ITERATION_NAME
350  iter_mutex = iter_mutex if iter_mutex is not None else ITERATION_MUTEX_NAME
351  from caffe2.python import core
352  if not init_net.BlobIsDefined(iter):
353  # Add training operators.
354  with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
355  iteration = init_net.ConstantFill(
356  [],
357  iter,
358  shape=[1],
359  value=iter_val,
360  dtype=core.DataType.INT64,
361  )
362  iter_mutex = init_net.CreateMutex([], [iter_mutex])
363  net.AtomicIter([iter_mutex, iteration], [iteration])
364  else:
365  iteration = init_net.GetBlobRef(iter)
366  return iteration
367 
368 
369 def EnumClassKeyVals(cls):
370  # cls can only be derived from object
371  assert type(cls) == type
372  # Enum attribute keys are all capitalized and values are strings
373  enum = {}
374  for k in dir(cls):
375  if k == k.upper():
376  v = getattr(cls, k)
377  if isinstance(v, string_types):
378  assert v not in enum.values(), (
379  "Failed to resolve {} as Enum: "
380  "duplicate entries {}={}, {}={}".format(
381  cls, k, v, [key for key in enum if enum[key] == v][0], v
382  )
383  )
384  enum[k] = v
385  return enum
386 
387 
388 def ArgsToDict(args):
389  """
390  Convert a list of arguments to a name, value dictionary. Assumes that
391  each argument has a name. Otherwise, the argument is skipped.
392  """
393  ans = {}
394  for arg in args:
395  if not arg.HasField("name"):
396  continue
397  for d in arg.DESCRIPTOR.fields:
398  if d.name == "name":
399  continue
400  if d.label == d.LABEL_OPTIONAL and arg.HasField(d.name):
401  ans[arg.name] = getattr(arg, d.name)
402  break
403  elif d.label == d.LABEL_REPEATED:
404  list_ = getattr(arg, d.name)
405  if len(list_) > 0:
406  ans[arg.name] = list_
407  break
408  else:
409  ans[arg.name] = None
410  return ans
411 
412 
413 def NHWC2NCHW(tensor):
414  assert tensor.ndim >= 1
415  return tensor.transpose((0, tensor.ndim - 1) + tuple(range(1, tensor.ndim - 1)))
416 
417 
418 def NCHW2NHWC(tensor):
419  assert tensor.ndim >= 2
420  return tensor.transpose((0,) + tuple(range(2, tensor.ndim)) + (1,))