Caffe2 - Python API
A deep learning, cross platform ML framework
workspace.py
1 ## @package workspace
2 # Module caffe2.python.workspace
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import collections
8 import contextlib
9 from google.protobuf.message import Message
10 from multiprocessing import Process
11 import os
12 from collections import defaultdict
13 import logging
14 import numpy as np
15 from past.builtins import basestring
16 import shutil
17 import socket
18 import tempfile
19 
20 from caffe2.proto import caffe2_pb2
21 from caffe2.python import scope, utils
22 
24 
25 logger = logging.getLogger(__name__)
26 
27 Blobs = C.blobs
28 CreateBlob = C.create_blob
29 CurrentWorkspace = C.current_workspace
30 DeserializeBlob = C.deserialize_blob
31 GlobalInit = C.global_init
32 HasBlob = C.has_blob
33 RegisteredOperators = C.registered_operators
34 SerializeBlob = C.serialize_blob
35 SwitchWorkspace = C.switch_workspace
36 RootFolder = C.root_folder
37 Workspaces = C.workspaces
38 BenchmarkNet = C.benchmark_net
39 GetStats = C.get_stats
40 
41 operator_tracebacks = defaultdict(dict)
42 
43 is_asan = C.is_asan
44 has_cuda_support = C.has_cuda_support
45 has_hip_support = C.has_hip_support
46 has_gpu_support = C.has_gpu_support
47 if has_cuda_support:
48  GpuDeviceType = caffe2_pb2.CUDA
49  NumCudaDevices = C.num_cuda_devices
50  # This is a duplicate of NumCudaDevices. Remove
51  # NumCudaDevices once replaced everywhere in the code
52  NumGpuDevices = C.num_cuda_devices
53  GetCUDAVersion = C.get_cuda_version
54  GetCuDNNVersion = C.get_cudnn_version
55 
56  def GetGpuPeerAccessPattern():
57  return np.asarray(C.get_cuda_peer_access_pattern())
58 
59  GetDeviceProperties = C.get_device_properties
60 else:
61  NumCudaDevices = lambda: 0 # noqa
62  GetCUDAVersion = lambda: 0 # noqa
63  GetCuDNNVersion = lambda: 0 # noqa
64 
65 if has_hip_support:
66  GpuDeviceType = caffe2_pb2.HIP
67  NumGpuDevices = C.num_hip_devices
68 
69  def GetGpuPeerAccessPattern():
70  return np.asarray(C.get_hip_peer_access_pattern())
71  GetDeviceProperties = C.get_device_properties
72 
73 if not has_gpu_support:
74  # setting cuda as the default GpuDeviceType as some tests
75  # like core, scope tests use GpuDeviceType even without gpu support
76  GpuDeviceType = caffe2_pb2.CUDA
77  NumGpuDevices = lambda: 0 # noqa
78  GetDeviceProperties = lambda x: None # noqa
79  GetGpuPeerAccessPattern = lambda: np.array([]) # noqa
80 
81 IsNUMAEnabled = C.is_numa_enabled
82 GetNumNUMANodes = C.get_num_numa_nodes
83 GetBlobNUMANode = C.get_blob_numa_node
84 GetBlobSizeBytes = C.get_blob_size_bytes
85 
86 def _GetFreeFlaskPort():
87  """Get a free flask port."""
88  # We will prefer to use 5000. If not, we will then pick a random port.
89  sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
90  result = sock.connect_ex(('127.0.0.1', 5000))
91  if result == 0:
92  return 5000
93  else:
94  s = socket.socket()
95  s.bind(('', 0))
96  port = s.getsockname()[1]
97  s.close()
98  # Race condition: between the interval we close the socket and actually
99  # start a mint process, another process might have occupied the port. We
100  # don't do much here as this is mostly for convenience in research
101  # rather than 24x7 service.
102  return port
103 
104 def StartMint(root_folder=None, port=None):
105  """Start a mint instance.
106 
107  TODO(Yangqing): this does not work well under ipython yet. According to
108  https://github.com/ipython/ipython/issues/5862
109  writing up some fix is a todo item.
110  """
111  from caffe2.python.mint import app
112  if root_folder is None:
113  # Get the root folder from the current workspace
114  root_folder = C.root_folder()
115  if port is None:
116  port = _GetFreeFlaskPort()
117  process = Process(
118  target=app.main,
119  args=(
120  ['-p', str(port), '-r', root_folder],
121  )
122  )
123  process.start()
124  print('Mint running at http://{}:{}'.format(socket.getfqdn(), port))
125  return process
126 
127 
128 def StringifyProto(obj):
129  """Stringify a protocol buffer object.
130 
131  Inputs:
132  obj: a protocol buffer object, or a Pycaffe2 object that has a Proto()
133  function.
134  Outputs:
135  string: the output protobuf string.
136  Raises:
137  AttributeError: if the passed in object does not have the right attribute.
138  """
139  if isinstance(obj, basestring):
140  return obj
141  else:
142  if isinstance(obj, Message):
143  # First, see if this object is a protocol buffer, which we can
144  # simply serialize with the SerializeToString() call.
145  return obj.SerializeToString()
146  elif hasattr(obj, 'Proto'):
147  return obj.Proto().SerializeToString()
148  else:
149  raise ValueError("Unexpected argument to StringifyProto of type " +
150  type(obj).__name__)
151 
152 
153 def ResetWorkspace(root_folder=None):
154  if root_folder is None:
155  # Reset the workspace, but keep the current root folder setting.
156  return C.reset_workspace(C.root_folder())
157  else:
158  if not os.path.exists(root_folder):
159  os.makedirs(root_folder)
160  return C.reset_workspace(root_folder)
161 
162 
163 def CreateNet(net, overwrite=False, input_blobs=None):
164  if input_blobs is None:
165  input_blobs = []
166  for input_blob in input_blobs:
167  C.create_blob(input_blob)
168  return CallWithExceptionIntercept(
169  C.create_net,
170  C.Workspace.current._last_failed_op_net_position,
171  GetNetName(net),
172  StringifyProto(net), overwrite,
173  )
174 
175 
176 def Predictor(init_net, predict_net):
177  return C.Predictor(StringifyProto(init_net), StringifyProto(predict_net))
178 
179 
180 def GetOperatorCost(operator, blobs):
181  return C.get_operator_cost(StringifyProto(operator), blobs)
182 
183 
184 def RunOperatorOnce(operator):
185  return C.run_operator_once(StringifyProto(operator))
186 
187 
188 def RunOperatorsOnce(operators):
189  for op in operators:
190  success = RunOperatorOnce(op)
191  if not success:
192  return False
193  return True
194 
195 
196 def CallWithExceptionIntercept(func, op_id_fetcher, net_name, *args, **kwargs):
197  try:
198  return func(*args, **kwargs)
199  except Exception:
200  op_id = op_id_fetcher()
201  net_tracebacks = operator_tracebacks.get(net_name, None)
202  logger.warning(
203  'Original python traceback for operator `{}` in network '
204  '`{}` in exception above (most recent call last):'.format(
205  op_id, net_name))
206  if net_tracebacks and op_id in net_tracebacks:
207  tb = net_tracebacks[op_id]
208  for line in reversed(tb):
209  logger.warning(' File "{}", line {}, in {}'.format(
210  line[0], line[1], line[2]))
211  raise
212 
213 
214 def RunNetOnce(net):
215  return CallWithExceptionIntercept(
216  C.run_net_once,
217  C.Workspace.current._last_failed_op_net_position,
218  GetNetName(net),
219  StringifyProto(net),
220  )
221 
222 
223 def RunNet(name, num_iter=1, allow_fail=False):
224  """Runs a given net.
225 
226  Inputs:
227  name: the name of the net, or a reference to the net.
228  num_iter: number of iterations to run
229  allow_fail: if True, does not assert on net exec failure but returns False
230  Returns:
231  True or an exception.
232  """
233  return CallWithExceptionIntercept(
234  C.run_net,
235  C.Workspace.current._last_failed_op_net_position,
236  GetNetName(name),
237  StringifyNetName(name), num_iter, allow_fail,
238  )
239 
240 
241 def RunPlan(plan_or_step):
242  # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
243  import caffe2.python.core as core
244  if isinstance(plan_or_step, core.ExecutionStep):
245  plan_or_step = core.Plan(plan_or_step)
246  return C.run_plan(StringifyProto(plan_or_step))
247 
248 
249 def RunPlanInBackground(plan_or_step):
250  # TODO(jiayq): refactor core.py/workspace.py to avoid circular deps
251  import caffe2.python.core as core
252  if isinstance(plan_or_step, core.ExecutionStep):
253  plan_or_step = core.Plan(plan_or_step)
254  return C.run_plan_in_background(StringifyProto(plan_or_step))
255 
256 
257 def InferShapesAndTypes(nets, blob_dimensions=None, nets_proto=False,
258  blob_types=None):
259  """Infers the shapes and types for the specified nets.
260 
261  Inputs:
262  nets: the list of nets
263  blob_dimensions (optional): a dictionary of blobs and their dimensions.
264  If not specified, the workspace blobs are used.
265  nets_proto (optional): a boolean flag indicating whether the protobuffer
266  representation is passed to the routine.
267  Returns:
268  A tuple of (shapes, types) dictionaries keyed by blob name.
269  """
270  if nets_proto:
271  net_protos = [StringifyProto(n) for n in nets]
272  else:
273  net_protos = [StringifyProto(n.Proto()) for n in nets]
274  if blob_dimensions is None:
275  assert blob_types is None
276  blobdesc_prototxt = C.infer_shapes_and_types_from_workspace(net_protos)
277  elif blob_types is None:
278  blobdesc_prototxt = C.infer_shapes_and_types_from_map(
279  net_protos, blob_dimensions
280  )
281  else:
282  blobdesc_prototxt = C.infer_shapes_and_types_from_map(
283  net_protos, blob_dimensions, blob_types
284  )
285  blobdesc_proto = caffe2_pb2.TensorShapes()
286  blobdesc_proto.ParseFromString(blobdesc_prototxt)
287  shapes = {}
288  types = {}
289  for ts in blobdesc_proto.shapes:
290  if not ts.unknown_shape:
291  shapes[ts.name] = list(ts.dims)
292  types[ts.name] = ts.data_type
293 
294  return (shapes, types)
295 
296 
297 def _StringifyName(name, expected_type):
298  if isinstance(name, basestring):
299  return name
300  assert type(name).__name__ == expected_type, \
301  "Expected a string or %s" % expected_type
302  return str(name)
303 
304 
305 def StringifyBlobName(name):
306  return _StringifyName(name, "BlobReference")
307 
308 
309 def StringifyNetName(name):
310  return _StringifyName(name, "Net")
311 
312 
313 def GetNetName(net):
314  if isinstance(net, basestring):
315  return net
316  if type(net).__name__ == "Net":
317  return net.Name()
318  if isinstance(net, caffe2_pb2.NetDef):
319  return net.name
320  raise Exception("Not a Net object: {}".format(str(net)))
321 
322 
323 def FeedBlob(name, arr, device_option=None):
324  """Feeds a blob into the workspace.
325 
326  Inputs:
327  name: the name of the blob.
328  arr: either a TensorProto object or a numpy array object to be fed into
329  the workspace.
330  device_option (optional): the device option to feed the data with.
331  Returns:
332  True or False, stating whether the feed is successful.
333  """
334  ws = C.Workspace.current
335  return _Workspace_feed_blob(ws, name, arr, device_option)
336 
337 
338 def FetchBlobs(names):
339  """Fetches a list of blobs from the workspace.
340 
341  Inputs:
342  names: list of names of blobs - strings or BlobReferences
343  Returns:
344  list of fetched blobs
345  """
346  return [FetchBlob(name) for name in names]
347 
348 
349 def FetchBlob(name):
350  """Fetches a blob from the workspace.
351 
352  Inputs:
353  name: the name of the blob - a string or a BlobReference
354  Returns:
355  Fetched blob (numpy array or string) if successful
356  """
357  result = C.fetch_blob(StringifyBlobName(name))
358  if isinstance(result, tuple):
359  raise TypeError(
360  "Use FetchInt8Blob to fetch Int8 Blob {}".format(
361  StringifyBlobName(name)
362  )
363  )
364  return result
365 
366 
367 def FetchTorch(name):
368  ws = C.Workspace.current
369  return ws.blobs[name].to_torch()
370 
371 
372 Int8Tensor = collections.namedtuple(
373  'Int8Tensor', ['data', 'scale', 'zero_point']
374 )
375 
376 
377 def FetchInt8Blob(name):
378  """Fetches an Int8 blob from the workspace. It shared backend implementation
379  with FetchBlob but it is recommened when fetching Int8 Blobs
380 
381  Inputs:
382  name: the name of the Int8 blob - a string or a BlobReference
383  Returns:
384  data: int8 numpy array, data
385  scale: float, fake quantization scale
386  zero_point: int, fake quantization offset
387  """
388  result = C.fetch_blob(StringifyBlobName(name))
389  assert isinstance(result, tuple), \
390  'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
391  StringifyBlobName(name))
392  return Int8Tensor(*result)
393 
394 
395 def FetchInt8BlobRealVal(name):
396  """Fetches an Int8 blob from the workspace and return its real value representation.
397 
398  Inputs:
399  name: the name of the Int8 blob - a string or a BlobReference
400  Returns:
401  real value representation of int8 numpy array
402  """
403  result = C.fetch_blob(StringifyBlobName(name))
404  assert isinstance(result, tuple), \
405  'You are not fetching an Int8Blob {}. Please use FetchBlob'.format(
406  StringifyBlobName(name))
407  int8_blob = Int8Tensor(*result)
408  return (int8_blob.data.astype(np.int32) - int(int8_blob.zero_point)).astype(
409  np.float32) * int8_blob.scale
410 
411 
412 def _Workspace_fetch_int8_blob(ws, name):
413  """Fetches an Int8 blob from the workspace. It shared backend implementation
414  with FetchBlob but it is recommened when fetching Int8 Blobs
415 
416  Inputs:
417  name: the name of the Int8 blob - a string or a BlobReference
418  Returns:
419  data: int8 numpy array, data
420  scale: float, fake quantization scale
421  zero_point: int, fake quantization offset
422  """
423  result = ws.fetch_blob(name)
424  assert isinstance(result, tuple), \
425  'You are not fetching an Int8Blob {}. Please use fetch_blob'.format(
426  StringifyBlobName(name))
427  return Int8Tensor(*result)
428 
429 
430 C.Workspace.fetch_int8_blob = _Workspace_fetch_int8_blob
431 
432 
433 def ApplyTransform(transform_key, net):
434  """Apply a Transform to a NetDef protobuf object, and returns the new
435  transformed NetDef.
436 
437  Inputs:
438  transform_key: the name of the transform, as it is stored in the registry
439  net: a NetDef protobuf object
440  Returns:
441  Transformed NetDef protobuf object.
442  """
443  transformed_net = caffe2_pb2.NetDef()
444  transformed_str = C.apply_transform(
445  str(transform_key).encode('utf-8'),
446  net.SerializeToString(),
447  )
448  transformed_net.ParseFromString(transformed_str)
449  return transformed_net
450 
451 
452 def ApplyTransformIfFaster(transform_key, net, init_net, **kwargs):
453  """Apply a Transform to a NetDef protobuf object, and returns the new
454  transformed NetDef, only if it runs faster than the original.
455 
456  The runs are performed on the current active workspace (gWorkspace).
457  You should initialize that workspace before making a call to this function.
458 
459  Inputs:
460  transform_key: the name of the transform, as it is stored in the registry
461  net: a NetDef protobuf object
462  init_net: The net to initialize the workspace.
463  warmup_runs (optional):
464  Determines how many times the net is run before testing.
465  Will be 5 by default.
466  main_runs (optional):
467  Determines how many times the net is run during testing.
468  Will be 10 by default.
469  improvement_threshold (optional):
470  Determines the factor which the new net needs to be faster
471  in order to replace the old. Will be 1.01 by default.
472 
473  Returns:
474  Either a Transformed NetDef protobuf object, or the original netdef.
475  """
476 
477  warmup_runs = kwargs['warmup_runs'] if 'warmup_runs' in kwargs else 5
478  main_runs = kwargs['main_runs'] if 'main_runs' in kwargs else 10
479  improvement_threshold = kwargs['improvement_threshold'] \
480  if 'improvement_threshold' in kwargs else 1.01
481 
482  transformed_net = caffe2_pb2.NetDef()
483  transformed_str = C.apply_transform_if_faster(
484  str(transform_key).encode('utf-8'),
485  net.SerializeToString(),
486  init_net.SerializeToString(),
487  warmup_runs,
488  main_runs,
489  float(improvement_threshold),
490  )
491  transformed_net.ParseFromString(transformed_str)
492  return transformed_net
493 
494 
495 def GetNameScope():
496  """Return the current namescope string. To be used to fetch blobs"""
497  return scope.CurrentNameScope()
498 
499 
500 class _BlobDict(object):
501  """Provides python dict compatible way to do fetching and feeding"""
502 
503  def __getitem__(self, key):
504  return FetchBlob(key)
505 
506  def __setitem__(self, key, value):
507  return FeedBlob(key, value)
508 
509  def __len__(self):
510  return len(C.blobs())
511 
512  def __iter__(self):
513  return C.blobs().__iter__()
514 
515  def __contains__(self, item):
516  return C.has_blob(item)
517 
518 
519 blobs = _BlobDict()
520 
521 
522 ################################################################################
523 # Utilities for immediate mode
524 #
525 # Caffe2's immediate mode implements the following behavior: between the two
526 # function calls StartImmediate() and StopImmediate(), for any operator that is
527 # called through CreateOperator(), we will also run that operator in a workspace
528 # that is specific to the immediate mode. The user is explicitly expected to
529 # make sure that these ops have proper inputs and outputs, i.e. one should not
530 # run an op where an external input is not created or fed.
531 #
532 # Users can use FeedImmediate() and FetchImmediate() to interact with blobs
533 # in the immediate workspace.
534 #
535 # Once StopImmediate() is called, all contents in the immediate workspace is
536 # freed up so one can continue using normal runs.
537 #
538 # The immediate mode is solely for debugging purposes and support will be very
539 # sparse.
540 ################################################################################
541 
542 _immediate_mode = False
543 _immediate_workspace_name = "_CAFFE2_IMMEDIATE"
544 _immediate_root_folder = ''
545 
546 
547 def IsImmediate():
548  return _immediate_mode
549 
550 
551 @contextlib.contextmanager
552 def WorkspaceGuard(workspace_name):
553  current = CurrentWorkspace()
554  SwitchWorkspace(workspace_name, True)
555  yield
556  SwitchWorkspace(current)
557 
558 
559 def StartImmediate(i_know=False):
560  global _immediate_mode
561  global _immediate_root_folder
562  if IsImmediate():
563  # already in immediate mode. We will kill the previous one
564  # and start from fresh.
565  StopImmediate()
566  _immediate_mode = True
567  with WorkspaceGuard(_immediate_workspace_name):
568  _immediate_root_folder = tempfile.mkdtemp()
569  ResetWorkspace(_immediate_root_folder)
570  if i_know:
571  # if the user doesn't want to see the warning message, sure...
572  return
573  print("""
574  Enabling immediate mode in caffe2 python is an EXTREMELY EXPERIMENTAL
575  feature and may very easily go wrong. This is because Caffe2 uses a
576  declarative way of defining operators and models, which is essentially
577  not meant to run things in an interactive way. Read the following carefully
578  to make sure that you understand the caveats.
579 
580  (1) You need to make sure that the sequences of operators you create are
581  actually runnable sequentially. For example, if you create an op that takes
582  an input X, somewhere earlier you should have already created X.
583 
584  (2) Caffe2 immediate uses one single workspace, so if the set of operators
585  you run are intended to be under different workspaces, they will not run.
586  To create boundaries between such use cases, you can call FinishImmediate()
587  and StartImmediate() manually to flush out everything no longer needed.
588 
589  (3) Underlying objects held by the immediate mode may interfere with your
590  normal run. For example, if there is a leveldb that you opened in immediate
591  mode and did not close, your main run will fail because leveldb does not
592  support double opening. Immediate mode may also occupy a lot of memory esp.
593  on GPUs. Call FinishImmediate() as soon as possible when you no longer
594  need it.
595 
596  (4) Immediate is designed to be slow. Every immediate call implicitly
597  creates a temp operator object, runs it, and destroys the operator. This
598  slow-speed run is by design to discourage abuse. For most use cases other
599  than debugging, do NOT turn on immediate mode.
600 
601  (5) If there is anything FATAL happening in the underlying C++ code, the
602  immediate mode will immediately (pun intended) cause the runtime to crash.
603 
604  Thus you should use immediate mode with extra care. If you still would
605  like to, have fun [https://xkcd.com/149/].
606  """)
607 
608 
609 def StopImmediate():
610  """Stops an immediate mode run."""
611  # Phew, that was a dangerous ride.
612  global _immediate_mode
613  global _immediate_root_folder
614  if not IsImmediate():
615  return
616  with WorkspaceGuard(_immediate_workspace_name):
617  ResetWorkspace()
618  shutil.rmtree(_immediate_root_folder)
619  _immediate_root_folder = ''
620  _immediate_mode = False
621 
622 
623 def ImmediateBlobs():
624  with WorkspaceGuard(_immediate_workspace_name):
625  return Blobs()
626 
627 
628 def RunOperatorImmediate(op):
629  with WorkspaceGuard(_immediate_workspace_name):
630  RunOperatorOnce(op)
631 
632 
633 def FetchImmediate(*args, **kwargs):
634  with WorkspaceGuard(_immediate_workspace_name):
635  return FetchBlob(*args, **kwargs)
636 
637 
638 def FeedImmediate(*args, **kwargs):
639  with WorkspaceGuard(_immediate_workspace_name):
640  return FeedBlob(*args, **kwargs)
641 
642 
643 # C.Workspace methods.
644 
645 def _Workspace_create_net_with_exception_intercept(ws, net, overwrite=False):
646  return CallWithExceptionIntercept(
647  ws._create_net,
648  ws._last_failed_op_net_position,
649  GetNetName(net),
650  StringifyProto(net), overwrite,
651  )
652 
653 
654 def _Workspace_run(ws, obj):
655  if hasattr(obj, 'Proto'):
656  obj = obj.Proto()
657  if isinstance(obj, caffe2_pb2.PlanDef):
658  return ws._run_plan(obj.SerializeToString())
659  if isinstance(obj, caffe2_pb2.NetDef):
660  return CallWithExceptionIntercept(
661  ws._run_net,
662  ws._last_failed_op_net_position,
663  GetNetName(obj),
664  obj.SerializeToString(),
665  )
666  # return ws._run_net(obj.SerializeToString())
667  if isinstance(obj, caffe2_pb2.OperatorDef):
668  return ws._run_operator(obj.SerializeToString())
669  raise ValueError(
670  "Don't know how to do Workspace.run() on {}".format(type(obj)))
671 
672 
673 def _Workspace_feed_blob(ws, name, arr, device_option=None):
674  if type(arr) is caffe2_pb2.TensorProto:
675  arr = utils.Caffe2TensorToNumpyArray(arr)
676  if type(arr) is np.ndarray and arr.dtype.kind in 'SU':
677  # Plain NumPy strings are weird, let's use objects instead
678  arr = arr.astype(np.object)
679 
680  if device_option is None:
681  device_option = scope.CurrentDeviceScope()
682 
683  if device_option and device_option.device_type == caffe2_pb2.CUDA:
684  if arr.dtype == np.dtype('float64'):
685  logger.warning(
686  "CUDA operators do not support 64-bit doubles, " +
687  "please use arr.astype(np.float32) or np.int32 for ints." +
688  " Blob: {}".format(name) +
689  " type: {}".format(str(arr.dtype))
690  )
691 
692  name = StringifyBlobName(name)
693  if device_option is not None:
694  return ws.create_blob(name).feed(arr, device_option)
695  else:
696  return ws.create_blob(name).feed(arr)
697 
698 
699 def _Workspace_remove_blob(ws, blob):
700  ws._remove_blob(str(blob))
701 
702 
703 Workspace = C.Workspace
704 Workspace.create_net = _Workspace_create_net_with_exception_intercept
705 Workspace.run = _Workspace_run
706 Workspace.feed_blob = _Workspace_feed_blob
707 Workspace.remove_blob = _Workspace_remove_blob
708 
709 # C.Blob methods.
710 
711 
712 def _Blob_feed(blob, arg, device_option=None):
713  # conservative type check to avoid unnecessary import
714  if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch':
715  import torch
716  if isinstance(arg, torch.Tensor):
717  assert device_option is None, \
718  "device_option doesn't make sense with PyTorch tensors"
719  handle = torch._C._tensor_impl_raw_handle(arg)
720  blob._wrap_tensor_impl(handle)
721  return True # _feed() returns True for some reason
722  if device_option is not None:
723  device_option = StringifyProto(device_option)
724  return blob._feed(arg, device_option)
725 
726 
727 C.Blob.feed = _Blob_feed
728 
729 
730 def _Tensor_to_torch(tensor):
731  """
732  PyTorch tensor interop (TensorCPU methods)
733 
734  Can be accessed as:
735  workspace.Workspace.current.blobs['foo'].tensor().to_torch()
736  """
737  # avoiding circular dependency
738  import torch
739  handle = tensor._tensor_impl_raw_handle()
740  return torch._C._wrap_tensor_impl(handle)
741 
742 C.TensorCPU.to_torch = _Tensor_to_torch
743 
744 
745 def _Blob_to_torch(blob):
746  if not blob.is_tensor():
747  raise RuntimeError("Blob has to be a tensor")
748  return blob.as_tensor().to_torch()
749 
750 C.Blob.to_torch = _Blob_to_torch