Caffe2 - Python API
A deep learning, cross platform ML framework
schema.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package schema
17 # Module caffe2.python.schema
18 """
19 Defines a minimal set of data types that allow to represent datasets with
20 arbitrary nested structure, including objects of variable length, such as
21 maps and lists.
22 
23 This defines a columnar storage format for such datasets on top of caffe2
24 tensors. In terms of capacity of representation, it can represent most of
25 the data types supported by Parquet, ORC, DWRF file formats.
26 
27 See comments in operator_test/dataset_ops_test.py for an example and
28 walkthrough on how to use schema to store and iterate through a structured
29 in-memory dataset.
30 """
31 from __future__ import absolute_import
32 from __future__ import division
33 from __future__ import print_function
34 from __future__ import unicode_literals
35 
36 import logging
37 import numpy as np
38 from caffe2.python import core
39 from caffe2.python import workspace
40 from caffe2.python.core import BlobReference
41 from collections import OrderedDict, namedtuple
42 from past.builtins import basestring
43 from future.utils import viewitems, viewkeys, viewvalues
44 from itertools import islice
45 from six import StringIO
46 
47 logger = logging.getLogger(__name__)
48 logger.setLevel(logging.INFO)
49 
50 FIELD_SEPARATOR = ':'
51 
52 
53 def _join_field_name(prefix, suffix):
54  if prefix and suffix:
55  return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
56  elif prefix:
57  return prefix
58  elif suffix:
59  return suffix
60  else:
61  return ''
62 
63 
64 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
65  """Clones/normalizes a field before adding it to a container."""
66  if isinstance(field_or_type_or_blob, Field):
67  return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
68  elif type(field_or_type_or_blob) in (type, np.dtype):
69  return Scalar(dtype=field_or_type_or_blob)
70  else:
71  return Scalar(blob=field_or_type_or_blob)
72 
73 
74 FeatureSpec = namedtuple(
75  'FeatureSpec',
76  [
77  'feature_type',
78  'feature_names',
79  'feature_ids',
80  'feature_is_request_only',
81  'desired_hash_size',
82  ]
83 )
84 
85 FeatureSpec.__new__.__defaults__ = (None, None, None, None, None)
86 
87 
88 class Metadata(
89  namedtuple(
90  'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
91  )
92 ):
93  """Represents additional information associated with a scalar in schema.
94 
95  `categorical_limit` - for fields of integral type that are guaranteed to be
96  non-negative it specifies the maximum possible value plus one. It's often
97  used as a size of an embedding table.
98 
99  `expected_value` - anticipated average value of elements in the field.
100  Usually makes sense for length fields of lists.
101 
102  `feature_specs` - information about the features that contained in this
103  field. For example if field have more than 1 feature it can have list of
104  feature names contained in this field."""
105  __slots__ = ()
106 
107 
108 Metadata.__new__.__defaults__ = (None, None, None)
109 
110 
111 class Field(object):
112  """Represents an abstract field type in a dataset.
113  """
114 
115  def __init__(self, children):
116  """Derived classes must call this after their initialization."""
117  self._parent = (None, 0)
118  offset = 0
119  self._field_offsets = []
120  for child in children:
121  self._field_offsets.append(offset)
122  offset += len(child.field_names())
123  self._field_offsets.append(offset)
124 
125  def clone_schema(self):
126  return self.clone(keep_blobs=False)
127 
128  def field_names(self):
129  """Return the children field names for this field."""
130  raise NotImplementedError('Field is an abstract class.')
131 
132  def field_types(self):
133  """Return the numpy.dtype for each of the children fields."""
134  raise NotImplementedError('Field is an abstract class.')
135 
136  def field_metadata(self):
137  """Return the Metadata for each of the children fields."""
138  raise NotImplementedError('Field is an abstract class.')
139 
140  def field_blobs(self):
141  """Return the list of blobs with contents for this Field.
142  Values can either be all numpy.ndarray or BlobReference.
143  If any of the fields doens't have a blob, throws.
144  """
145  raise NotImplementedError('Field is an abstract class.')
146 
147  def all_scalars(self):
148  """Return the list of all Scalar instances in the Field.
149  The order is the same as for field_names() or field_blobs()"""
150  raise NotImplementedError('Field is an abstract class.')
151 
152  def has_blobs(self):
153  """Return True if every scalar of this field has blobs."""
154  raise NotImplementedError('Field is an abstract class.')
155 
156  def clone(self, keep_blobs=True):
157  """Clone this Field along with its children."""
158  raise NotImplementedError('Field is an abstract class.')
159 
160  def _set_parent(self, parent, relative_id):
161  self._parent = (parent, relative_id)
162 
163  def slice(self):
164  """
165  Returns a slice representing the range of field ids that belong to
166  this field. This slice can be used to index a list of fields.
167 
168  E.g.:
169 
170  >>> s = Struct(
171  >>> ('a', Scalar()),
172  >>> ('b', Struct(
173  >>> ('b1', Scalar()),
174  >>> ('b2', Scalar()),
175  >>> )),
176  >>> ('c', Scalar()),
177  >>> )
178  >>> field_data = ['da', 'db1', 'db2', 'dc']
179  >>> field_data[s.b.split()]
180  ['db1', 'db2']
181  """
182  base_id = self._child_base_id()
183  return slice(base_id, base_id + len(self.field_names()))
184 
185  def _child_base_id(self, child_index=None):
186  """Get the base id of the given child"""
187  p, i = self._parent
188  pos = 0 if child_index is None else self._field_offsets[child_index]
189  if p:
190  pos += p._child_base_id(i)
191  return pos
192 
193  def __eq__(self, other):
194  """Equivalance of two schemas"""
195  return (
196  (self.field_names() == other.field_names()) and
197  (self.field_types() == other.field_types()) and
198  (self.field_metadata() == other.field_metadata())
199  )
200 
201  def _pprint_impl(self, indent, str_buffer):
202  raise NotImplementedError('Field is an abstrct class.')
203 
204  def __repr__(self):
205  str_buffer = StringIO()
206  self._pprint_impl(0, str_buffer)
207  contents = str_buffer.getvalue()
208  str_buffer.close()
209  return contents
210 
211 
212 class List(Field):
213  """Represents a variable-length list.
214 
215  Values of a list can also be complex fields such as Lists and Structs.
216  In addition to the fields exposed by its `values` field, a List exposes an
217  additional `lengths` field, which will contain the size of each list under
218  the parent domain.
219  """
220 
221  def __init__(self, values, lengths_blob=None):
222  if isinstance(lengths_blob, Field):
223  assert isinstance(lengths_blob, Scalar)
224  self.lengths = _normalize_field(lengths_blob)
225  else:
226  self.lengths = Scalar(np.int32, lengths_blob)
227  self._items = _normalize_field(values)
228  self.lengths._set_parent(self, 0)
229  self._items._set_parent(self, 1)
230  Field.__init__(self, [self.lengths, self._items])
231 
232  def field_names(self):
233  value_fields = self._items.field_names()
234  return (
235  ['lengths'] + [_join_field_name('values', v) for v in value_fields]
236  )
237 
238  def field_types(self):
239  return self.lengths.field_types() + self._items.field_types()
240 
241  def field_metadata(self):
242  return self.lengths.field_metadata() + self._items.field_metadata()
243 
244  def field_blobs(self):
245  return self.lengths.field_blobs() + self._items.field_blobs()
246 
247  def all_scalars(self):
248  return self.lengths.all_scalars() + self._items.all_scalars()
249 
250  def has_blobs(self):
251  return self.lengths.has_blobs() and self._items.has_blobs()
252 
253  def clone(self, keep_blobs=True):
254  return List(
255  _normalize_field(self._items, keep_blobs=keep_blobs),
256  _normalize_field(self.lengths, keep_blobs=keep_blobs)
257  )
258 
259  def _pprint_impl(self, indent, str_buffer):
260  str_buffer.write(' ' * indent + "List(\n")
261  str_buffer.write(' ' * (indent + 1) + "lengths=\n")
262  self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
263  str_buffer.write(' ' * (indent + 1) + "_items=\n")
264  self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
265  str_buffer.write(' ' * indent + ")\n")
266 
267  def __getattr__(self, item):
268  """If the value of this list is a struct,
269  allow to introspect directly into its fields."""
270  if item.startswith('__'):
271  raise AttributeError(item)
272  if isinstance(self._items, Struct):
273  return getattr(self._items, item)
274  elif item == 'value' or item == 'items':
275  return self._items
276  else:
277  raise AttributeError('Field not found in list: %s.' % item)
278 
279  def __getitem__(self, item):
280  names = item.split(FIELD_SEPARATOR, 1)
281 
282  if len(names) == 1:
283  if item == 'lengths':
284  return self.lengths
285  elif item == 'values':
286  return self._items
287  else:
288  if names[0] == 'values':
289  return self._items[names[1]]
290  raise KeyError('Field not found in list: %s.' % item)
291 
292 
293 class Struct(Field):
294  """Represents a named list of fields sharing the same domain.
295  """
296 
297  def __init__(self, *fields):
298  """ fields is a list of tuples in format of (name, field). The name is
299  a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
300 
301  Struct(
302  ('a', Scalar()),
303  ('b:c', Scalar()),
304  ('b:d:e', Scalar()),
305  ('b', Struct(
306  ('f', Scalar()),
307  )),
308  )
309 
310  is equal to
311 
312  Struct(
313  ('a', Scalar()),
314  ('b', Struct(
315  ('c', Scalar()),
316  ('d', Struct(('e', Scalar()))),
317  ('f', Scalar()),
318  )),
319  )
320  """
321  for field in fields:
322  assert len(field) == 2
323  assert field[0], 'Field names cannot be empty'
324  assert field[0] != 'lengths', (
325  'Struct cannot contain a field named `lengths`.'
326  )
327  fields = [(name, _normalize_field(field)) for name, field in fields]
328  self.fields = OrderedDict()
329  for name, field in fields:
330  if FIELD_SEPARATOR in name:
331  name, field = self._struct_from_nested_name(name, field)
332  if name not in self.fields:
333  self.fields[name] = field
334  continue
335  if (
336  not isinstance(field, Struct) or
337  not isinstance(self.fields[name], Struct)
338  ):
339  raise ValueError('Duplicate field name: %s' % name)
340  self.fields[name] = self.fields[name] + field
341  for id, (_, field) in enumerate(viewitems(self.fields)):
342  field._set_parent(self, id)
343  Field.__init__(self, viewvalues(self.fields))
344  self._frozen = True
345 
346  def _struct_from_nested_name(self, nested_name, field):
347  def create_internal(nested_name, field):
348  names = nested_name.split(FIELD_SEPARATOR, 1)
349  if len(names) == 1:
350  added_field = field
351  else:
352  added_field = create_internal(names[1], field)
353  return Struct((names[0], added_field))
354 
355  names = nested_name.split(FIELD_SEPARATOR, 1)
356  assert len(names) >= 2
357  return names[0], create_internal(names[1], field)
358 
359  def get_children(self):
360  return list(viewitems(self.fields))
361 
362  def field_names(self):
363  names = []
364  for name, field in viewitems(self.fields):
365  names += [_join_field_name(name, f) for f in field.field_names()]
366  return names
367 
368  def field_types(self):
369  types = []
370  for _, field in viewitems(self.fields):
371  types += field.field_types()
372  return types
373 
374  def field_metadata(self):
375  metadata = []
376  for _, field in viewitems(self.fields):
377  metadata += field.field_metadata()
378  return metadata
379 
380  def field_blobs(self):
381  blobs = []
382  for _, field in viewitems(self.fields):
383  blobs += field.field_blobs()
384  return blobs
385 
386  def all_scalars(self):
387  scalars = []
388  for _, field in viewitems(self.fields):
389  scalars += field.all_scalars()
390  return scalars
391 
392  def has_blobs(self):
393  return all(field.has_blobs() for field in viewvalues(self.fields))
394 
395  def clone(self, keep_blobs=True):
396  normalized_fields = [
397  (k, _normalize_field(v, keep_blobs=keep_blobs))
398  for k, v in viewitems(self.fields)
399  ]
400  return Struct(*normalized_fields)
401 
402  def _get_field_by_nested_name(self, nested_name):
403  names = nested_name.split(FIELD_SEPARATOR, 1)
404  field = self.fields.get(names[0], None)
405 
406  if field is None:
407  return None
408 
409  if len(names) == 1:
410  return field
411 
412  try:
413  return field[names[1]]
414  except (KeyError, TypeError):
415  return None
416 
417  def _pprint_impl(self, indent, str_buffer):
418  str_buffer.write(' ' * indent + "Struct( \n")
419  for name, field in viewitems(self.fields):
420  str_buffer.write(' ' * (indent + 1) + "{}=".format(name) + "\n")
421  field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
422  str_buffer.write(' ' * indent + ") \n")
423 
424  def __contains__(self, item):
425  field = self._get_field_by_nested_name(item)
426  return field is not None
427 
428  def __len__(self):
429  return len(self.fields)
430 
431  def __getitem__(self, item):
432  """
433  item can be a tuple or list of ints or strings, or a single
434  int or string. String item is a nested field name, e.g., "a", "a:b",
435  "a:b:c". Int item is the index of a field at the first level of the
436  Struct.
437  """
438  if isinstance(item, list) or isinstance(item, tuple):
439  keys = list(viewkeys(self.fields))
440  return Struct(
441  * [
442  (
443  keys[k]
444  if isinstance(k, int) else k, self[k]
445  ) for k in item
446  ]
447  )
448  elif isinstance(item, int):
449  return next(islice(viewvalues(self.fields), item, None))
450  else:
451  field = self._get_field_by_nested_name(item)
452  if field is None:
453  raise KeyError('field "%s" not found' % (item))
454  return field
455 
456  def get(self, item, default_value):
457  """
458  similar to python's dictionary get method, return field of item if found
459  (i.e. self.item is valid) or otherwise return default_value
460 
461  it's a syntax suger of python's builtin getattr method
462  """
463  return getattr(self, item, default_value)
464 
465  def __getattr__(self, item):
466  if item.startswith('__'):
467  raise AttributeError(item)
468  try:
469  return self.__dict__['fields'][item]
470  except KeyError:
471  raise AttributeError(item)
472 
473  def __setattr__(self, key, value):
474  # Disable setting attributes after initialization to prevent false
475  # impression of being able to overwrite a field.
476  # Allowing setting internal states mainly so that _parent can be set
477  # post initialization.
478  if getattr(self, '_frozen', None) and not key.startswith('_'):
479  raise TypeError('Struct.__setattr__() is disabled after __init__()')
480  super(Struct, self).__setattr__(key, value)
481 
482  def __add__(self, other):
483  """
484  Allows to merge fields of two schema.Struct using '+' operator.
485  If two Struct have common field names, the merge is conducted
486  recursively. Here are examples:
487 
488  Example 1
489  s1 = Struct(('a', Scalar()))
490  s2 = Struct(('b', Scalar()))
491  s1 + s2 == Struct(
492  ('a', Scalar()),
493  ('b', Scalar()),
494  )
495 
496  Example 2
497  s1 = Struct(
498  ('a', Scalar()),
499  ('b', Struct(('c', Scalar()))),
500  )
501  s2 = Struct(('b', Struct(('d', Scalar()))))
502  s1 + s2 == Struct(
503  ('a', Scalar()),
504  ('b', Struct(
505  ('c', Scalar()),
506  ('d', Scalar()),
507  )),
508  )
509  """
510  if not isinstance(other, Struct):
511  return NotImplemented
512 
513  children = OrderedDict(self.get_children())
514  for name, right_field in other.get_children():
515  if name not in children:
516  children[name] = right_field
517  continue
518  left_field = children[name]
519  children[name] = left_field + right_field
520 
521  return Struct(*(viewitems(children)))
522 
523  def __sub__(self, other):
524  """
525  Allows to remove common fields of two schema.Struct from self by
526  using '-' operator. If two Struct have common field names, the
527  removal is conducted recursively. If a child struct has no fields
528  inside, it will be removed from its parent. Here are examples:
529 
530  Example 1
531  s1 = Struct(
532  ('a', Scalar()),
533  ('b', Scalar()),
534  )
535  s2 = Struct(('a', Scalar()))
536  s1 - s2 == Struct(('b', Scalar()))
537 
538  Example 2
539  s1 = Struct(
540  ('b', Struct(
541  ('c', Scalar()),
542  ('d', Scalar()),
543  ))
544  )
545  s2 = Struct(
546  ('b', Struct(('c', Scalar()))),
547  )
548  s1 - s2 == Struct(
549  ('b', Struct(
550  ('d', Scalar()),
551  )),
552  )
553 
554  Example 3
555  s1 = Struct(
556  ('a', Scalar()),
557  ('b', Struct(
558  ('d', Scalar()),
559  ))
560  )
561  s2 = Struct(
562  ('b', Struct(
563  ('c', Scalar())
564  ('d', Scalar())
565  )),
566  )
567  s1 - s2 == Struct(
568  ('a', Scalar()),
569  )
570  """
571  if not isinstance(other, Struct):
572  return NotImplemented
573 
574  children = OrderedDict(self.get_children())
575  for name, right_field in other.get_children():
576  if name in children:
577  left_field = children[name]
578  if type(left_field) == type(right_field):
579  if isinstance(left_field, Struct):
580  child = left_field - right_field
581  if child.get_children():
582  children[name] = child
583  continue
584  children.pop(name)
585  else:
586  raise TypeError(
587  "Type of left_field, " + str(type(left_field)) +
588  ", is not the same as that of right_field, " +
589  str(type(right_field)) +
590  ", yet they have the same field name, " + name)
591  return Struct(*(children.items()))
592 
593 
594 class Scalar(Field):
595  """Represents a typed scalar or tensor of fixed shape.
596 
597  A Scalar is a leaf in a schema tree, translating to exactly one tensor in
598  the dataset's underlying storage.
599 
600  Usually, the tensor storing the actual values of this field is a 1D tensor,
601  representing a series of values in its domain. It is possible however to
602  have higher rank values stored as a Scalar, as long as all entries have
603  the same shape.
604 
605  E.g.:
606 
607  Scalar(np.float64)
608 
609  Scalar field of type float64. Caffe2 will expect readers and
610  datasets to expose it as a 1D tensor of doubles (vector), where
611  the size of the vector is determined by this fields' domain.
612 
613  Scalar((np.int32, 5))
614 
615  Tensor field of type int32. Caffe2 will expect readers and
616  datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
617  where L is determined by this fields' domain.
618 
619  Scalar((str, (10, 20)))
620 
621  Tensor field of type str. Caffe2 will expect readers and
622  datasets to implement it as a 3D tensor of shape (L, 10, 20),
623  where L is determined by this fields' domain.
624 
625  If the field type is unknown at construction time, call Scalar(), that will
626  default to np.void as its dtype.
627 
628  It is an error to pass a structured dtype to Scalar, since it would contain
629  more than one field. Instead, use from_dtype, which will construct
630  a nested `Struct` field reflecting the given dtype's structure.
631 
632  A Scalar can also contain a blob, which represents the value of this
633  Scalar. A blob can be either a numpy.ndarray, in which case it contain the
634  actual contents of the Scalar, or a BlobReference, which represents a
635  blob living in a caffe2 Workspace. If blob of different types are passed,
636  a conversion to numpy.ndarray is attempted.
637  """
638 
639  def __init__(self, dtype=None, blob=None, metadata=None):
640  self._metadata = None
641  self.set(dtype, blob, metadata, unsafe=True)
642  Field.__init__(self, [])
643 
644  def field_names(self):
645  return ['']
646 
647  def field_type(self):
648  return self.dtype
649 
650  def field_types(self):
651  return [self.dtype]
652 
653  def field_metadata(self):
654  return [self._metadata]
655 
656  def has_blobs(self):
657  return self._blob is not None
658 
659  def field_blobs(self):
660  assert self._blob is not None, 'Value is not set for this field.'
661  return [self._blob]
662 
663  def all_scalars(self):
664  return [self]
665 
666  def clone(self, keep_blobs=True):
667  return Scalar(
668  dtype=self._original_dtype,
669  blob=self._blob if keep_blobs else None,
670  metadata=self._metadata
671  )
672 
673  def get(self):
674  """Gets the current blob of this Scalar field."""
675  assert self._blob is not None, 'Value is not set for this field.'
676  return self._blob
677 
678  def __call__(self):
679  """Shortcut for self.get()"""
680  return self.get()
681 
682  @property
683  def metadata(self):
684  return self._metadata
685 
686  def set_metadata(self, value):
687  assert isinstance(value, Metadata), \
688  'metadata must be Metadata, got {}'.format(type(value))
689  self._metadata = value
690  self._validate_metadata()
691 
692  def _validate_metadata(self):
693  if self._metadata is None:
694  return
695  if (self._metadata.categorical_limit is not None and
696  self.dtype is not None):
697  assert np.issubdtype(self.dtype, np.integer), \
698  "`categorical_limit` can be specified only in integral " + \
699  "fields but got {}".format(self.dtype)
700 
701  def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
702  """Sets only the blob field still validating the existing dtype"""
703  if self.dtype.base != np.void and throw_on_type_mismatch:
704  assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
705  assert blob.dtype.base == self.dtype.base, (
706  "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
707  self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
708 
709  def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
710  """Set the type and/or blob of this scalar. See __init__ for details.
711 
712  Args:
713  dtype: can be any numpy type. If not provided and `blob` is
714  provided, it will be inferred. If no argument is provided,
715  this Scalar will be of type np.void.
716  blob: if provided, can be either a BlobReference or a
717  numpy.ndarray. If a value of different type is passed,
718  a conversion to numpy.ndarray is attempted. Strings aren't
719  accepted, since they can be ambiguous. If you want to pass
720  a string, to either BlobReference(blob) or np.array(blob).
721  metadata: optional instance of Metadata, if provided overrides
722  the metadata information of the scalar
723  """
724  if not unsafe:
725  logger.warning(
726  "Scalar should be considered immutable. Only call Scalar.set() "
727  "on newly created Scalar with unsafe=True. This will become an "
728  "error soon."
729  )
730  if blob is not None and isinstance(blob, basestring):
731  raise ValueError(
732  'Passing str blob to Scalar.set() is ambiguous. '
733  'Do either set(blob=np.array(blob)) or '
734  'set(blob=BlobReference(blob))'
735  )
736 
737  self._original_dtype = dtype
738  if dtype is not None:
739  dtype = np.dtype(dtype)
740  # If blob is not None and it is not a BlobReference, we assume that
741  # it is actual tensor data, so we will try to cast it to a numpy array.
742  if blob is not None and not isinstance(blob, BlobReference):
743  preserve_shape = isinstance(blob, np.ndarray)
744  if dtype is not None and dtype != np.void:
745  blob = np.array(blob, dtype=dtype.base)
746  # if array is empty we may need to reshape a little
747  if blob.size == 0 and not preserve_shape:
748  blob = blob.reshape((0, ) + dtype.shape)
749  else:
750  assert isinstance(blob, np.ndarray), (
751  'Invalid blob type: %s' % str(type(blob)))
752 
753  # reshape scalars into 1D arrays
754  # TODO(azzolini): figure out better way of representing this
755  if len(blob.shape) == 0 and not preserve_shape:
756  blob = blob.reshape((1, ))
757 
758  # infer inner shape from the blob given
759  # TODO(dzhulgakov): tweak this to make it work with PackedStruct
760  if (len(blob.shape) > 1 and dtype is not None and
761  dtype.base != np.void):
762  dtype = np.dtype((dtype.base, blob.shape[1:]))
763  # if we were still unable to infer the dtype
764  if dtype is None:
765  dtype = np.dtype(np.void)
766  assert not dtype.fields, (
767  'Cannot create Scalar with a structured dtype. ' +
768  'Use from_dtype instead.'
769  )
770  self.dtype = dtype
771  self._blob = blob
772  if metadata is not None:
773  self.set_metadata(metadata)
774  self._validate_metadata()
775 
776  def set_type(self, dtype):
777  self._original_dtype = dtype
778  if dtype is not None:
779  self.dtype = np.dtype(dtype)
780  else:
781  self.dtype = np.dtype(np.void)
782  self._validate_metadata()
783 
784  def _pprint_impl(self, indent, str_buffer):
785  str_buffer.write(' ' * (indent) +
786  'Scalar({!r}, {!r}, {!r})'.format(
787  self.dtype, self._blob, self._metadata) + "\n")
788 
789  def id(self):
790  """
791  Return the zero-indexed position of this scalar field in its schema.
792  Used in order to index into the field_blob list returned by readers or
793  accepted by writers.
794  """
795  return self._child_base_id()
796 
797 
798 def Map(
799  keys,
800  values,
801  keys_name='keys',
802  values_name='values',
803  lengths_blob=None
804 ):
805  """A map is a List of Struct containing keys and values fields.
806  Optionally, you can provide custom name for the key and value fields.
807  """
808  return List(
809  Struct((keys_name, keys), (values_name, values)),
810  lengths_blob=lengths_blob
811  )
812 
813 
814 def NamedTuple(name_prefix, *fields):
815  return Struct(* [('%s_%d' % (name_prefix, i), field)
816  for i, field in enumerate(fields)])
817 
818 
819 def Tuple(*fields):
820  """
821  Creates a Struct with default, sequential, field names of given types.
822  """
823  return NamedTuple('field', *fields)
824 
825 
826 def RawTuple(num_fields, name_prefix='field'):
827  """
828  Creates a tuple of `num_field` untyped scalars.
829  """
830  assert isinstance(num_fields, int)
831  assert num_fields >= 0
832  return NamedTuple(name_prefix, *([np.void] * num_fields))
833 
834 
835 def from_dtype(dtype, _outer_shape=()):
836  """Constructs a Caffe2 schema from the given numpy's dtype.
837 
838  Numpy supports scalar, array-like and structured datatypes, as long as
839  all the shapes are fixed. This function breaks down the given dtype into
840  a Caffe2 schema containing `Struct` and `Scalar` types.
841 
842  Fields containing byte offsets are not currently supported.
843  """
844  if not isinstance(dtype, np.dtype):
845  # wrap into a ndtype
846  shape = _outer_shape
847  dtype = np.dtype((dtype, _outer_shape))
848  else:
849  # concatenate shapes if necessary
850  shape = _outer_shape + dtype.shape
851  if shape != dtype.shape:
852  dtype = np.dtype((dtype.base, shape))
853 
854  if not dtype.fields:
855  return Scalar(dtype)
856 
857  struct_fields = []
858  for name, (fdtype, offset) in dtype.fields:
859  assert offset == 0, ('Fields with byte offsets are not supported.')
860  struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
861  return Struct(*struct_fields)
862 
863 
864 class _SchemaNode(object):
865  """This is a private class used to represent a Schema Node"""
866 
867  def __init__(self, name, type_str=''):
868  self.name = name
869  self.children = []
870  self.type_str = type_str
871  self.field = None
872 
873  def add_child(self, name, type_str=''):
874  for child in self.children:
875  if child.name == name and child.type_str == type_str:
876  return child
877  child = _SchemaNode(name, type_str)
878  self.children.append(child)
879  return child
880 
881  def get_field(self):
882 
883  list_names = ['lengths', 'values']
884  map_names = ['lengths', 'keys', 'values']
885 
886  if len(self.children) == 0 or self.field is not None:
887  if self.field is None:
888  return Struct()
889  else:
890  return self.field
891 
892  child_names = []
893  for child in self.children:
894  child_names.append(child.name)
895 
896  if (set(child_names) == set(list_names)):
897  for child in self.children:
898  if child.name == 'values':
899  values_field = child.get_field()
900  else:
901  lengths_field = child.get_field()
902  self.field = List(
903  values_field,
904  lengths_blob=lengths_field
905  )
906  self.type_str = "List"
907  return self.field
908  elif (set(child_names) == set(map_names)):
909  for child in self.children:
910  if child.name == 'keys':
911  key_field = child.get_field()
912  elif child.name == 'values':
913  values_field = child.get_field()
914  else:
915  lengths_field = child.get_field()
916  self.field = Map(
917  key_field,
918  values_field,
919  lengths_blob=lengths_field
920  )
921  self.type_str = "Map"
922  return self.field
923 
924  else:
925  struct_fields = []
926  for child in self.children:
927  struct_fields.append((child.name, child.get_field()))
928 
929  self.field = Struct(*struct_fields)
930  self.type_str = "Struct"
931  return self.field
932 
933  def print_recursively(self):
934  for child in self.children:
935  child.print_recursively()
936  logger.info("Printing node: Name and type")
937  logger.info(self.name)
938  logger.info(self.type_str)
939 
940 
941 def from_column_list(
942  col_names, col_types=None,
943  col_blobs=None, col_metadata=None
944 ):
945  """
946  Given a list of names, types, and optionally values, construct a Schema.
947  """
948  if col_types is None:
949  col_types = [None] * len(col_names)
950  if col_metadata is None:
951  col_metadata = [None] * len(col_names)
952  if col_blobs is None:
953  col_blobs = [None] * len(col_names)
954  assert len(col_names) == len(col_types), (
955  'col_names and col_types must have the same length.'
956  )
957  assert len(col_names) == len(col_metadata), (
958  'col_names and col_metadata must have the same length.'
959  )
960  assert len(col_names) == len(col_blobs), (
961  'col_names and col_blobs must have the same length.'
962  )
963  root = _SchemaNode('root', 'Struct')
964  for col_name, col_type, col_blob, col_metadata in zip(
965  col_names, col_types, col_blobs, col_metadata
966  ):
967  columns = col_name.split(FIELD_SEPARATOR)
968  current = root
969  for i in range(len(columns)):
970  name = columns[i]
971  type_str = ''
972  field = None
973  if i == len(columns) - 1:
974  type_str = col_type
975  field = Scalar(
976  dtype=col_type,
977  blob=col_blob,
978  metadata=col_metadata
979  )
980  next = current.add_child(name, type_str)
981  if field is not None:
982  next.field = field
983  current = next
984 
985  return root.get_field()
986 
987 
988 def from_blob_list(schema, values, throw_on_type_mismatch=False):
989  """
990  Create a schema that clones the given schema, but containing the given
991  list of values.
992  """
993  assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
994  if isinstance(values, BlobReference):
995  values = [values]
996  record = schema.clone_schema()
997  scalars = record.all_scalars()
998  assert len(scalars) == len(values), (
999  'Values must have %d elements, got %d.' % (len(scalars), len(values))
1000  )
1001  for scalar, value in zip(scalars, values):
1002  scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
1003  return record
1004 
1005 
1006 def as_record(value):
1007  if isinstance(value, Field):
1008  return value
1009  elif isinstance(value, list) or isinstance(value, tuple):
1010  is_field_list = all(
1011  f is tuple and len(f) == 2 and isinstance(f[0], basestring)
1012  for f in value
1013  )
1014  if is_field_list:
1015  return Struct(* [(k, as_record(v)) for k, v in value])
1016  else:
1017  return Tuple(* [as_record(f) for f in value])
1018  elif isinstance(value, dict):
1019  return Struct(* [(k, as_record(v)) for k, v in viewitems(value)])
1020  else:
1021  return _normalize_field(value)
1022 
1023 
1024 def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1025  """
1026  Given a record containing BlobReferences, return a new record with same
1027  schema, containing numpy arrays, fetched from the current active workspace.
1028  """
1029 
1030  def fetch(v):
1031  if ws is None:
1032  return workspace.FetchBlob(str(v))
1033  else:
1034  return ws.blobs[str(v)].fetch()
1035 
1036  assert isinstance(blob_record, Field)
1037  field_blobs = blob_record.field_blobs()
1038  assert all(isinstance(v, BlobReference) for v in field_blobs)
1039  field_arrays = [fetch(value) for value in field_blobs]
1040  return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1041 
1042 
1043 def FeedRecord(blob_record, arrays, ws=None):
1044  """
1045  Given a Record containing blob_references and arrays, which is either
1046  a list of numpy arrays or a Record containing numpy arrays, feeds the
1047  record to the current workspace.
1048  """
1049 
1050  def feed(b, v):
1051  if ws is None:
1052  workspace.FeedBlob(str(b), v)
1053  else:
1054  ws.create_blob(str(b))
1055  ws.blobs[str(b)].feed(v)
1056 
1057  assert isinstance(blob_record, Field)
1058  field_blobs = blob_record.field_blobs()
1059  assert all(isinstance(v, BlobReference) for v in field_blobs)
1060  if isinstance(arrays, Field):
1061  # TODO: check schema
1062  arrays = arrays.field_blobs()
1063  assert len(arrays) == len(field_blobs), (
1064  'Values must contain exactly %d ndarrays.' % len(field_blobs)
1065  )
1066  for blob, array in zip(field_blobs, arrays):
1067  feed(blob, array)
1068 
1069 
1070 def NewRecord(net, schema):
1071  """
1072  Given a record of np.arrays, create a BlobReference for each one of them,
1073  returning a record containing BlobReferences. The name of each returned blob
1074  is NextScopedBlob(field_name), which guarantees unique name in the current
1075  net. Use NameScope explicitly to avoid name conflictions between different
1076  nets.
1077  """
1078  if isinstance(schema, Scalar):
1079  result = schema.clone()
1080  result.set_value(
1081  blob=net.NextScopedBlob('unnamed_scalar'),
1082  unsafe=True,
1083  )
1084  return result
1085 
1086  assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
1087  blob_refs = [
1088  net.NextScopedBlob(prefix=name)
1089  for name in schema.field_names()
1090  ]
1091  return from_blob_list(schema, blob_refs)
1092 
1093 
1094 def ConstRecord(net, array_record):
1095  """
1096  Given a record of arrays, returns a record of blobs,
1097  initialized with net.Const.
1098  """
1099  blob_record = NewRecord(net, array_record)
1100  for blob, array in zip(
1101  blob_record.field_blobs(), array_record.field_blobs()
1102  ):
1103  net.Const(array, blob)
1104  return blob_record
1105 
1106 
1107 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1108  if not schema_or_record.has_blobs():
1109  record = NewRecord(net, schema_or_record)
1110  else:
1111  record = schema_or_record
1112 
1113  for blob_type, blob in zip(record.field_types(), record.field_blobs()):
1114  try:
1115  data_type = data_type_for_dtype(blob_type)
1116  shape = [0] + list(blob_type.shape)
1117  net.ConstantFill([], blob, shape=shape, dtype=data_type)
1118  except TypeError:
1119  logger.warning("Blob {} has type error".format(blob))
1120  # If data_type_for_dtype doesn't know how to resolve given numpy
1121  # type to core.DataType, that function can throw type error (for
1122  # example that would happen for cases of unknown types such as
1123  # np.void). This is not a problem for cases when the record if going
1124  # to be overwritten by some operator later, though it might be an
1125  # issue for type/shape inference.
1126  if enforce_types:
1127  raise
1128  # If we don't enforce types for all items we'll create a blob with
1129  # the default ConstantFill (FLOAT, no shape)
1130  net.ConstantFill([], blob, shape=[0])
1131 
1132  return record
1133 
1134 
1135 _DATA_TYPE_FOR_DTYPE = [
1136  (np.str, core.DataType.STRING),
1137  (np.float16, core.DataType.FLOAT16),
1138  (np.float32, core.DataType.FLOAT),
1139  (np.float64, core.DataType.DOUBLE),
1140  (np.bool, core.DataType.BOOL),
1141  (np.int8, core.DataType.INT8),
1142  (np.int16, core.DataType.INT16),
1143  (np.int32, core.DataType.INT32),
1144  (np.int64, core.DataType.INT64),
1145  (np.uint8, core.DataType.UINT8),
1146  (np.uint16, core.DataType.UINT16),
1147 ]
1148 
1149 
1150 def is_schema_subset(schema, original_schema):
1151  # TODO add more checks
1152  return set(schema.field_names()).issubset(
1153  set(original_schema.field_names()))
1154 
1155 
1156 def equal_schemas(schema,
1157  original_schema,
1158  check_field_names=True,
1159  check_field_types=True,
1160  check_field_metas=False):
1161  assert isinstance(schema, Field)
1162  assert isinstance(original_schema, Field)
1163 
1164  if check_field_names and (
1165  schema.field_names() != original_schema.field_names()):
1166  return False
1167  if check_field_types and (
1168  schema.field_types() != original_schema.field_types()):
1169  return False
1170  if check_field_metas and (
1171  schema.field_metadata() != original_schema.field_metadata()):
1172  return False
1173 
1174  return True
1175 
1176 
1177 def schema_check(schema, previous=None):
1178  record = as_record(schema)
1179  if previous is not None:
1180  assert equal_schemas(schema, previous)
1181  return record
1182 
1183 
1184 def data_type_for_dtype(dtype):
1185  for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1186  if dtype.base == np_type:
1187  return dt
1188  raise TypeError('Unknown dtype: ' + str(dtype.base))
1189 
1190 
1191 def attach_metadata_to_scalars(field, metadata):
1192  for f in field.all_scalars():
1193  f.set_metadata(metadata)
def __init__(self, fields)
Definition: schema.py:297
def set(self, dtype=None, blob=None, metadata=None, unsafe=False)
Definition: schema.py:709
def __getattr__(self, item)
Definition: schema.py:267
def set_metadata(self, value)
Definition: schema.py:686
def field_metadata(self)
Definition: schema.py:136
def __add__(self, other)
Definition: schema.py:482
def get(self, item, default_value)
Definition: schema.py:456
def _pprint_impl(self, indent, str_buffer)
Definition: schema.py:201
def __getitem__(self, item)
Definition: schema.py:431
def clone(self, keep_blobs=True)
Definition: schema.py:156
def _child_base_id(self, child_index=None)
Definition: schema.py:185
def _struct_from_nested_name(self, nested_name, field)
Definition: schema.py:346
def __sub__(self, other)
Definition: schema.py:523
def _validate_metadata(self)
Definition: schema.py:692
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False)
Definition: schema.py:701
def __eq__(self, other)
Definition: schema.py:193
def _get_field_by_nested_name(self, nested_name)
Definition: schema.py:402
def __init__(self, children)
Definition: schema.py:115