3 from __future__ 
import absolute_import
     4 from __future__ 
import division
     5 from __future__ 
import print_function
     6 from __future__ 
import unicode_literals
     9 from google.protobuf.message 
import Message
    10 from multiprocessing 
import Process
    12 from collections 
import defaultdict
    15 from past.builtins 
import basestring
    20 from caffe2.proto 
import caffe2_pb2
    25 logger = logging.getLogger(__name__)
    28 CreateBlob = C.create_blob
    29 CurrentWorkspace = C.current_workspace
    30 DeserializeBlob = C.deserialize_blob
    31 GlobalInit = C.global_init
    33 RegisteredOperators = C.registered_operators
    34 SerializeBlob = C.serialize_blob
    35 SwitchWorkspace = C.switch_workspace
    36 RootFolder = C.root_folder
    37 Workspaces = C.workspaces
    38 BenchmarkNet = C.benchmark_net
    39 GetStats = C.get_stats
    41 operator_tracebacks = defaultdict(dict)
    44 has_cuda_support = C.has_cuda_support
    45 has_hip_support = C.has_hip_support
    46 has_gpu_support = C.has_gpu_support
    48     GpuDeviceType = caffe2_pb2.CUDA
    49     NumCudaDevices = C.num_cuda_devices
    52     NumGpuDevices = C.num_cuda_devices
    53     GetCUDAVersion = C.get_cuda_version
    54     GetCuDNNVersion = C.get_cudnn_version
    56     def GetGpuPeerAccessPattern():
    57         return np.asarray(C.get_cuda_peer_access_pattern())
    59     GetDeviceProperties = C.get_device_properties
    61     NumCudaDevices = 
lambda: 0 
    62     GetCUDAVersion = 
lambda: 0 
    63     GetCuDNNVersion = 
lambda: 0 
    66     GpuDeviceType = caffe2_pb2.HIP
    67     NumGpuDevices = C.num_hip_devices
    69     def GetGpuPeerAccessPattern():
    70         return np.asarray(C.get_hip_peer_access_pattern())
    71     GetDeviceProperties = C.get_device_properties
    73 if not has_gpu_support:
    76     GpuDeviceType = caffe2_pb2.CUDA
    77     NumGpuDevices = 
lambda: 0 
    78     GetDeviceProperties = 
