3 from __future__ 
import absolute_import
     4 from __future__ 
import division
     5 from __future__ 
import print_function
     6 from __future__ 
import unicode_literals
     8 from caffe2.proto 
import caffe2_pb2
    10 from future.utils 
import viewitems
    11 from google.protobuf.message 
import DecodeError, Message
    12 from google.protobuf 
import text_format
    18 from six 
import integer_types, binary_type, text_type, string_types
    20 OPTIMIZER_ITERATION_NAME = 
"optimizer_iteration"    21 ITERATION_MUTEX_NAME = 
"iteration_mutex"    24 def OpAlmostEqual(op_a, op_b, ignore_fields=None):
    26     Two ops are identical except for each field in the `ignore_fields`.    28     ignore_fields = ignore_fields 
or []
    29     if not isinstance(ignore_fields, list):
    30         ignore_fields = [ignore_fields]
    32     assert all(isinstance(f, text_type) 
for f 
in ignore_fields), (
    33         'Expect each field is text type, but got {}'.format(ignore_fields))
    36         op = copy.deepcopy(op)
    37         for field 
in ignore_fields:
    38             if op.HasField(field):
    47 def CaffeBlobToNumpyArray(blob):
    50         return (np.asarray(blob.data, dtype=np.float32)
    51                 .reshape(blob.num, blob.channels, blob.height, blob.width))
    54         return (np.asarray(blob.data, dtype=np.float32)
    55                 .reshape(blob.shape.dim))
    58 def Caffe2TensorToNumpyArray(tensor):
    59     if tensor.data_type == caffe2_pb2.TensorProto.FLOAT:
    61             tensor.float_data, dtype=np.float32).reshape(tensor.dims)
    62     elif tensor.data_type == caffe2_pb2.TensorProto.DOUBLE:
    64             tensor.double_data, dtype=np.float64).reshape(tensor.dims)
    65     elif tensor.data_type == caffe2_pb2.TensorProto.INT32:
    67             tensor.int32_data, dtype=np.int).reshape(tensor.dims)   
    68     elif tensor.data_type == caffe2_pb2.TensorProto.INT16:
    70             tensor.int32_data, dtype=np.int16).reshape(tensor.dims)  
    71     elif tensor.data_type == caffe2_pb2.TensorProto.UINT16:
    73             tensor.int32_data, dtype=np.uint16).reshape(tensor.dims)  
    74     elif tensor.data_type == caffe2_pb2.TensorProto.INT8:
    76             tensor.int32_data, dtype=np.int8).reshape(tensor.dims)  
    77     elif tensor.data_type == caffe2_pb2.TensorProto.UINT8:
    79             tensor.int32_data, dtype=np.uint8).reshape(tensor.dims)  
    83             "Tensor data type not supported yet: " + str(tensor.data_type))
    86 def NumpyArrayToCaffe2Tensor(arr, name=None):
    87     tensor = caffe2_pb2.TensorProto()
    88     tensor.dims.extend(arr.shape)
    91     if arr.dtype == np.float32:
    92         tensor.data_type = caffe2_pb2.TensorProto.FLOAT
    93         tensor.float_data.extend(list(arr.flatten().astype(float)))
    94     elif arr.dtype == np.float64:
    95         tensor.data_type = caffe2_pb2.TensorProto.DOUBLE
    96         tensor.double_data.extend(list(arr.flatten().astype(np.float64)))
    97     elif arr.dtype == np.int 
