Caffe2 - Python API
A deep learning, cross platform ML framework
functional.py
1 # @package functional
2 # Module caffe2.python.layers.functional
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema, scope, workspace
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 import caffe2.proto.caffe2_pb2 as caffe2_pb2
13 import numpy as np
14 import six
15 import logging
16 
17 logger = logging.getLogger(__name__)
18 logger.setLevel(logging.INFO)
19 
20 
21 class Functional(ModelLayer):
22 
23  def __init__(self, model, input_record, output_names_or_num, function,
24  name='functional', output_dtypes=None, tags=None, **kwargs):
25 
26  # allow coercion
27  input_record = schema.as_record(input_record)
28 
29  super(Functional, self).__init__(model, name, input_record, tags=tags, **kwargs)
30  self._function = function
31  self._kwargs = kwargs
32  return_struct = (
33  isinstance(output_names_or_num, list) or
34  (isinstance(output_names_or_num, six.integer_types) and
35  output_names_or_num != 1)
36  )
37 
38  with scope.NameScope(self.name, reset=True):
39  if isinstance(output_names_or_num, int):
40  struct_output_schema = schema.NewRecord(
41  model.net, schema.RawTuple(output_names_or_num))
42  elif isinstance(output_names_or_num, schema.Field):
43  self.output_schema = output_names_or_num.clone(keep_blobs=True)
44  return
45  else:
46  if not isinstance(output_names_or_num, list):
47  output_names_or_num = [output_names_or_num]
48  out_tuple = [(out, np.void) for out in output_names_or_num]
49  struct_output_schema = schema.NewRecord(
50  model.net, schema.Struct(*out_tuple))
51 
52  num_outputs = len(struct_output_schema.field_blobs())
53 
54  # functional layer returns Struct if more than one outputs or output is
55  # a list, otherwise Scalar
56  if return_struct:
57  self.output_schema = struct_output_schema
58  else:
59  self.output_schema = struct_output_schema[0]
60 
61  # If output_dtypes is provided, use it for output schema. Otherwise
62  # the shape and type will be inferred.
63  if output_dtypes is not None:
64  if not isinstance(output_dtypes, list):
65  output_dtypes = [output_dtypes] * num_outputs
66  assert len(output_dtypes) == num_outputs
67  for dtype, scalar in zip(output_dtypes,
68  self.output_schema.all_scalars()):
69  scalar.set_type(dtype)
70  return
71 
72  # Fake execution of the function to infer shapes and types automatically
73  had_issues = False
74  try:
75  type_net = core.Net('_temp_type_and_shape_inference_net')
76  schema.InitEmptyRecord(type_net, input_record, enforce_types=True)
77 
78  function(type_net, self.input_record, self.output_schema, **kwargs)
79  (shapes, types) = workspace.InferShapesAndTypes([type_net], {})
80  for i in range(num_outputs):
81  scalar_schema = (self.output_schema[i] if return_struct
82  else self.output_schema)
83  blob = scalar_schema()
84  if blob not in types or blob not in shapes:
85  had_issues = True
86  continue
87  if shapes[blob] == []:
88  # Scalar type
89  shape = tuple()
90  elif shapes[blob][0] == 0:
91  shape = tuple(shapes[blob][1:])
92  else:
93  logger.warning("unexpeced shape: {}".format(shapes[blob]))
94  # If batch dimension is not first - give up on shape
95  # inference for that blob
96  had_issues = True
97  continue
98 
99  # TODO(amalevich): Move it to some shared library
100  dtype = None
101  if types[blob] == caffe2_pb2.TensorProto.DOUBLE:
102  dtype = (np.float64, shape)
103  elif types[blob] == caffe2_pb2.TensorProto.FLOAT:
104  dtype = (np.float32, shape)
105  elif types[blob] == caffe2_pb2.TensorProto.INT32:
106  dtype = (np.int32, shape)
107  elif types[blob] == caffe2_pb2.TensorProto.INT64:
108  dtype = (np.int64, shape)
109  elif types[blob] == caffe2_pb2.TensorProto.FLOAT16:
110  dtype = (np.float16, shape)
111 
112  if dtype is not None:
113  scalar_schema.set_type(dtype)
114  except TypeError as ex:
115  had_issues = True
116  logger.warning(str(ex))
117 
118  if had_issues:
119  logger.warning(
120  "Type inference had problems for layer: {}".format(self.name))
121 
122  def add_ops(self, net):
123  self._function(
124  net, self.input_record, self.output_schema, **(self._kwargs))