3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from google.protobuf.message
import Message
10 from multiprocessing
import Process
12 from collections
import defaultdict
15 from past.builtins
import basestring
20 from caffe2.proto
import caffe2_pb2
25 logger = logging.getLogger(__name__)
28 CreateBlob = C.create_blob
29 CurrentWorkspace = C.current_workspace
30 DeserializeBlob = C.deserialize_blob
31 GlobalInit = C.global_init
33 RegisteredOperators = C.registered_operators
34 SerializeBlob = C.serialize_blob
35 SwitchWorkspace = C.switch_workspace
36 RootFolder = C.root_folder
37 Workspaces = C.workspaces
38 BenchmarkNet = C.benchmark_net
39 GetStats = C.get_stats
41 operator_tracebacks = defaultdict(dict)
44 has_cuda_support = C.has_cuda_support
45 has_hip_support = C.has_hip_support
46 has_gpu_support = C.has_gpu_support
48 GpuDeviceType = caffe2_pb2.CUDA
49 NumCudaDevices = C.num_cuda_devices
52 NumGpuDevices = C.num_cuda_devices
53 GetCUDAVersion = C.get_cuda_version
54 GetCuDNNVersion = C.get_cudnn_version
56 def GetGpuPeerAccessPattern():
57 return np.asarray(C.get_cuda_peer_access_pattern())
59 GetDeviceProperties = C.get_device_properties
61 NumCudaDevices =
lambda: 0
62 GetCUDAVersion =
lambda: 0
63 GetCuDNNVersion =
lambda: 0
66 GpuDeviceType = caffe2_pb2.HIP
67 NumGpuDevices = C.num_hip_devices
69 def GetGpuPeerAccessPattern():
70 return np.asarray(C.get_hip_peer_access_pattern())
71 GetDeviceProperties = C.get_device_properties
73 if not has_gpu_support:
76 GpuDeviceType = caffe2_pb2.CUDA
77 NumGpuDevices =
lambda: 0
78 GetDeviceProperties =
lambda x:
None 79 GetGpuPeerAccessPattern =
lambda: np.array([])
81 IsNUMAEnabled = C.is_numa_enabled
82 GetNumNUMANodes = C.get_num_numa_nodes
83 GetBlobNUMANode = C.get_blob_numa_node
84 GetBlobSizeBytes = C.get_blob_size_bytes
86 def _GetFreeFlaskPort():
87 """Get a free flask port.""" 89 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
90 result = sock.connect_ex((
'127.0.0.1', 5000))
96 port = s.getsockname()[1]
104 def StartMint(root_folder=None, port=None):
105 """Start a mint instance. 107 TODO(Yangqing): this does not work well under ipython yet. According to 108 https://github.com/ipython/ipython/issues/5862 109 writing up some fix is a todo item. 112 if root_folder
is None:
114 root_folder = C.root_folder()
116 port = _GetFreeFlaskPort()
120 [
'-p', str(port),
'-r', root_folder],
124 print(
'Mint running at http://{}:{}'.format(socket.getfqdn(), port))
128 def StringifyProto(obj):
129 """Stringify a protocol buffer object. 132 obj: a protocol buffer object, or a Pycaffe2 object that has a Proto() 135 string: the output protobuf string. 137 AttributeError: if the passed in object does not have the right attribute. 139 if isinstance(obj, basestring):
142 if isinstance(obj, Message):
145 return obj.SerializeToString()
146 elif hasattr(obj,
'Proto'):
147 return obj.Proto().SerializeToString()
149 raise ValueError(
"Unexpected argument to StringifyProto of type " +
153 def ResetWorkspace(root_folder=None):
154 if root_folder
is None:
156 return C.reset_workspace(C.root_folder())
158 if not os.path.exists(root_folder):
159 os.makedirs(root_folder)
160 return C.reset_workspace(root_folder)
163 def CreateNet(net, overwrite=False, input_blobs=None):
164 if input_blobs
is None:
166 for input_blob
in input_blobs:
167 C.create_blob(input_blob)
168 return CallWithExceptionIntercept(
170 C.Workspace.current._last_failed_op_net_position,
172 StringifyProto(net), overwrite,
176 def Predictor(init_net, predict_net):
177 return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
180 def GetOperatorCost(operator, blobs):
181 return C.get_operator_cost(StringifyProto(operator), blobs)
184 def RunOperatorOnce(operator):
185 return C.run_operator_once(StringifyProto(operator))
188 def RunOperatorsOnce(operators):
190 success = RunOperatorOnce(op)
196 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
198 return func(*args, **kwargs)
200 op_id = op_id_fetcher()
201 net_tracebacks = operator_tracebacks.get(net_name,
None)
203 'Original python traceback for operator `{}` in network ' 204 '`{}` in exception above (most recent call last):'.format(
206 if net_tracebacks
and op_id
in net_tracebacks:
207 tb = net_tracebacks[op_id]
208 for line
in reversed(tb):
209 logger.warning(
' File "{}", line {}, in {}'.format(
210 line[0], line[1], line[2]))
215 return CallWithExceptionIntercept(
217 C.Workspace.current._last_failed_op_net_position,
223 def RunNet(name, num_iter=1, allow_fail=False):
227 name: the name of the net, or a reference to the net. 228 num_iter: number of iterations to run 229 allow_fail: if True, does not assert on net exec failure but returns False 231 True or an exception. 233 return CallWithExceptionIntercept(
235 C.Workspace.current._last_failed_op_net_position,
237 StringifyNetName(name), num_iter, allow_fail,
241 def RunPlan(plan_or_step):
244 if isinstance(plan_or_step, core.ExecutionStep):
245 plan_or_step = core.Plan(plan_or_step)
246 return C.run_plan(StringifyProto(plan_or_step))
249 def RunPlanInBackground(plan_or_step):
252 if isinstance(plan_or_step, core.ExecutionStep):
253 plan_or_step = core.Plan(plan_or_step)
254 return C.run_plan_in_background(StringifyProto(plan_or_step))
257 def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
259 """Infers the shapes and types for the specified nets. 262 nets: the list of nets 263 blob_dimensions (optional): a dictionary of blobs and their dimensions. 264 If not specified, the workspace blobs are used. 265 nets_proto (optional): a boolean flag indicating whether the protobuffer 266 representation is passed to the routine. 268 A tuple of (shapes, types) dictionaries keyed by blob name. 271 net_protos = [StringifyProto(n)
for n
in nets]
273 net_protos = [StringifyProto(n.Proto())
for n
in nets]
274 if blob_dimensions
is None:
275 assert blob_types
is None 276 blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
277 elif blob_types
is None:
278 blobdesc_prototxt = C.infer_shapes_and_types_from_map(
279 net_protos, blob_dimensions
282 blobdesc_prototxt = C.infer_shapes_and_types_from_map(
283 net_protos, blob_dimensions, blob_types
285 blobdesc_proto = caffe2_pb2.TensorShapes()
286 blobdesc_proto.ParseFromString(blobdesc_prototxt)
289 for ts
in blobdesc_proto.shapes:
290 if not ts.unknown_shape:
291 shapes[ts.name] = list(ts.dims)
292 types[ts.name] = ts.data_type
294 return (shapes, types)
297 def _StringifyName(name, expected_type):
298 if isinstance(name, basestring):
300 assert type(name).__name__ == expected_type, \
301 "Expected a string or %s" % expected_type
305 def StringifyBlobName(name):
306 return _StringifyName(name,
"BlobReference")
309 def StringifyNetName(name):
310 return _StringifyName(name,
"Net")
314 if isinstance(net, basestring):
316 if type(net).__name__ ==
"Net":
318 if isinstance(net, caffe2_pb2.NetDef):
320 raise Exception(
"Not a Net object: {}".format(str(net)))
323 def FeedBlob(name, arr, device_option=None):
324 """Feeds a blob into the workspace. 327 name: the name of the blob. 328 arr: either a TensorProto object or a numpy array object to be fed into 330 device_option (optional): the device option to feed the data with. 332 True or False, stating whether the feed is successful. 334 ws = C.Workspace.current
335 return _Workspace_feed_blob(ws, name, arr, device_option)
338 def FetchBlobs(names):
339 """Fetches a list of blobs from the workspace. 342 names: list of names of blobs - strings or BlobReferences 344 list of fetched blobs 346 return [FetchBlob(name)
for name
in names]
350 """Fetches a blob from the workspace. 353 name: the name of the blob - a string or a BlobReference 355 Fetched blob (numpy array or string) if successful 357 result = C.fetch_blob(StringifyBlobName(name))
358 if isinstance(result, tuple):
360 "Use FetchInt8Blob to fetch Int8 Blob {}".format(
361 StringifyBlobName(name)
367 def FetchTorch(name):
368 ws = C.Workspace.current
369 return ws.blobs[name].to_torch()
372 Int8Tensor = collections.namedtuple(
373 'Int8Tensor', [
'data',
'scale',
'zero_point']
377 def FetchInt8Blob(name):
378 """Fetches an Int8 blob from the workspace. It shared backend implementation 379 with FetchBlob but it is recommened when fetching Int8 Blobs 382 name: the name of the Int8 blob - a string or a BlobReference 384 data: int8 numpy array, data 385 scale: float, fake quantization scale 386 zero_point: int, fake quantization offset 388 result = C.fetch_blob(StringifyBlobName(name))
389 assert isinstance(result, tuple), \
390 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
391 StringifyBlobName(name))
392 return Int8Tensor(*result)
395 def FetchInt8BlobRealVal(name):
396 """Fetches an Int8 blob from the workspace and return its real value representation. 399 name: the name of the Int8 blob - a string or a BlobReference 401 real value representation of int8 numpy array 403 result = C.fetch_blob(StringifyBlobName(name))
404 assert isinstance(result, tuple), \
405 'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
406 StringifyBlobName(name))
407 int8_blob = Int8Tensor(*result)
408 return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
409 np.float32) * int8_blob.scale
412 def _Workspace_fetch_int8_blob(ws, name):
413 """Fetches an Int8 blob from the workspace. It shared backend implementation 414 with FetchBlob but it is recommened when fetching Int8 Blobs 417 name: the name of the Int8 blob - a string or a BlobReference 419 data: int8 numpy array, data 420 scale: float, fake quantization scale 421 zero_point: int, fake quantization offset 423 result = ws.fetch_blob(name)
424 assert isinstance(result, tuple), \
425 'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
426 StringifyBlobName(name))
427 return Int8Tensor(*result)
430 C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
433 def ApplyTransform(transform_key, net):
434 """Apply a Transform to a NetDef protobuf object, and returns the new 438 transform_key: the name of the transform, as it is stored in the registry 439 net: a NetDef protobuf object 441 Transformed NetDef protobuf object. 443 transformed_net = caffe2_pb2.NetDef()
444 transformed_str = C.apply_transform(
445 str(transform_key).encode(
'utf-8'),
446 net.SerializeToString(),
448 transformed_net.ParseFromString(transformed_str)
449 return transformed_net
452 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
453 """Apply a Transform to a NetDef protobuf object, and returns the new 454 transformed NetDef, only if it runs faster than the original. 456 The runs are performed on the current active workspace (gWorkspace). 457 You should initialize that workspace before making a call to this function. 460 transform_key: the name of the transform, as it is stored in the registry 461 net: a NetDef protobuf object 462 init_net: The net to initialize the workspace. 463 warmup_runs (optional): 464 Determines how many times the net is run before testing. 465 Will be 5 by default. 466 main_runs (optional): 467 Determines how many times the net is run during testing. 468 Will be 10 by default. 469 improvement_threshold (optional): 470 Determines the factor which the new net needs to be faster 471 in order to replace the old. Will be 1.01 by default. 474 Either a Transformed NetDef protobuf object, or the original netdef. 477 warmup_runs = kwargs[
'warmup_runs']
if 'warmup_runs' in kwargs
else 5
478 main_runs = kwargs[
'main_runs']
if 'main_runs' in kwargs
else 10
479 improvement_threshold = kwargs[
'improvement_threshold'] \
480 if 'improvement_threshold' in kwargs
else 1.01
482 transformed_net = caffe2_pb2.NetDef()
483 transformed_str = C.apply_transform_if_faster(
484 str(transform_key).encode(
'utf-8'),
485 net.SerializeToString(),
486 init_net.SerializeToString(),
489 float(improvement_threshold),
491 transformed_net.ParseFromString(transformed_str)
492 return transformed_net
496 """Return the current namescope string. To be used to fetch blobs""" 497 return scope.CurrentNameScope()
501 """Provides python dict compatible way to do fetching and feeding""" 503 def __getitem__(self, key):
504 return FetchBlob(key)
506 def __setitem__(self, key, value):
507 return FeedBlob(key, value)
510 return len(C.blobs())
513 return C.blobs().__iter__()
515 def __contains__(self, item):
516 return C.has_blob(item)
542 _immediate_mode =
False 543 _immediate_workspace_name =
"_CAFFE2_IMMEDIATE" 544 _immediate_root_folder =
'' 548 return _immediate_mode
551 @contextlib.contextmanager
552 def WorkspaceGuard(workspace_name):
553 current = CurrentWorkspace()
554 SwitchWorkspace(workspace_name,
True)
556 SwitchWorkspace(current)
559 def StartImmediate(i_know=False):
560 global _immediate_mode
561 global _immediate_root_folder
566 _immediate_mode =
True 567 with WorkspaceGuard(_immediate_workspace_name):
568 _immediate_root_folder = tempfile.mkdtemp()
569 ResetWorkspace(_immediate_root_folder)
574 Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL 575 feature and may very easily go wrong. This is because Caffe2 uses a 576 declarative way of defining operators and models, which is essentially 577 not meant to run things in an interactive way. Read the following carefully 578 to make sure that you understand the caveats. 580 (1) You need to make sure that the sequences of operators you create are 581 actually runnable sequentially. For example, if you create an op that takes 582 an input X, somewhere earlier you should have already created X. 584 (2) Caffe2 immediate uses one single workspace, so if the set of operators 585 you run are intended to be under different workspaces, they will not run. 586 To create boundaries between such use cases, you can call FinishImmediate() 587 and StartImmediate() manually to flush out everything no longer needed. 589 (3) Underlying objects held by the immediate mode may interfere with your 590 normal run. For example, if there is a leveldb that you opened in immediate 591 mode and did not close, your main run will fail because leveldb does not 592 support double opening. Immediate mode may also occupy a lot of memory esp. 593 on GPUs. Call FinishImmediate() as soon as possible when you no longer 596 (4) Immediate is designed to be slow. Every immediate call implicitly 597 creates a temp operator object, runs it, and destroys the operator. This 598 slow-speed run is by design to discourage abuse. For most use cases other 599 than debugging, do NOT turn on immediate mode. 601 (5) If there is anything FATAL happening in the underlying C++ code, the 602 immediate mode will immediately (pun intended) cause the runtime to crash. 604 Thus you should use immediate mode with extra care. If you still would 605 like to, have fun [https://xkcd.com/149/]. 610 """Stops an immediate mode run.""" 612 global _immediate_mode
613 global _immediate_root_folder
614 if not IsImmediate():
616 with WorkspaceGuard(_immediate_workspace_name):
618 shutil.rmtree(_immediate_root_folder)
619 _immediate_root_folder =
'' 620 _immediate_mode =
False 623 def ImmediateBlobs():
624 with WorkspaceGuard(_immediate_workspace_name):
628 def RunOperatorImmediate(op):
629 with WorkspaceGuard(_immediate_workspace_name):
633 def FetchImmediate(*args, **kwargs):
634 with WorkspaceGuard(_immediate_workspace_name):
635 return FetchBlob(*args, **kwargs)
638 def FeedImmediate(*args, **kwargs):
639 with WorkspaceGuard(_immediate_workspace_name):
640 return FeedBlob(*args, **kwargs)
645 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
646 return CallWithExceptionIntercept(
648 ws._last_failed_op_net_position,
650 StringifyProto(net), overwrite,
654 def _Workspace_run(ws, obj):
655 if hasattr(obj,
'Proto'):
657 if isinstance(obj, caffe2_pb2.PlanDef):
658 return ws._run_plan(obj.SerializeToString())
659 if isinstance(obj, caffe2_pb2.NetDef):
660 return CallWithExceptionIntercept(
662 ws._last_failed_op_net_position,
664 obj.SerializeToString(),
667 if isinstance(obj, caffe2_pb2.OperatorDef):
668 return ws._run_operator(obj.SerializeToString())
670 "Don't know how to do Workspace.run() on {}".format(type(obj)))
673 def _Workspace_feed_blob(ws, name, arr, device_option=None):
674 if type(arr)
is caffe2_pb2.TensorProto:
675 arr = utils.Caffe2TensorToNumpyArray(arr)
676 if type(arr)
is np.ndarray
and arr.dtype.kind
in 'SU':
678 arr = arr.astype(np.object)
680 if device_option
is None:
681 device_option = scope.CurrentDeviceScope()
683 if device_option
and device_option.device_type == caffe2_pb2.CUDA:
684 if arr.dtype == np.dtype(
'float64'):
686 "CUDA operators do not support 64-bit doubles, " +
687 "please use arr.astype(np.float32) or np.int32 for ints." +
688 " Blob: {}".format(name) +
689 " type: {}".format(str(arr.dtype))
692 name = StringifyBlobName(name)
693 if device_option
is not None:
694 return ws.create_blob(name).feed(arr, device_option)
696 return ws.create_blob(name).feed(arr)
699 def _Workspace_remove_blob(ws, blob):
700 ws._remove_blob(str(blob))
703 Workspace = C.Workspace
704 Workspace.create_net = _Workspace_create_net_with_exception_intercept
705 Workspace.run = _Workspace_run
706 Workspace.feed_blob = _Workspace_feed_blob
707 Workspace.remove_blob = _Workspace_remove_blob
712 def _Blob_feed(blob, arg, device_option=None):
714 if type(arg).__name__ ==
'Tensor' and type(arg).__module__ ==
'torch':
716 if isinstance(arg, torch.Tensor):
717 assert device_option
is None, \
718 "device_option doesn't make sense with PyTorch tensors" 719 handle = torch._C._tensor_impl_raw_handle(arg)
720 blob._wrap_tensor_impl(handle)
722 if device_option
is not None:
723 device_option = StringifyProto(device_option)
724 return blob._feed(arg, device_option)
727 C.Blob.feed = _Blob_feed
730 def _Tensor_to_torch(tensor):
732 PyTorch tensor interop (TensorCPU methods) 735 workspace.Workspace.current.blobs['foo'].tensor().to_torch() 739 handle = tensor._tensor_impl_raw_handle()
740 return torch._C._wrap_tensor_impl(handle)
742 C.TensorCPU.to_torch = _Tensor_to_torch
745 def _Blob_to_torch(blob):
746 if not blob.is_tensor():
747 raise RuntimeError(
"Blob has to be a tensor")
748 return blob.as_tensor().to_torch()
750 C.Blob.to_torch = _Blob_to_torch