Caffe2 - Python API
A deep learning, cross platform ML framework
tools.py
1 ## @package tools
2 # Module caffe2.python.helpers.tools
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 def image_input(
10  model, blob_in, blob_out, order="NCHW", use_gpu_transform=False, **kwargs
11 ):
12  assert 'is_test' in kwargs, "Argument 'is_test' is required"
13  if order == "NCHW":
14  if (use_gpu_transform):
15  kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
16  # GPU transform will handle NHWC -> NCHW
17  outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
18  pass
19  else:
20  outputs = model.net.ImageInput(
21  blob_in, [blob_out[0] + '_nhwc'] + blob_out[1:], **kwargs
22  )
23  outputs_list = list(outputs)
24  outputs_list[0] = model.net.NHWC2NCHW(outputs_list[0], blob_out[0])
25  outputs = tuple(outputs_list)
26  else:
27  outputs = model.net.ImageInput(blob_in, blob_out, **kwargs)
28  return outputs
29 
30 
31 def video_input(model, blob_in, blob_out, **kwargs):
32  # size of outputs can vary depending on kwargs
33  outputs = model.net.VideoInput(blob_in, blob_out, **kwargs)
34  return outputs