lambda x: 
None     79     GetGpuPeerAccessPattern = 
lambda: np.array([]) 
    81 IsNUMAEnabled = C.is_numa_enabled
    82 GetNumNUMANodes = C.get_num_numa_nodes
    83 GetBlobNUMANode = C.get_blob_numa_node
    84 GetBlobSizeBytes = C.get_blob_size_bytes
    86 def _GetFreeFlaskPort():
    87     """Get a free flask port."""    89     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    90     result = sock.connect_ex((
'127.0.0.1', 5000))
    96         port = s.getsockname()[1]
   104 def StartMint(root_folder=None, port=None):
   105     """Start a mint instance.   107     TODO(Yangqing): this does not work well under ipython yet. According to   108         https://github.com/ipython/ipython/issues/5862   109     writing up some fix is a todo item.   112     if root_folder 
is None:
   114         root_folder = C.root_folder()
   116         port = _GetFreeFlaskPort()
   120             [
'-p', str(port), 
'-r', root_folder],
   124     print(
'Mint running at http://{}:{}'.format(socket.getfqdn(), port))
   128 def StringifyProto(obj):
   129     """Stringify a protocol buffer object.   132     obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()   135     string: the output protobuf string.   137     AttributeError: if the passed in object does not have the right attribute.   139     if isinstance(obj, basestring):
   142         if isinstance(obj, Message):
   145             return obj.SerializeToString()
   146         elif hasattr(obj, 
'Proto'):
   147             return obj.Proto().SerializeToString()
   149             raise ValueError(
"Unexpected argument to StringifyProto of type " +
   153 def ResetWorkspace(root_folder=None):
   154     if root_folder 
is None:
   156         return C.reset_workspace(C.root_folder())
   158         if not os.path.exists(root_folder):
   159             os.makedirs(root_folder)
   160         return C.reset_workspace(root_folder)
   163 def CreateNet(net, overwrite=False, input_blobs=None):
   164     if input_blobs 
is None:
   166     for input_blob 
in input_blobs:
   167         C.create_blob(input_blob)
   168     return CallWithExceptionIntercept(
   170         C.Workspace.current._last_failed_op_net_position,
   172         StringifyProto(net), overwrite,
   176 def Predictor(init_net, predict_net):
   177     return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
   180 def GetOperatorCost(operator, blobs):
   181     return C.get_operator_cost(StringifyProto(operator), blobs)
   184 def RunOperatorOnce(operator):
   185     return C.run_operator_once(StringifyProto(operator))
   188 def RunOperatorsOnce(operators):
   190         success = RunOperatorOnce(op)
   196 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
   198         return func(*args, **kwargs)
   200         op_id = op_id_fetcher()
   201         net_tracebacks = operator_tracebacks.get(net_name, 
None)
   203             'Original python traceback for operator `{}` in network '   204             '`{}` in exception above (most recent call last):'.format(
   206         if net_tracebacks 
and op_id 
in net_tracebacks:
   207             tb = net_tracebacks[op_id]
   208             for line 
in reversed(tb):
   209                 logger.warning(
'  File "{}", line {}, in {}'.format(
   210                     line[0], line[1], line[2]))
   215     return CallWithExceptionIntercept(
   217         C.Workspace.current._last_failed_op_net_position,
   223 def RunNet(name, num_iter=1, allow_fail=False):
   227       name: the name of the net, or a reference to the net.   228       num_iter: number of iterations to run   229       allow_fail: if True, does not assert on net exec failure but returns False   231       True or an exception.   233     return CallWithExceptionIntercept(
   235         C.Workspace.current._last_failed_op_net_position,
   237         StringifyNetName(name), num_iter, allow_fail,
   241 def RunPlan(plan_or_step):
   244     if isinstance(plan_or_step, core.ExecutionStep):
   245         plan_or_step = core.Plan(plan_or_step)
   246     return C.run_plan(StringifyProto(plan_or_step))
   249 def RunPlanInBackground(plan_or_step):
   252     if isinstance(plan_or_step, core.ExecutionStep):
   253         plan_or_step = core.Plan(plan_or_step)
   254     return C.run_plan_in_background(StringifyProto(plan_or_step))
   257 def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
   259     """Infers the shapes and types for the specified nets.   262       nets: the list of nets   263       blob_dimensions (optional): a dictionary of blobs and their dimensions.   264           If not specified, the workspace blobs are used.   265       nets_proto (optional): a boolean flag indicating whether the protobuffer   266           representation is passed to the routine.   268       A tuple of (shapes, types) dictionaries keyed by blob name.   271         net_protos = [StringifyProto(n) 
for n 
in nets]
   273         net_protos = [StringifyProto(n.Proto()) 
for n 
in nets]
   274     if blob_dimensions 
is None:
   275         assert blob_types 
is None   276         blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
   277     elif blob_types 
is None:
   278         blobdesc_prototxt = C.infer_shapes_and_types_from_map(
   279             net_protos, blob_dimensions
   282         blobdesc_prototxt = C.infer_shapes_and_types_from_map(
   283             net_protos, blob_dimensions, blob_types
   285     blobdesc_proto = caffe2_pb2.TensorShapes()
   286     blobdesc_proto.ParseFromString(blobdesc_prototxt)
   289     for ts 
in blobdesc_proto.shapes:
   290         if not ts.unknown_shape:
   291             shapes[ts.name] = list(ts.dims)
   292             types[ts.name] = ts.data_type
   294     return (shapes, types)
   297 def _StringifyName(name, expected_type):
   298     if isinstance(name, basestring):
   300     assert type(name).__name__ == expected_type, \
   301         "Expected a string or %s" % expected_type
   305 def StringifyBlobName(name):
   306     return _StringifyName(name, 
"BlobReference")
   309 def StringifyNetName(name):
   310     return _StringifyName(name, 
"Net")
   314     if isinstance(net, basestring):
   316     if type(net).__name__ == 
"Net":
   318     if isinstance(net, caffe2_pb2.NetDef):
   320     raise Exception(
"Not a Net object: {}".format(str(net)))
   323 def FeedBlob(name, arr, device_option=None):
   324     """Feeds a blob into the workspace.   327       name: the name of the blob.   328       arr: either a TensorProto object or a numpy array object to be fed into   330       device_option (optional): the device option to feed the data with.   332       True or False, stating whether the feed is successful.   334     ws = C.Workspace.current
   335     return _Workspace_feed_blob(ws, name, arr, device_option)
   338 def FetchBlobs(names):
   339     """Fetches a list of blobs from the workspace.   342         names: list of names of blobs - strings or BlobReferences   344         list of fetched blobs   346     return [FetchBlob(name) 
for name 
in names]
   350     """Fetches a blob from the workspace.   353       name: the name of the blob - a string or a BlobReference   355       Fetched blob (numpy array or string) if successful   357     result = C.fetch_blob(StringifyBlobName(name))
   358     if isinstance(result, tuple):
   360             "Use FetchInt8Blob to fetch Int8 Blob {}".format(
   361                 StringifyBlobName(name)
   367 def FetchTorch(name):
   368     ws = C.Workspace.current
   369     return ws.blobs[name].to_torch()
   372 Int8Tensor = collections.namedtuple(
   373     'Int8Tensor', [
'data', 
'scale', 
'zero_point']
   377 def FetchInt8Blob(name):
   378     """Fetches an Int8 blob from the workspace. It shared backend implementation   379     with FetchBlob but it is recommened when fetching Int8 Blobs   382       name: the name of the Int8 blob - a string or a BlobReference   384       data: int8 numpy array, data   385       scale: float, fake quantization scale   386       zero_point: int, fake quantization offset   388     result = C.fetch_blob(StringifyBlobName(name))
   389     assert isinstance(result, tuple), \
   390         'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
   391             StringifyBlobName(name))
   392     return Int8Tensor(*result)
   395 def FetchInt8BlobRealVal(name):
   396     """Fetches an Int8 blob from the workspace and return its real value representation.   399       name: the name of the Int8 blob - a string or a BlobReference   401       real value representation of int8 numpy array   403     result = C.fetch_blob(StringifyBlobName(name))
   404     assert isinstance(result, tuple), \
   405         'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
   406             StringifyBlobName(name))
   407     int8_blob = Int8Tensor(*result)
   408     return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
   409         np.float32) * int8_blob.scale
   412 def _Workspace_fetch_int8_blob(ws, name):
   413     """Fetches an Int8 blob from the workspace. It shared backend implementation   414     with FetchBlob but it is recommened when fetching Int8 Blobs   417       name: the name of the Int8 blob - a string or a BlobReference   419       data: int8 numpy array, data   420       scale: float, fake quantization scale   421       zero_point: int, fake quantization offset   423     result = ws.fetch_blob(name)
   424     assert isinstance(result, tuple), \
   425         'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
   426             StringifyBlobName(name))
   427     return Int8Tensor(*result)
   430 C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
   433 def ApplyTransform(transform_key, net):
   434     """Apply a Transform to a NetDef protobuf object, and returns the new   438       transform_key: the name of the transform, as it is stored in the registry   439       net: a NetDef protobuf object   441       Transformed NetDef protobuf object.   443     transformed_net = caffe2_pb2.NetDef()
   444     transformed_str = C.apply_transform(
   445         str(transform_key).encode(
'utf-8'),
   446         net.SerializeToString(),
   448     transformed_net.ParseFromString(transformed_str)
   449     return transformed_net
   452 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
   453     """Apply a Transform to a NetDef protobuf object, and returns the new   454     transformed NetDef, only if it runs faster than the original.   456     The runs are performed on the current active workspace (gWorkspace).   457     You should initialize that workspace before making a call to this function.   460       transform_key: the name of the transform, as it is stored in the registry   461       net: a NetDef protobuf object   462       init_net: The net to initialize the workspace.   463       warmup_runs (optional):   464         Determines how many times the net is run before testing.   465         Will be 5 by default.   466       main_runs (optional):   467         Determines how many times the net is run during testing.   468         Will be 10 by default.   469       improvement_threshold (optional):   470         Determines the factor which the new net needs to be faster   471         in order to replace the old. Will be 1.01 by default.   474       Either a Transformed NetDef protobuf object, or the original netdef.   477     warmup_runs = kwargs[
'warmup_runs'] 
if 'warmup_runs' in kwargs 
else 5
   478     main_runs = kwargs[
'main_runs'] 
if 'main_runs' in kwargs 
else 10
   479     improvement_threshold = kwargs[
'improvement_threshold'] \
   480         if 'improvement_threshold' in kwargs 
else 1.01
   482     transformed_net = caffe2_pb2.NetDef()
   483     transformed_str = C.apply_transform_if_faster(
   484         str(transform_key).encode(
'utf-8'),
   485         net.SerializeToString(),
   486         init_net.SerializeToString(),
   489         float(improvement_threshold),
   491     transformed_net.ParseFromString(transformed_str)
   492     return transformed_net
   496     """Return the current namescope string. To be used to fetch blobs"""   497     return scope.CurrentNameScope()
   501     """Provides python dict compatible way to do fetching and feeding"""   503     def __getitem__(self, key):
   504         return FetchBlob(key)
   506     def __setitem__(self, key, value):
   507         return FeedBlob(key, value)
   510         return len(C.blobs())
   513         return C.blobs().__iter__()
   515     def __contains__(self, item):
   516         return C.has_blob(item)
   542 _immediate_mode = 
False   543 _immediate_workspace_name = 
"_CAFFE2_IMMEDIATE"   544 _immediate_root_folder = 
''   548     return _immediate_mode
   551 @contextlib.contextmanager
   552 def WorkspaceGuard(workspace_name):
   553     current = CurrentWorkspace()
   554     SwitchWorkspace(workspace_name, 
True)
   556     SwitchWorkspace(current)
   559 def StartImmediate(i_know=False):
   560     global _immediate_mode
   561     global _immediate_root_folder
   566     _immediate_mode = 
True   567     with WorkspaceGuard(_immediate_workspace_name):
   568         _immediate_root_folder = tempfile.mkdtemp()
   569         ResetWorkspace(_immediate_root_folder)
   574     Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL   575     feature and may very easily go wrong. This is because Caffe2 uses a   576     declarative way of defining operators and models, which is essentially   577     not meant to run things in an interactive way. Read the following carefully   578     to make sure that you understand the caveats.   580     (1) You need to make sure that the sequences of operators you create are   581     actually runnable sequentially. For example, if you create an op that takes   582     an input X, somewhere earlier you should have already created X.   584     (2) Caffe2 immediate uses one single workspace, so if the set of operators   585     you run are intended to be under different workspaces, they will not run.   586     To create boundaries between such use cases, you can call FinishImmediate()   587     and StartImmediate() manually to flush out everything no longer needed.   589     (3) Underlying objects held by the immediate mode may interfere with your   590     normal run. For example, if there is a leveldb that you opened in immediate   591     mode and did not close, your main run will fail because leveldb does not   592     support double opening. Immediate mode may also occupy a lot of memory esp.   593     on GPUs. Call FinishImmediate() as soon as possible when you no longer   596     (4) Immediate is designed to be slow. Every immediate call implicitly   597     creates a temp operator object, runs it, and destroys the operator. This   598     slow-speed run is by design to discourage abuse. For most use cases other   599     than debugging, do NOT turn on immediate mode.   601     (5) If there is anything FATAL happening in the underlying C++ code, the   602     immediate mode will immediately (pun intended) cause the runtime to crash.   604     Thus you should use immediate mode with extra care. If you still would   605     like to, have fun [https://xkcd.com/149/].   610     """Stops an immediate mode run."""   612     global _immediate_mode
   613     global _immediate_root_folder
   614     if not IsImmediate():
   616     with WorkspaceGuard(_immediate_workspace_name):
   618     shutil.rmtree(_immediate_root_folder)
   619     _immediate_root_folder = 
''   620     _immediate_mode = 
False   623 def ImmediateBlobs():
   624     with WorkspaceGuard(_immediate_workspace_name):
   628 def RunOperatorImmediate(op):
   629     with WorkspaceGuard(_immediate_workspace_name):
   633 def FetchImmediate(*args, **kwargs):
   634     with WorkspaceGuard(_immediate_workspace_name):
   635         return FetchBlob(*args, **kwargs)
   638 def FeedImmediate(*args, **kwargs):
   639     with WorkspaceGuard(_immediate_workspace_name):
   640         return FeedBlob(*args, **kwargs)
   645 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
   646     return CallWithExceptionIntercept(
   648         ws._last_failed_op_net_position,
   650         StringifyProto(net), overwrite,
   654 def _Workspace_run(ws, obj):
   655     if hasattr(obj, 
'Proto'):
   657     if isinstance(obj, caffe2_pb2.PlanDef):
   658         return ws._run_plan(obj.SerializeToString())
   659     if isinstance(obj, caffe2_pb2.NetDef):
   660         return CallWithExceptionIntercept(
   662             ws._last_failed_op_net_position,
   664             obj.SerializeToString(),
   667     if isinstance(obj, caffe2_pb2.OperatorDef):
   668         return ws._run_operator(obj.SerializeToString())
   670         "Don't know how to do Workspace.run() on {}".format(type(obj)))
   673 def _Workspace_feed_blob(ws, name, arr, device_option=None):
   674     if type(arr) 
is caffe2_pb2.TensorProto:
   675         arr = utils.Caffe2TensorToNumpyArray(arr)
   676     if type(arr) 
is np.ndarray 
and arr.dtype.kind 
in 'SU':
   678         arr = arr.astype(np.object)
   680     if device_option 
is None:
   681         device_option = scope.CurrentDeviceScope()
   683     if device_option 
and device_option.device_type == caffe2_pb2.CUDA:
   684         if arr.dtype == np.dtype(
'float64'):
   686                 "CUDA operators do not support 64-bit doubles, " +
   687                 "please use arr.astype(np.float32) or np.int32 for ints." +
   688                 " Blob: {}".format(name) +
   689                 " type: {}".format(str(arr.dtype))
   692     name = StringifyBlobName(name)
   693     if device_option 
is not None:
   694         return ws.create_blob(name).feed(arr, device_option)
   696         return ws.create_blob(name).feed(arr)
   699 def _Workspace_remove_blob(ws, blob):
   700     ws._remove_blob(str(blob))
   703 Workspace = C.Workspace
   704 Workspace.create_net = _Workspace_create_net_with_exception_intercept
   705 Workspace.run = _Workspace_run
   706 Workspace.feed_blob = _Workspace_feed_blob
   707 Workspace.remove_blob = _Workspace_remove_blob
   712 def _Blob_feed(blob, arg, device_option=None):
   714     if type(arg).__name__ == 
'Tensor' and type(arg).__module__ == 
'torch':
   716         if isinstance(arg, torch.Tensor):
   717             assert device_option 
is None, \
   718                 "device_option doesn't make sense with PyTorch tensors"   719             handle = torch._C._tensor_impl_raw_handle(arg)
   720             blob._wrap_tensor_impl(handle)
   722     if device_option 
is not None:
   723         device_option = StringifyProto(device_option)
   724     return blob._feed(arg, device_option)
   727 C.Blob.feed = _Blob_feed
   730 def _Tensor_to_torch(tensor):
   732     PyTorch tensor interop (TensorCPU methods)   735       workspace.Workspace.current.blobs['foo'].tensor().to_torch()   739     handle = tensor._tensor_impl_raw_handle()
   740     return torch._C._wrap_tensor_impl(handle)
   742 C.TensorCPU.to_torch = _Tensor_to_torch
   745 def _Blob_to_torch(blob):
   746     if not blob.is_tensor():
   747         raise RuntimeError(
"Blob has to be a tensor")
   748     return blob.as_tensor().to_torch()
   750 C.Blob.to_torch = _Blob_to_torch