Caffe2 - Python API
A deep learning, cross platform ML framework
workspace.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package workspace
17 # Module caffe2.python.workspace
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 import contextlib
23 from google.protobuf.message import Message
24 from multiprocessing import Process
25 import os
26 from collections import defaultdict
27 import logging
28 import numpy as np
29 from past.builtins import basestring
30 import shutil
31 import socket
32 import tempfile
33 
34 from caffe2.proto import caffe2_pb2
35 from caffe2.python import scope, utils
36 
38 
39 logger = logging.getLogger(__name__)
40 
41 Blobs = C.blobs
42 CreateBlob = C.create_blob
43 CurrentWorkspace = C.current_workspace
44 DeserializeBlob = C.deserialize_blob
45 GlobalInit = C.global_init
46 HasBlob = C.has_blob
47 RegisteredOperators = C.registered_operators
48 SerializeBlob = C.serialize_blob
49 SwitchWorkspace = C.switch_workspace
50 RootFolder = C.root_folder
51 Workspaces = C.workspaces
52 BenchmarkNet = C.benchmark_net
53 GetStats = C.get_stats
54 
55 operator_tracebacks = defaultdict(dict)
56 
57 is_asan = C.is_asan
58 has_gpu_support = C.has_gpu_support
59 if has_gpu_support:
60  NumCudaDevices = C.num_cuda_devices
61  GetCUDAVersion = C.get_cuda_version
62  GetCuDNNVersion = C.get_cudnn_version
63 
64  def GetCudaPeerAccessPattern():
65  return np.asarray(C.get_cuda_peer_access_pattern())
66 
67  GetDeviceProperties = C.get_device_properties
68 else:
69  NumCudaDevices = lambda: 0 # noqa
70  GetCuDNNVersion = lambda: 0 # noqa
71  GetCuDNNVersion = lambda: 0 # noqa
72  GetCudaPeerAccessPattern = lambda: np.array([]) # noqa
73  GetDeviceProperties = lambda x: None # noqa
74 
75 
76 def _GetFreeFlaskPort():
77  """Get a free flask port."""
78  # We will prefer to use 5000. If not, we will then pick a random port.
79  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
80  result = sock.connect_ex(('127.0.0.1', 5000))
81  if result == 0:
82  return 5000
83  else:
84  s = socket.socket()
85  s.bind(('', 0))
86  port = s.getsockname()[1]
87  s.close()
88  # Race condition: between the interval we close the socket and actually
89  # start a mint process, another process might have occupied the port. We
90  # don't do much here as this is mostly for convenience in research
91  # rather than 24x7 service.
92  return port
93 
94 
95 def StartMint(root_folder=None, port=None):
96  """Start a mint instance.
97 
98  TODO(Yangqing): this does not work well under ipython yet. According to
99  https://github.com/ipython/ipython/issues/5862
100  writing up some fix is a todo item.
101  """
102  from caffe2.python.mint import app
103  if root_folder is None:
104  # Get the root folder from the current workspace
105  root_folder = C.root_folder()
106  if port is None:
107  port = _GetFreeFlaskPort()
108  process = Process(
109  target=app.main,
110  args=(
111  ['-p', str(port), '-r', root_folder],
112  )
113  )
114  process.start()
115  print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
116  return process
117 
118 
119 def StringifyProto(obj):
120  """Stringify a protocol buffer object.
121 
122  Inputs:
123  obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
124  function.
125  Outputs:
126  string: the output protobuf string.
127  Raises:
128  AttributeError: if the passed in object does not have the right attribute.
129  """
130  if isinstance(obj, basestring):
131  return obj
132  else:
133  if isinstance(obj, Message):
134  # First, see if this object is a protocol buffer, which we can
135  # simply serialize with the SerializeToString() call.
136  return obj.SerializeToString()
137  elif hasattr(obj, 'Proto'):
138  return obj.Proto().SerializeToString()
139  else:
140  raise ValueError("Unexpected argument to StringifyProto of type " +
141  type(obj).__name__)
142 
143 
144 def ResetWorkspace(root_folder=None):
145  if root_folder is None:
146  # Reset the workspace, but keep the current root folder setting.
147  return C.reset_workspace(C.root_folder())
148  else:
149  if not os.path.exists(root_folder):
150  os.makedirs(root_folder)
151  return C.reset_workspace(root_folder)
152 
153 
154 def CreateNet(net, overwrite=False, input_blobs=None):
155  if input_blobs is None:
156  input_blobs = []
157  for input_blob in input_blobs:
158  C.create_blob(input_blob)
159  return CallWithExceptionIntercept(
160  C.create_net,
161  C.Workspace.current._last_failed_op_net_position,
162  GetNetName(net),
163  StringifyProto(net), overwrite,
164  )
165 
166 
167 def Predictor(init_net, predict_net):
168  return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
169 
170 
171 def GetOperatorCost(operator, blobs):
172  return C.get_operator_cost(StringifyProto(operator), blobs)
173 
174 
175 def RunOperatorOnce(operator):
176  return C.run_operator_once(StringifyProto(operator))
177 
178 
179 def RunOperatorsOnce(operators):
180  for op in operators:
181  success = RunOperatorOnce(op)
182  if not success:
183  return False
184  return True
185 
186 
187 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
188  try:
189  return func(*args, **kwargs)
190  except Exception:
191  op_id = op_id_fetcher()
192  net_tracebacks = operator_tracebacks.get(net_name, None)
193  print("Traceback for operator {} in network {}".format(op_id, net_name))
194  if net_tracebacks and op_id in net_tracebacks:
195  tb = net_tracebacks[op_id]
196  for line in tb:
197  print(':'.join(map(str, line)))
198  raise
199 
200 
201 def RunNetOnce(net):
202  return CallWithExceptionIntercept(
203  C.run_net_once,
204  C.Workspace.current._last_failed_op_net_position,
205  GetNetName(net),
206  StringifyProto(net),
207  )
208 
209 
210 def RunNet(name, num_iter=1, allow_fail=False):
211  """Runs a given net.
212 
213  Inputs:
214  name: the name of the net, or a reference to the net.
215  num_iter: number of iterations to run
216  allow_fail: if True, does not assert on net exec failure but returns False
217  Returns:
218  True or an exception.
219  """
220  return CallWithExceptionIntercept(
221  C.run_net,
222  C.Workspace.current._last_failed_op_net_position,
223  GetNetName(name),
224  StringifyNetName(name), num_iter, allow_fail,
225  )
226 
227 
228 def RunPlan(plan_or_step):
229  # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
230  import caffe2.python.core as core
231  if isinstance(plan_or_step, core.ExecutionStep):
232  plan_or_step = core.Plan(plan_or_step)
233  return C.run_plan(StringifyProto(plan_or_step))
234 
235 
236 def InferShapesAndTypes(nets, blob_dimensions=None):
237  """Infers the shapes and types for the specified nets.
238 
239  Inputs:
240  nets: the list of nets
241  blob_dimensions (optional): a dictionary of blobs and their dimensions.
242  If not specified, the workspace blobs are used.
243  Returns:
244  A tuple of (shapes, types) dictionaries keyed by blob name.
245  """
246  net_protos = [StringifyProto(n.Proto()) for n in nets]
247  if blob_dimensions is None:
248  blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
249  else:
250  blobdesc_prototxt = C.infer_shapes_and_types_from_map(
251  net_protos, blob_dimensions
252  )
253  blobdesc_proto = caffe2_pb2.TensorShapes()
254  blobdesc_proto.ParseFromString(blobdesc_prototxt)
255  shapes = {}
256  types = {}
257  for ts in blobdesc_proto.shapes:
258  if not ts.unknown_shape:
259  shapes[ts.name] = list(ts.dims)
260  types[ts.name] = ts.data_type
261 
262  return (shapes, types)
263 
264 
265 def _StringifyName(name, expected_type):
266  if isinstance(name, basestring):
267  return name
268  assert type(name).__name__ == expected_type, \
269  "Expected a string or %s" % expected_type
270  return str(name)
271 
272 
273 def StringifyBlobName(name):
274  return _StringifyName(name, "BlobReference")
275 
276 
277 def StringifyNetName(name):
278  return _StringifyName(name, "Net")
279 
280 
281 def GetNetName(net):
282  if isinstance(net, basestring):
283  return net
284  if type(net).__name__ == "Net":
285  return net.Name()
286  if isinstance(net, caffe2_pb2.NetDef):
287  return net.name
288  raise Exception("Not a Net object: {}".format(str(net)))
289 
290 
291 def FeedBlob(name, arr, device_option=None):
292  """Feeds a blob into the workspace.
293 
294  Inputs:
295  name: the name of the blob.
296  arr: either a TensorProto object or a numpy array object to be fed into
297  the workspace.
298  device_option (optional): the device option to feed the data with.
299  Returns:
300  True or False, stating whether the feed is successful.
301  """
302  if type(arr) is caffe2_pb2.TensorProto:
303  arr = utils.Caffe2TensorToNumpyArray(arr)
304  if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
305  # Plain NumPy strings are weird, let's use objects instead
306  arr = arr.astype(np.object)
307 
308  if device_option is None:
309  device_option = scope.CurrentDeviceScope()
310 
311  if device_option and device_option.device_type == caffe2_pb2.CUDA:
312  if arr.dtype == np.dtype('float64'):
313  logger.warning(
314  "CUDA operators do not support 64-bit doubles, " +
315  "please use arr.astype(np.float32) or np.int32 for ints." +
316  " Blob: {}".format(name) +
317  " type: {}".format(str(arr.dtype))
318  )
319 
320  name = StringifyBlobName(name)
321  if device_option is not None:
322  return C.feed_blob(name, arr, StringifyProto(device_option))
323  else:
324  return C.feed_blob(name, arr)
325 
326 
327 def FetchBlobs(names):
328  """Fetches a list of blobs from the workspace.
329 
330  Inputs:
331  names: list of names of blobs - strings or BlobReferences
332  Returns:
333  list of fetched blobs
334  """
335  return [FetchBlob(name) for name in names]
336 
337 
338 def FetchBlob(name):
339  """Fetches a blob from the workspace.
340 
341  Inputs:
342  name: the name of the blob - a string or a BlobReference
343  Returns:
344  Fetched blob (numpy array or string) if successful
345  """
346  return C.fetch_blob(StringifyBlobName(name))
347 
348 
349 def ApplyTransform(transform_key, net):
350  """Apply a Transform to a NetDef protobuf object, and returns the new
351  transformed NetDef.
352 
353  Inputs:
354  transform_key: the name of the transform, as it is stored in the registry
355  net: a NetDef protobuf object
356  Returns:
357  Transformed NetDef protobuf object.
358  """
359  transformed_net = caffe2_pb2.NetDef()
360  transformed_str = C.apply_transform(
361  str(transform_key).encode('utf-8'),
362  net.SerializeToString(),
363  )
364  transformed_net.ParseFromString(transformed_str)
365  return transformed_net
366 
367 
368 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
369  """Apply a Transform to a NetDef protobuf object, and returns the new
370  transformed NetDef, only if it runs faster than the original.
371 
372  The runs are performed on the current active workspace (gWorkspace).
373  You should initialize that workspace before making a call to this function.
374 
375  Inputs:
376  transform_key: the name of the transform, as it is stored in the registry
377  net: a NetDef protobuf object
378  init_net: The net to initialize the workspace.
379  warmup_runs (optional):
380  Determines how many times the net is run before testing.
381  Will be 5 by default.
382  main_runs (optional):
383  Determines how many times the net is run during testing.
384  Will be 10 by default.
385  improvement_threshold (optional):
386  Determines the factor which the new net needs to be faster
387  in order to replace the old. Will be 1.01 by default.
388 
389  Returns:
390  Either a Transformed NetDef protobuf object, or the original netdef.
391  """
392 
393  warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
394  main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
395  improvement_threshold = kwargs['improvement_threshold'] \
396  if 'improvement_threshold' in kwargs else 1.01
397 
398  transformed_net = caffe2_pb2.NetDef()
399  transformed_str = C.apply_transform_if_faster(
400  str(transform_key).encode('utf-8'),
401  net.SerializeToString(),
402  init_net.SerializeToString(),
403  warmup_runs,
404  main_runs,
405  float(improvement_threshold),
406  )
407  transformed_net.ParseFromString(transformed_str)
408  return transformed_net
409 
410 
411 def GetNameScope():
412  """Return the current namescope string. To be used to fetch blobs"""
413  return scope.CurrentNameScope()
414 
415 
416 class _BlobDict(object):
417  """Provides python dict compatible way to do fetching and feeding"""
418 
419  def __getitem__(self, key):
420  return FetchBlob(key)
421 
422  def __setitem__(self, key, value):
423  return FeedBlob(key, value)
424 
425  def __len__(self):
426  return len(C.blobs())
427 
428  def __iter__(self):
429  return C.blobs().__iter__()
430 
431  def __contains__(self, item):
432  return C.has_blob(item)
433 
434 
435 blobs = _BlobDict()
436 
437 
438 ################################################################################
439 # Utilities for immediate mode
440 #
441 # Caffe2's immediate mode implements the following behavior: between the two
442 # function calls StartImmediate() and StopImmediate(), for any operator that is
443 # called through CreateOperator(), we will also run that operator in a workspace
444 # that is specific to the immediate mode. The user is explicitly expected to
445 # make sure that these ops have proper inputs and outputs, i.e. one should not
446 # run an op where an external input is not created or fed.
447 #
448 # Users can use FeedImmediate() and FetchImmediate() to interact with blobs
449 # in the immediate workspace.
450 #
451 # Once StopImmediate() is called, all contents in the immediate workspace is
452 # freed up so one can continue using normal runs.
453 #
454 # The immediate mode is solely for debugging purposes and support will be very
455 # sparse.
456 ################################################################################
457 
458 _immediate_mode = False
459 _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
460 _immediate_root_folder = ''
461 
462 
463 def IsImmediate():
464  return _immediate_mode
465 
466 
467 @contextlib.contextmanager
468 def WorkspaceGuard(workspace_name):
469  current = CurrentWorkspace()
470  SwitchWorkspace(workspace_name, True)
471  yield
472  SwitchWorkspace(current)
473 
474 
475 def StartImmediate(i_know=False):
476  global _immediate_mode
477  global _immediate_root_folder
478  if IsImmediate():
479  # already in immediate mode. We will kill the previous one
480  # and start from fresh.
481  StopImmediate()
482  _immediate_mode = True
483  with WorkspaceGuard(_immediate_workspace_name):
484  _immediate_root_folder = tempfile.mkdtemp()
485  ResetWorkspace(_immediate_root_folder)
486  if i_know:
487  # if the user doesn't want to see the warning message, sure...
488  return
489  print("""
490  Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
491  feature and may very easily go wrong. This is because Caffe2 uses a
492  declarative way of defining operators and models, which is essentially
493  not meant to run things in an interactive way. Read the following carefully
494  to make sure that you understand the caveats.
495 
496  (1) You need to make sure that the sequences of operators you create are
497  actually runnable sequentially. For example, if you create an op that takes
498  an input X, somewhere earlier you should have already created X.
499 
500  (2) Caffe2 immediate uses one single workspace, so if the set of operators
501  you run are intended to be under different workspaces, they will not run.
502  To create boundaries between such use cases, you can call FinishImmediate()
503  and StartImmediate() manually to flush out everything no longer needed.
504 
505  (3) Underlying objects held by the immediate mode may interfere with your
506  normal run. For example, if there is a leveldb that you opened in immediate
507  mode and did not close, your main run will fail because leveldb does not
508  support double opening. Immediate mode may also occupy a lot of memory esp.
509  on GPUs. Call FinishImmediate() as soon as possible when you no longer
510  need it.
511 
512  (4) Immediate is designed to be slow. Every immediate call implicitly
513  creates a temp operator object, runs it, and destroys the operator. This
514  slow-speed run is by design to discourage abuse. For most use cases other
515  than debugging, do NOT turn on immediate mode.
516 
517  (5) If there is anything FATAL happening in the underlying C++ code, the
518  immediate mode will immediately (pun intended) cause the runtime to crash.
519 
520  Thus you should use immediate mode with extra care. If you still would
521  like to, have fun [https://xkcd.com/149/].
522  """)
523 
524 
525 def StopImmediate():
526  """Stops an immediate mode run."""
527  # Phew, that was a dangerous ride.
528  global _immediate_mode
529  global _immediate_root_folder
530  if not IsImmediate():
531  return
532  with WorkspaceGuard(_immediate_workspace_name):
533  ResetWorkspace()
534  shutil.rmtree(_immediate_root_folder)
535  _immediate_root_folder = ''
536  _immediate_mode = False
537 
538 
539 def ImmediateBlobs():
540  with WorkspaceGuard(_immediate_workspace_name):
541  return Blobs()
542 
543 
544 def RunOperatorImmediate(op):
545  with WorkspaceGuard(_immediate_workspace_name):
546  RunOperatorOnce(op)
547 
548 
549 def FetchImmediate(*args, **kwargs):
550  with WorkspaceGuard(_immediate_workspace_name):
551  return FetchBlob(*args, **kwargs)
552 
553 
554 def FeedImmediate(*args, **kwargs):
555  with WorkspaceGuard(_immediate_workspace_name):
556  return FeedBlob(*args, **kwargs)
557 
558 
559 # CWorkspace utilities
560 
561 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
562  return CallWithExceptionIntercept(
563  ws._create_net,
564  ws._last_failed_op_net_position,
565  GetNetName(net),
566  StringifyProto(net), overwrite,
567  )
568 
569 
570 C.Workspace.create_net = _Workspace_create_net_with_exception_intercept
571 
572 
573 def _Workspace_run(ws, obj):
574  if hasattr(obj, 'Proto'):
575  obj = obj.Proto()
576  if isinstance(obj, caffe2_pb2.PlanDef):
577  return ws._run_plan(obj.SerializeToString())
578  if isinstance(obj, caffe2_pb2.NetDef):
579  return CallWithExceptionIntercept(
580  ws._run_net,
581  ws._last_failed_op_net_position,
582  GetNetName(obj),
583  obj.SerializeToString(),
584  )
585  # return ws._run_net(obj.SerializeToString())
586  if isinstance(obj, caffe2_pb2.OperatorDef):
587  return ws._run_operator(obj.SerializeToString())
588  raise ValueError(
589  "Don't know how to do Workspace.run() on {}".format(type(obj)))
590 
591 
592 C.Workspace.run = _Workspace_run
593 
594 
595 def _Blob_feed(blob, arg, device_option=None):
596  if device_option is not None:
597  device_option = StringifyProto(device_option)
598  return blob._feed(arg, device_option)
599 
600 
601 C.Blob.feed = _Blob_feed