Caffe2 - Python API
A deep learning, cross platform ML framework
utils.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package utils
17 # Module caffe2.python.utils
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 from caffe2.proto import caffe2_pb2
23 from future.utils import viewitems
24 from google.protobuf.message import DecodeError, Message
25 from google.protobuf import text_format
26 import sys
27 import collections
28 import functools
29 import numpy as np
30 from six import integer_types, binary_type, text_type
31 
32 
33 def CaffeBlobToNumpyArray(blob):
34  if (blob.num != 0):
35  # old style caffe blob.
36  return (np.asarray(blob.data, dtype=np.float32)
37  .reshape(blob.num, blob.channels, blob.height, blob.width))
38  else:
39  # new style caffe blob.
40  return (np.asarray(blob.data, dtype=np.float32)
41  .reshape(blob.shape.dim))
42 
43 
44 def Caffe2TensorToNumpyArray(tensor):
45  if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
46  return np.asarray(
47  tensor.float_data, dtype=np.float32).reshape(tensor.dims)
48  elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
49  return np.asarray(
50  tensor.double_data, dtype=np.float64).reshape(tensor.dims)
51  elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
52  return np.asarray(
53  tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data
54  elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
55  return np.asarray(
56  tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data
57  elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
58  return np.asarray(
59  tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data
60  elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
61  return np.asarray(
62  tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data
63  elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
64  return np.asarray(
65  tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data
66  else:
67  # TODO: complete the data type: bool, float16, byte, int64, string
68  raise RuntimeError(
69  "Tensor data type not supported yet: " + str(tensor.data_type))
70 
71 
72 def NumpyArrayToCaffe2Tensor(arr, name=None):
73  tensor = caffe2_pb2.TensorProto()
74  tensor.dims.extend(arr.shape)
75  if name:
76  tensor.name = name
77  if arr.dtype == np.float32:
78  tensor.data_type = caffe2_pb2.TensorProto.FLOAT
79  tensor.float_data.extend(list(arr.flatten().astype(float)))
80  elif arr.dtype == np.float64:
81  tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
82  tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
83  elif arr.dtype == np.int or arr.dtype == np.int32:
84  tensor.data_type = caffe2_pb2.TensorProto.INT32
85  tensor.int32_data.extend(list(arr.flatten().astype(np.int)))
86  elif arr.dtype == np.int16:
87  tensor.data_type = caffe2_pb2.TensorProto.INT16
88  tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data
89  elif arr.dtype == np.uint16:
90  tensor.data_type = caffe2_pb2.TensorProto.UINT16
91  tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data
92  elif arr.dtype == np.int8:
93  tensor.data_type = caffe2_pb2.TensorProto.INT8
94  tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data
95  elif arr.dtype == np.uint8:
96  tensor.data_type = caffe2_pb2.TensorProto.UINT8
97  tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data
98  else:
99  # TODO: complete the data type: bool, float16, byte, int64, string
100  raise RuntimeError(
101  "Numpy data type not supported yet: " + str(arr.dtype))
102  return tensor
103 
104 
105 def MakeArgument(key, value):
106  """Makes an argument based on the value type."""
107  argument = caffe2_pb2.Argument()
108  argument.name = key
109  iterable = isinstance(value, collections.Iterable)
110 
111  # Fast tracking common use case where a float32 array of tensor parameters
112  # needs to be serialized. The entire array is guaranteed to have the same
113  # dtype, so no per-element checking necessary and no need to convert each
114  # element separately.
115  if isinstance(value, np.ndarray) and value.dtype.type is np.float32:
116  argument.floats.extend(value.flatten().tolist())
117  return argument
118 
119  if isinstance(value, np.ndarray):
120  value = value.flatten().tolist()
121  elif isinstance(value, np.generic):
122  # convert numpy scalar to native python type
123  value = np.asscalar(value)
124 
125  if type(value) is float:
126  argument.f = value
127  elif type(value) in integer_types or type(value) is bool:
128  # We make a relaxation that a boolean variable will also be stored as
129  # int.
130  argument.i = value
131  elif isinstance(value, binary_type):
132  argument.s = value
133  elif isinstance(value, text_type):
134  argument.s = value.encode('utf-8')
135  elif isinstance(value, caffe2_pb2.NetDef):
136  argument.n.CopyFrom(value)
137  elif isinstance(value, Message):
138  argument.s = value.SerializeToString()
139  elif iterable and all(type(v) in [float, np.float_] for v in value):
140  argument.floats.extend(
141  v.item() if type(v) is np.float_ else v for v in value
142  )
143  elif iterable and all(
144  type(v) in integer_types or type(v) in [bool, np.int_] for v in value
145  ):
146  argument.ints.extend(
147  v.item() if type(v) is np.int_ else v for v in value
148  )
149  elif iterable and all(
150  isinstance(v, binary_type) or isinstance(v, text_type) for v in value
151  ):
152  argument.strings.extend(
153  v.encode('utf-8') if isinstance(v, text_type) else v
154  for v in value
155  )
156  elif iterable and all(isinstance(v, caffe2_pb2.NetDef) for v in value):
157  argument.nets.extend(value)
158  elif iterable and all(isinstance(v, Message) for v in value):
159  argument.strings.extend(v.SerializeToString() for v in value)
160  else:
161  if iterable:
162  raise ValueError(
163  "Unknown iterable argument type: key={} value={}, value "
164  "type={}[{}]".format(
165  key, value, type(value), set(type(v) for v in value)
166  )
167  )
168  else:
169  raise ValueError(
170  "Unknown argument type: key={} value={}, value type={}".format(
171  key, value, type(value)
172  )
173  )
174  return argument
175 
176 
177 def TryReadProtoWithClass(cls, s):
178  """Reads a protobuffer with the given proto class.
179 
180  Inputs:
181  cls: a protobuffer class.
182  s: a string of either binary or text protobuffer content.
183 
184  Outputs:
185  proto: the protobuffer of cls
186 
187  Throws:
188  google.protobuf.message.DecodeError: if we cannot decode the message.
189  """
190  obj = cls()
191  try:
192  text_format.Parse(s, obj)
193  return obj
194  except text_format.ParseError:
195  obj.ParseFromString(s)
196  return obj
197 
198 
199 def GetContentFromProto(obj, function_map):
200  """Gets a specific field from a protocol buffer that matches the given class
201  """
202  for cls, func in viewitems(function_map):
203  if type(obj) is cls:
204  return func(obj)
205 
206 
207 def GetContentFromProtoString(s, function_map):
208  for cls, func in viewitems(function_map):
209  try:
210  obj = TryReadProtoWithClass(cls, s)
211  return func(obj)
212  except DecodeError:
213  continue
214  else:
215  raise DecodeError("Cannot find a fit protobuffer class.")
216 
217 
218 def ConvertProtoToBinary(proto_class, filename, out_filename):
219  """Convert a text file of the given protobuf class to binary."""
220  proto = TryReadProtoWithClass(proto_class, open(filename).read())
221  with open(out_filename, 'w') as fid:
222  fid.write(proto.SerializeToString())
223 
224 
225 def GetGPUMemoryUsageStats():
226  """Get GPU memory usage stats from CUDAContext. This requires flag
227  --caffe2_gpu_memory_tracking to be enabled"""
228  from caffe2.python import workspace, core
229  workspace.RunOperatorOnce(
230  core.CreateOperator(
231  "GetGPUMemoryUsage",
232  [],
233  ["____mem____"],
234  device_option=core.DeviceOption(caffe2_pb2.CUDA, 0),
235  ),
236  )
237  b = workspace.FetchBlob("____mem____")
238  return {
239  'total_by_gpu': b[0, :],
240  'max_by_gpu': b[1, :],
241  'total': np.sum(b[0, :]),
242  'max_total': np.sum(b[1, :])
243  }
244 
245 
246 def ResetBlobs(blobs):
247  from caffe2.python import workspace, core
248  workspace.RunOperatorOnce(
249  core.CreateOperator(
250  "Free",
251  list(blobs),
252  list(blobs),
253  device_option=core.DeviceOption(caffe2_pb2.CPU),
254  ),
255  )
256 
257 
258 class DebugMode(object):
259  '''
260  This class allows to drop you into an interactive debugger
261  if there is an unhandled exception in your python script
262 
263  Example of usage:
264 
265  def main():
266  # your code here
267  pass
268 
269  if __name__ == '__main__':
270  from caffe2.python.utils import DebugMode
271  DebugMode.run(main)
272  '''
273 
274  @classmethod
275  def run(cls, func):
276  try:
277  return func()
278  except KeyboardInterrupt:
279  raise
280  except Exception:
281  import pdb
282 
283  print(
284  'Entering interactive debugger. Type "bt" to print '
285  'the full stacktrace. Type "help" to see command listing.')
286  print(sys.exc_info()[1])
287  print
288 
289  pdb.post_mortem()
290  sys.exit(1)
291  raise
292 
293 def raiseIfNotEqual(a, b, msg):
294  if a != b:
295  raise Exception("{}. {} != {}".format(msg, a, b))
296 
297 def debug(f):
298  '''
299  Use this method to decorate your function with DebugMode's functionality
300 
301  Example:
302 
303  @debug
304  def test_foo(self):
305  raise Exception("Bar")
306 
307  '''
308 
309  @functools.wraps(f)
310  def wrapper(*args, **kwargs):
311  def func():
312  return f(*args, **kwargs)
313  DebugMode.run(func)
314 
315  return wrapper