or arr.dtype == np.int32:
    98         tensor.data_type = caffe2_pb2.TensorProto.INT32
    99         tensor.int32_data.extend(arr.flatten().astype(np.int).tolist())
   100     elif arr.dtype == np.int16:
   101         tensor.data_type = caffe2_pb2.TensorProto.INT16
   102         tensor.int32_data.extend(list(arr.flatten().astype(np.int16)))  
   103     elif arr.dtype == np.uint16:
   104         tensor.data_type = caffe2_pb2.TensorProto.UINT16
   105         tensor.int32_data.extend(list(arr.flatten().astype(np.uint16)))  
   106     elif arr.dtype == np.int8:
   107         tensor.data_type = caffe2_pb2.TensorProto.INT8
   108         tensor.int32_data.extend(list(arr.flatten().astype(np.int8)))   
   109     elif arr.dtype == np.uint8:
   110         tensor.data_type = caffe2_pb2.TensorProto.UINT8
   111         tensor.int32_data.extend(list(arr.flatten().astype(np.uint8)))   
   115             "Numpy data type not supported yet: " + str(arr.dtype))
   119 def MakeArgument(key, value):
   120     """Makes an argument based on the value type."""   121     argument = caffe2_pb2.Argument()
   123     iterable = isinstance(value, container_abcs.Iterable)
   129     if isinstance(value, np.ndarray) 
and value.dtype.type 
is np.float32:
   130         argument.floats.extend(value.flatten().tolist())
   133     if isinstance(value, np.ndarray):
   134         value = value.flatten().tolist()
   135     elif isinstance(value, np.generic):
   137         value = np.asscalar(value)
   139     if type(value) 
is float:
   141     elif type(value) 
in integer_types 
or type(value) 
is bool:
   145     elif isinstance(value, binary_type):
   147     elif isinstance(value, text_type):
   148         argument.s = value.encode(
'utf-8')
   149     elif isinstance(value, caffe2_pb2.NetDef):
   150         argument.n.CopyFrom(value)
   151     elif isinstance(value, Message):
   152         argument.s = value.SerializeToString()
   153     elif iterable 
and all(type(v) 
in [float, np.float_] 
for v 
in value):
   154         argument.floats.extend(
   155             v.item() 
if type(v) 
is np.float_ 
else v 
for v 
in value
   157     elif iterable 
and all(
   158         type(v) 
in integer_types 
or type(v) 
in [bool, np.int_] 
for v 
in value
   160         argument.ints.extend(
   161             v.item() 
if type(v) 
is np.int_ 
else v 
for v 
in value
   163     elif iterable 
and all(
   164         isinstance(v, binary_type) 
or isinstance(v, text_type) 
for v 
in value
   166         argument.strings.extend(
   167             v.encode(
'utf-8') 
if isinstance(v, text_type) 
else v
   170     elif iterable 
and all(isinstance(v, caffe2_pb2.NetDef) 
for v 
in value):
   171         argument.nets.extend(value)
   172     elif iterable 
and all(isinstance(v, Message) 
for v 
in value):
   173         argument.strings.extend(v.SerializeToString() 
for v 
in value)
   177                 "Unknown iterable argument type: key={} value={}, value "   178                 "type={}[{}]".format(
   179                     key, value, type(value), set(type(v) 
for v 
in value)
   184                 "Unknown argument type: key={} value={}, value type={}".format(
   185                     key, value, type(value)
   191 def TryReadProtoWithClass(cls, s):
   192     """Reads a protobuffer with the given proto class.   195       cls: a protobuffer class.   196       s: a string of either binary or text protobuffer content.   199       proto: the protobuffer of cls   202       google.protobuf.message.DecodeError: if we cannot decode the message.   206         text_format.Parse(s, obj)
   208     except text_format.ParseError:
   209         obj.ParseFromString(s)
   213 def GetContentFromProto(obj, function_map):
   214     """Gets a specific field from a protocol buffer that matches the given class   216     for cls, func 
in viewitems(function_map):
   221 def GetContentFromProtoString(s, function_map):
   222     for cls, func 
in viewitems(function_map):
   224             obj = TryReadProtoWithClass(cls, s)
   229         raise DecodeError(
"Cannot find a fit protobuffer class.")
   232 def ConvertProtoToBinary(proto_class, filename, out_filename):
   233     """Convert a text file of the given protobuf class to binary."""   234     with open(filename) 
as f:
   235         proto = TryReadProtoWithClass(proto_class, f.read())
   236     with open(out_filename, 
'w') 
as fid:
   237         fid.write(proto.SerializeToString())
   240 def GetGPUMemoryUsageStats():
   241     """Get GPU memory usage stats from CUDAContext/HIPContext. This requires flag   242        --caffe2_gpu_memory_tracking to be enabled"""   244     workspace.RunOperatorOnce(
   249             device_option=core.DeviceOption(workspace.GpuDeviceType, 0),
   252     b = workspace.FetchBlob(
"____mem____")
   254         'total_by_gpu': b[0, :],
   255         'max_by_gpu': b[1, :],
   256         'total': np.sum(b[0, :]),
   257         'max_total': np.sum(b[1, :])
   261 def ResetBlobs(blobs):
   263     workspace.RunOperatorOnce(
   268             device_option=core.DeviceOption(caffe2_pb2.CPU),
   275     This class allows to drop you into an interactive debugger   276     if there is an unhandled exception in your python script   284     if __name__ == '__main__':   285         from caffe2.python.utils import DebugMode   293         except KeyboardInterrupt:
   299                 'Entering interactive debugger. Type "bt" to print '   300                 'the full stacktrace. Type "help" to see command listing.')
   301             print(sys.exc_info()[1])
   309 def raiseIfNotEqual(a, b, msg):
   311         raise Exception(
"{}. {} != {}".format(msg, a, b))
   316     Use this method to decorate your function with DebugMode's functionality   322         raise Exception("Bar")   327     def wrapper(*args, **kwargs):
   329             return f(*args, **kwargs)
   330         return DebugMode.run(func)
   335 def BuildUniqueMutexIter(
   343     Often, a mutex guarded iteration counter is needed. This function creates a   344     mutex iter in the net uniquely (if the iter already existing, it does   347     This function returns the iter blob   349     iter = iter 
if iter 
is not None else OPTIMIZER_ITERATION_NAME
   350     iter_mutex = iter_mutex 
if iter_mutex 
is not None else ITERATION_MUTEX_NAME
   352     if not init_net.BlobIsDefined(iter):
   354         with core.DeviceScope(core.DeviceOption(caffe2_pb2.CPU)):
   355             iteration = init_net.ConstantFill(
   360                 dtype=core.DataType.INT64,
   362             iter_mutex = init_net.CreateMutex([], [iter_mutex])
   363             net.AtomicIter([iter_mutex, iteration], [iteration])
   365         iteration = init_net.GetBlobRef(iter)
   369 def EnumClassKeyVals(cls):
   371     assert type(cls) == type
   377             if isinstance(v, string_types):
   378                 assert v 
not in enum.values(), (
   379                     "Failed to resolve {} as Enum: "   380                     "duplicate entries {}={}, {}={}".format(
   381                         cls, k, v, [key 
for key 
in enum 
if enum[key] == v][0], v
   388 def ArgsToDict(args):
   390     Convert a list of arguments to a name, value dictionary. Assumes that   391     each argument has a name. Otherwise, the argument is skipped.   395         if not arg.HasField(
"name"):
   397         for d 
in arg.DESCRIPTOR.fields:
   400             if d.label == d.LABEL_OPTIONAL 
and arg.HasField(d.name):
   401                 ans[arg.name] = getattr(arg, d.name)
   403             elif d.label == d.LABEL_REPEATED:
   404                 list_ = getattr(arg, d.name)
   406                     ans[arg.name] = list_
   413 def NHWC2NCHW(tensor):
   414     assert tensor.ndim >= 1
   415     return tensor.transpose((0, tensor.ndim - 1) + tuple(range(1, tensor.ndim - 1)))
   418 def NCHW2NHWC(tensor):
   419     assert tensor.ndim >= 2
   420     return tensor.transpose((0,) + tuple(range(2, tensor.ndim)) + (1,))