Caffe2 - Python API
A deep learning, cross platform ML framework
schema.py
1 ## @package schema
2 # Module caffe2.python.schema
3 """
4 Defines a minimal set of data types that allow to represent datasets with
5 arbitrary nested structure, including objects of variable length, such as
6 maps and lists.
7 
8 This defines a columnar storage format for such datasets on top of caffe2
9 tensors. In terms of capacity of representation, it can represent most of
10 the data types supported by Parquet, ORC, DWRF file formats.
11 
12 See comments in operator_test/dataset_ops_test.py for an example and
13 walkthrough on how to use schema to store and iterate through a structured
14 in-memory dataset.
15 """
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import logging
22 import numpy as np
23 from caffe2.python import core
24 from caffe2.python import workspace
25 from caffe2.python.core import BlobReference
26 from collections import OrderedDict, namedtuple
27 from past.builtins import basestring
28 from future.utils import viewitems, viewkeys, viewvalues
29 from itertools import islice
30 from six import StringIO
31 
32 logger = logging.getLogger(__name__)
33 logger.setLevel(logging.INFO)
34 
35 FIELD_SEPARATOR = ':'
36 
37 
38 def _join_field_name(prefix, suffix):
39  if prefix and suffix:
40  return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
41  elif prefix:
42  return prefix
43  elif suffix:
44  return suffix
45  else:
46  return ''
47 
48 
49 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
50  """Clones/normalizes a field before adding it to a container."""
51  if isinstance(field_or_type_or_blob, Field):
52  return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
53  elif type(field_or_type_or_blob) in (type, np.dtype):
54  return Scalar(dtype=field_or_type_or_blob)
55  else:
56  return Scalar(blob=field_or_type_or_blob)
57 
58 
59 FeatureSpec = namedtuple(
60  'FeatureSpec',
61  [
62  'feature_type',
63  'feature_names',
64  'feature_ids',
65  'feature_is_request_only',
66  'desired_hash_size',
67  'feature_to_index',
68  ]
69 )
70 
71 FeatureSpec.__new__.__defaults__ = (None, None, None, None, None, None)
72 
73 
74 class Metadata(
75  namedtuple(
76  'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
77  )
78 ):
79  """Represents additional information associated with a scalar in schema.
80 
81  `categorical_limit` - for fields of integral type that are guaranteed to be
82  non-negative it specifies the maximum possible value plus one. It's often
83  used as a size of an embedding table.
84 
85  `expected_value` - anticipated average value of elements in the field.
86  Usually makes sense for length fields of lists.
87 
88  `feature_specs` - information about the features that contained in this
89  field. For example if field have more than 1 feature it can have list of
90  feature names contained in this field."""
91  __slots__ = ()
92 
93 
94 Metadata.__new__.__defaults__ = (None, None, None)
95 
96 
97 class Field(object):
98  """Represents an abstract field type in a dataset.
99  """
100 
101  def __init__(self, children):
102  """Derived classes must call this after their initialization."""
103  self._parent = (None, 0)
104  offset = 0
105  self._field_offsets = []
106  for child in children:
107  self._field_offsets.append(offset)
108  offset += len(child.field_names())
109  self._field_offsets.append(offset)
110 
111  def clone_schema(self):
112  return self.clone(keep_blobs=False)
113 
114  def field_names(self):
115  """Return the children field names for this field."""
116  raise NotImplementedError('Field is an abstract class.')
117 
118  def field_types(self):
119  """Return the numpy.dtype for each of the children fields."""
120  raise NotImplementedError('Field is an abstract class.')
121 
122  def field_metadata(self):
123  """Return the Metadata for each of the children fields."""
124  raise NotImplementedError('Field is an abstract class.')
125 
126  def field_blobs(self):
127  """Return the list of blobs with contents for this Field.
128  Values can either be all numpy.ndarray or BlobReference.
129  If any of the fields doens't have a blob, throws.
130  """
131  raise NotImplementedError('Field is an abstract class.')
132 
133  def all_scalars(self):
134  """Return the list of all Scalar instances in the Field.
135  The order is the same as for field_names() or field_blobs()"""
136  raise NotImplementedError('Field is an abstract class.')
137 
138  def has_blobs(self):
139  """Return True if every scalar of this field has blobs."""
140  raise NotImplementedError('Field is an abstract class.')
141 
142  def clone(self, keep_blobs=True):
143  """Clone this Field along with its children."""
144  raise NotImplementedError('Field is an abstract class.')
145 
146  def _set_parent(self, parent, relative_id):
147  self._parent = (parent, relative_id)
148 
149  def slice(self):
150  """
151  Returns a slice representing the range of field ids that belong to
152  this field. This slice can be used to index a list of fields.
153 
154  E.g.:
155 
156  >>> s = Struct(
157  >>> ('a', Scalar()),
158  >>> ('b', Struct(
159  >>> ('b1', Scalar()),
160  >>> ('b2', Scalar()),
161  >>> )),
162  >>> ('c', Scalar()),
163  >>> )
164  >>> field_data = ['da', 'db1', 'db2', 'dc']
165  >>> field_data[s.b.split()]
166  ['db1', 'db2']
167  """
168  base_id = self._child_base_id()
169  return slice(base_id, base_id + len(self.field_names()))
170 
171  def _child_base_id(self, child_index=None):
172  """Get the base id of the given child"""
173  p, i = self._parent
174  pos = 0 if child_index is None else self._field_offsets[child_index]
175  if p:
176  pos += p._child_base_id(i)
177  return pos
178 
179  def __eq__(self, other):
180  """Equivalance of two schemas"""
181  return (
182  (self.field_names() == other.field_names()) and
183  (self.field_types() == other.field_types()) and
184  (self.field_metadata() == other.field_metadata())
185  )
186 
187  def _pprint_impl(self, indent, str_buffer):
188  raise NotImplementedError('Field is an abstrct class.')
189 
190  def __repr__(self):
191  str_buffer = StringIO()
192  self._pprint_impl(0, str_buffer)
193  contents = str_buffer.getvalue()
194  str_buffer.close()
195  return contents
196 
197 
198 class List(Field):
199  """Represents a variable-length list.
200 
201  Values of a list can also be complex fields such as Lists and Structs.
202  In addition to the fields exposed by its `values` field, a List exposes an
203  additional `lengths` field, which will contain the size of each list under
204  the parent domain.
205  """
206 
207  def __init__(self, values, lengths_blob=None):
208  if isinstance(lengths_blob, Field):
209  assert isinstance(lengths_blob, Scalar)
210  self.lengths = _normalize_field(lengths_blob)
211  else:
212  self.lengths = Scalar(np.int32, lengths_blob)
213  self._items = _normalize_field(values)
214  self.lengths._set_parent(self, 0)
215  self._items._set_parent(self, 1)
216  Field.__init__(self, [self.lengths, self._items])
217 
218  def field_names(self):
219  value_fields = self._items.field_names()
220  return (
221  ['lengths'] + [_join_field_name('values', v) for v in value_fields]
222  )
223 
224  def field_types(self):
225  return self.lengths.field_types() + self._items.field_types()
226 
227  def field_metadata(self):
228  return self.lengths.field_metadata() + self._items.field_metadata()
229 
230  def field_blobs(self):
231  return self.lengths.field_blobs() + self._items.field_blobs()
232 
233  def all_scalars(self):
234  return self.lengths.all_scalars() + self._items.all_scalars()
235 
236  def has_blobs(self):
237  return self.lengths.has_blobs() and self._items.has_blobs()
238 
239  def clone(self, keep_blobs=True):
240  return type(self)(
241  _normalize_field(self._items, keep_blobs=keep_blobs),
242  _normalize_field(self.lengths, keep_blobs=keep_blobs)
243  )
244 
245  def _pprint_impl(self, indent, str_buffer):
246  str_buffer.write(' ' * indent + "List(\n")
247  str_buffer.write(' ' * (indent + 1) + "lengths=\n")
248  self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
249  str_buffer.write(' ' * (indent + 1) + "_items=\n")
250  self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
251  str_buffer.write(' ' * indent + ")\n")
252 
253  def __getattr__(self, item):
254  """If the value of this list is a struct,
255  allow to introspect directly into its fields."""
256  if item.startswith('__'):
257  raise AttributeError(item)
258  if isinstance(self._items, Struct):
259  return getattr(self._items, item)
260  elif item == 'value' or item == 'items':
261  return self._items
262  else:
263  raise AttributeError('Field not found in list: %s.' % item)
264 
265  def __getitem__(self, item):
266  names = item.split(FIELD_SEPARATOR, 1)
267 
268  if len(names) == 1:
269  if item == 'lengths':
270  return self.lengths
271  elif item == 'values':
272  return self._items
273  else:
274  if names[0] == 'values':
275  return self._items[names[1]]
276  raise KeyError('Field not found in list: %s.' % item)
277 
278 
279 class Struct(Field):
280  """Represents a named list of fields sharing the same domain.
281  """
282 
283  def __init__(self, *fields):
284  """ fields is a list of tuples in format of (name, field). The name is
285  a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
286 
287  Struct(
288  ('a', Scalar()),
289  ('b:c', Scalar()),
290  ('b:d:e', Scalar()),
291  ('b', Struct(
292  ('f', Scalar()),
293  )),
294  )
295 
296  is equal to
297 
298  Struct(
299  ('a', Scalar()),
300  ('b', Struct(
301  ('c', Scalar()),
302  ('d', Struct(('e', Scalar()))),
303  ('f', Scalar()),
304  )),
305  )
306  """
307  for field in fields:
308  assert len(field) == 2
309  assert field[0], 'Field names cannot be empty'
310  assert field[0] != 'lengths', (
311  'Struct cannot contain a field named `lengths`.'
312  )
313  fields = [(name, _normalize_field(field)) for name, field in fields]
314  self.fields = OrderedDict()
315  for name, field in fields:
316  if FIELD_SEPARATOR in name:
317  name, field = self._struct_from_nested_name(name, field)
318  if name not in self.fields:
319  self.fields[name] = field
320  continue
321  if (
322  not isinstance(field, Struct) or
323  not isinstance(self.fields[name], Struct)
324  ):
325  raise ValueError('Duplicate field name: %s' % name)
326  self.fields[name] = self.fields[name] + field
327  for id, (_, field) in enumerate(viewitems(self.fields)):
328  field._set_parent(self, id)
329  Field.__init__(self, viewvalues(self.fields))
330  self._frozen = True
331 
332  def _struct_from_nested_name(self, nested_name, field):
333  def create_internal(nested_name, field):
334  names = nested_name.split(FIELD_SEPARATOR, 1)
335  if len(names) == 1:
336  added_field = field
337  else:
338  added_field = create_internal(names[1], field)
339  return Struct((names[0], added_field))
340 
341  names = nested_name.split(FIELD_SEPARATOR, 1)
342  assert len(names) >= 2
343  return names[0], create_internal(names[1], field)
344 
345  def get_children(self):
346  return list(viewitems(self.fields))
347 
348  def field_names(self):
349  names = []
350  for name, field in viewitems(self.fields):
351  names += [_join_field_name(name, f) for f in field.field_names()]
352  return names
353 
354  def field_types(self):
355  types = []
356  for _, field in viewitems(self.fields):
357  types += field.field_types()
358  return types
359 
360  def field_metadata(self):
361  metadata = []
362  for _, field in viewitems(self.fields):
363  metadata += field.field_metadata()
364  return metadata
365 
366  def field_blobs(self):
367  blobs = []
368  for _, field in viewitems(self.fields):
369  blobs += field.field_blobs()
370  return blobs
371 
372  def all_scalars(self):
373  scalars = []
374  for _, field in viewitems(self.fields):
375  scalars += field.all_scalars()
376  return scalars
377 
378  def has_blobs(self):
379  return all(field.has_blobs() for field in viewvalues(self.fields))
380 
381  def clone(self, keep_blobs=True):
382  normalized_fields = [
383  (k, _normalize_field(v, keep_blobs=keep_blobs))
384  for k, v in viewitems(self.fields)
385  ]
386  return type(self)(*normalized_fields)
387 
388  def _get_field_by_nested_name(self, nested_name):
389  names = nested_name.split(FIELD_SEPARATOR, 1)
390  field = self.fields.get(names[0], None)
391 
392  if field is None:
393  return None
394 
395  if len(names) == 1:
396  return field
397 
398  try:
399  return field[names[1]]
400  except (KeyError, TypeError):
401  return None
402 
403  def _pprint_impl(self, indent, str_buffer):
404  str_buffer.write(' ' * indent + "Struct( \n")
405  for name, field in viewitems(self.fields):
406  str_buffer.write(' ' * (indent + 1) + "{}=".format(name) + "\n")
407  field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
408  str_buffer.write(' ' * indent + ") \n")
409 
410  def __contains__(self, item):
411  field = self._get_field_by_nested_name(item)
412  return field is not None
413 
414  def __len__(self):
415  return len(self.fields)
416 
417  def __getitem__(self, item):
418  """
419  item can be a tuple or list of ints or strings, or a single
420  int or string. String item is a nested field name, e.g., "a", "a:b",
421  "a:b:c". Int item is the index of a field at the first level of the
422  Struct.
423  """
424  if isinstance(item, list) or isinstance(item, tuple):
425  keys = list(viewkeys(self.fields))
426  return Struct(
427  * [
428  (
429  keys[k]
430  if isinstance(k, int) else k, self[k]
431  ) for k in item
432  ]
433  )
434  elif isinstance(item, int):
435  return next(islice(viewvalues(self.fields), item, None))
436  else:
437  field = self._get_field_by_nested_name(item)
438  if field is None:
439  raise KeyError('field "%s" not found' % (item))
440  return field
441 
442  def get(self, item, default_value):
443  """
444  similar to python's dictionary get method, return field of item if found
445  (i.e. self.item is valid) or otherwise return default_value
446 
447  it's a syntax suger of python's builtin getattr method
448  """
449  return getattr(self, item, default_value)
450 
451  def __getattr__(self, item):
452  if item.startswith('__'):
453  raise AttributeError(item)
454  try:
455  return self.__dict__['fields'][item]
456  except KeyError:
457  raise AttributeError(item)
458 
459  def __setattr__(self, key, value):
460  # Disable setting attributes after initialization to prevent false
461  # impression of being able to overwrite a field.
462  # Allowing setting internal states mainly so that _parent can be set
463  # post initialization.
464  if getattr(self, '_frozen', None) and not key.startswith('_'):
465  raise TypeError('Struct.__setattr__() is disabled after __init__()')
466  super(Struct, self).__setattr__(key, value)
467 
468  def __add__(self, other):
469  """
470  Allows to merge fields of two schema.Struct using '+' operator.
471  If two Struct have common field names, the merge is conducted
472  recursively. Here are examples:
473 
474  Example 1
475  s1 = Struct(('a', Scalar()))
476  s2 = Struct(('b', Scalar()))
477  s1 + s2 == Struct(
478  ('a', Scalar()),
479  ('b', Scalar()),
480  )
481 
482  Example 2
483  s1 = Struct(
484  ('a', Scalar()),
485  ('b', Struct(('c', Scalar()))),
486  )
487  s2 = Struct(('b', Struct(('d', Scalar()))))
488  s1 + s2 == Struct(
489  ('a', Scalar()),
490  ('b', Struct(
491  ('c', Scalar()),
492  ('d', Scalar()),
493  )),
494  )
495  """
496  if not isinstance(other, Struct):
497  return NotImplemented
498 
499  children = OrderedDict(self.get_children())
500  for name, right_field in other.get_children():
501  if name not in children:
502  children[name] = right_field
503  continue
504  left_field = children[name]
505  children[name] = left_field + right_field
506 
507  return Struct(*(viewitems(children)))
508 
509  def __sub__(self, other):
510  """
511  Allows to remove common fields of two schema.Struct from self by
512  using '-' operator. If two Struct have common field names, the
513  removal is conducted recursively. If a child struct has no fields
514  inside, it will be removed from its parent. Here are examples:
515 
516  Example 1
517  s1 = Struct(
518  ('a', Scalar()),
519  ('b', Scalar()),
520  )
521  s2 = Struct(('a', Scalar()))
522  s1 - s2 == Struct(('b', Scalar()))
523 
524  Example 2
525  s1 = Struct(
526  ('b', Struct(
527  ('c', Scalar()),
528  ('d', Scalar()),
529  ))
530  )
531  s2 = Struct(
532  ('b', Struct(('c', Scalar()))),
533  )
534  s1 - s2 == Struct(
535  ('b', Struct(
536  ('d', Scalar()),
537  )),
538  )
539 
540  Example 3
541  s1 = Struct(
542  ('a', Scalar()),
543  ('b', Struct(
544  ('d', Scalar()),
545  ))
546  )
547  s2 = Struct(
548  ('b', Struct(
549  ('c', Scalar())
550  ('d', Scalar())
551  )),
552  )
553  s1 - s2 == Struct(
554  ('a', Scalar()),
555  )
556  """
557  if not isinstance(other, Struct):
558  return NotImplemented
559 
560  children = OrderedDict(self.get_children())
561  for name, right_field in other.get_children():
562  if name in children:
563  left_field = children[name]
564  if type(left_field) == type(right_field):
565  if isinstance(left_field, Struct):
566  child = left_field - right_field
567  if child.get_children():
568  children[name] = child
569  continue
570  children.pop(name)
571  else:
572  raise TypeError(
573  "Type of left_field, " + str(type(left_field)) +
574  ", is not the same as that of right_field, " +
575  str(type(right_field)) +
576  ", yet they have the same field name, " + name)
577  return Struct(*(children.items()))
578 
579 
580 class Scalar(Field):
581  """Represents a typed scalar or tensor of fixed shape.
582 
583  A Scalar is a leaf in a schema tree, translating to exactly one tensor in
584  the dataset's underlying storage.
585 
586  Usually, the tensor storing the actual values of this field is a 1D tensor,
587  representing a series of values in its domain. It is possible however to
588  have higher rank values stored as a Scalar, as long as all entries have
589  the same shape.
590 
591  E.g.:
592 
593  Scalar(np.float64)
594 
595  Scalar field of type float64. Caffe2 will expect readers and
596  datasets to expose it as a 1D tensor of doubles (vector), where
597  the size of the vector is determined by this fields' domain.
598 
599  Scalar((np.int32, 5))
600 
601  Tensor field of type int32. Caffe2 will expect readers and
602  datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
603  where L is determined by this fields' domain.
604 
605  Scalar((str, (10, 20)))
606 
607  Tensor field of type str. Caffe2 will expect readers and
608  datasets to implement it as a 3D tensor of shape (L, 10, 20),
609  where L is determined by this fields' domain.
610 
611  If the field type is unknown at construction time, call Scalar(), that will
612  default to np.void as its dtype.
613 
614  It is an error to pass a structured dtype to Scalar, since it would contain
615  more than one field. Instead, use from_dtype, which will construct
616  a nested `Struct` field reflecting the given dtype's structure.
617 
618  A Scalar can also contain a blob, which represents the value of this
619  Scalar. A blob can be either a numpy.ndarray, in which case it contain the
620  actual contents of the Scalar, or a BlobReference, which represents a
621  blob living in a caffe2 Workspace. If blob of different types are passed,
622  a conversion to numpy.ndarray is attempted.
623  """
624 
625  def __init__(self, dtype=None, blob=None, metadata=None):
626  self._metadata = None
627  self.set(dtype, blob, metadata, unsafe=True)
628  Field.__init__(self, [])
629 
630  def field_names(self):
631  return ['']
632 
633  def field_type(self):
634  return self.dtype
635 
636  def field_types(self):
637  return [self.dtype]
638 
639  def field_metadata(self):
640  return [self._metadata]
641 
642  def has_blobs(self):
643  return self._blob is not None
644 
645  def field_blobs(self):
646  assert self._blob is not None, 'Value is not set for this field.'
647  return [self._blob]
648 
649  def all_scalars(self):
650  return [self]
651 
652  def clone(self, keep_blobs=True):
653  return Scalar(
654  dtype=self._original_dtype,
655  blob=self._blob if keep_blobs else None,
656  metadata=self._metadata
657  )
658 
659  def get(self):
660  """Gets the current blob of this Scalar field."""
661  assert self._blob is not None, 'Value is not set for this field.'
662  return self._blob
663 
664  def __call__(self):
665  """Shortcut for self.get()"""
666  return self.get()
667 
668  @property
669  def metadata(self):
670  return self._metadata
671 
672  def set_metadata(self, value):
673  assert isinstance(value, Metadata), \
674  'metadata must be Metadata, got {}'.format(type(value))
675  self._metadata = value
676  self._validate_metadata()
677 
678  def _validate_metadata(self):
679  if self._metadata is None:
680  return
681  if (self._metadata.categorical_limit is not None and
682  self.dtype is not None):
683  assert np.issubdtype(self.dtype, np.integer), \
684  "`categorical_limit` can be specified only in integral " + \
685  "fields but got {}".format(self.dtype)
686 
687  def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
688  """Sets only the blob field still validating the existing dtype"""
689  if self.dtype.base != np.void and throw_on_type_mismatch:
690  assert isinstance(blob, np.ndarray), "Got {!r}".format(blob)
691  assert blob.dtype.base == self.dtype.base, (
692  "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
693  self.set(dtype=self._original_dtype, blob=blob, unsafe=unsafe)
694 
695  def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
696  """Set the type and/or blob of this scalar. See __init__ for details.
697 
698  Args:
699  dtype: can be any numpy type. If not provided and `blob` is
700  provided, it will be inferred. If no argument is provided,
701  this Scalar will be of type np.void.
702  blob: if provided, can be either a BlobReference or a
703  numpy.ndarray. If a value of different type is passed,
704  a conversion to numpy.ndarray is attempted. Strings aren't
705  accepted, since they can be ambiguous. If you want to pass
706  a string, to either BlobReference(blob) or np.array(blob).
707  metadata: optional instance of Metadata, if provided overrides
708  the metadata information of the scalar
709  """
710  if not unsafe:
711  logger.warning(
712  "Scalar should be considered immutable. Only call Scalar.set() "
713  "on newly created Scalar with unsafe=True. This will become an "
714  "error soon."
715  )
716  if blob is not None and isinstance(blob, basestring):
717  raise ValueError(
718  'Passing str blob to Scalar.set() is ambiguous. '
719  'Do either set(blob=np.array(blob)) or '
720  'set(blob=BlobReference(blob))'
721  )
722 
723  self._original_dtype = dtype
724  # Numpy will collapse a shape of 1 into an unindexed data array (shape = ()),
725  # which betrays the docstring of this class (which expects shape = (1,)).
726  # >>> import numpy as np
727  # >> np.dtype((np.int32, 1))
728  # dtype('int32')
729  # >>> np.dtype((np.int32, 5))
730  # dtype(('<i4', (5,)))
731  if dtype is not None and isinstance(dtype, tuple) and dtype[1] == 1:
732  dtype = (dtype[0], (1,))
733  if dtype is not None:
734  if isinstance(dtype, tuple) and dtype[0] == np.void:
735  raise TypeError(
736  "Cannot set the Scalar with type {} for blob {}."
737  "If this blob is the output of some operation, "
738  "please verify the input of that operation has "
739  "proper type.".format(dtype, blob)
740  )
741  dtype = np.dtype(dtype)
742  # If blob is not None and it is not a BlobReference, we assume that
743  # it is actual tensor data, so we will try to cast it to a numpy array.
744  if blob is not None and not isinstance(blob, BlobReference):
745  preserve_shape = isinstance(blob, np.ndarray)
746  if dtype is not None and dtype != np.void:
747  blob = np.array(blob, dtype=dtype.base)
748  # if array is empty we may need to reshape a little
749  if blob.size == 0 and not preserve_shape:
750  blob = blob.reshape((0, ) + dtype.shape)
751  else:
752  assert isinstance(blob, np.ndarray), (
753  'Invalid blob type: %s' % str(type(blob)))
754 
755  # reshape scalars into 1D arrays
756  # TODO(azzolini): figure out better way of representing this
757  if len(blob.shape) == 0 and not preserve_shape:
758  blob = blob.reshape((1, ))
759 
760  # infer inner shape from the blob given
761  # TODO(dzhulgakov): tweak this to make it work with PackedStruct
762  if (len(blob.shape) > 1 and dtype is not None and
763  dtype.base != np.void):
764  dtype = np.dtype((dtype.base, blob.shape[1:]))
765  # if we were still unable to infer the dtype
766  if dtype is None:
767  dtype = np.dtype(np.void)
768  assert not dtype.fields, (
769  'Cannot create Scalar with a structured dtype. ' +
770  'Use from_dtype instead.'
771  )
772  self.dtype = dtype
773  self._blob = blob
774  if metadata is not None:
775  self.set_metadata(metadata)
776  self._validate_metadata()
777 
778  def set_type(self, dtype):
779  self._original_dtype = dtype
780  if dtype is not None:
781  self.dtype = np.dtype(dtype)
782  else:
783  self.dtype = np.dtype(np.void)
784  self._validate_metadata()
785 
786  def _pprint_impl(self, indent, str_buffer):
787  str_buffer.write(' ' * (indent) +
788  'Scalar({!r}, {!r}, {!r})'.format(
789  self.dtype, self._blob, self._metadata) + "\n")
790 
791  def id(self):
792  """
793  Return the zero-indexed position of this scalar field in its schema.
794  Used in order to index into the field_blob list returned by readers or
795  accepted by writers.
796  """
797  return self._child_base_id()
798 
799 
800 def Map(
801  keys,
802  values,
803  keys_name='keys',
804  values_name='values',
805  lengths_blob=None
806 ):
807  """A map is a List of Struct containing keys and values fields.
808  Optionally, you can provide custom name for the key and value fields.
809  """
810  return List(
811  Struct((keys_name, keys), (values_name, values)),
812  lengths_blob=lengths_blob
813  )
814 
815 
816 def NamedTuple(name_prefix, *fields):
817  return Struct(* [('%s_%d' % (name_prefix, i), field)
818  for i, field in enumerate(fields)])
819 
820 
821 def Tuple(*fields):
822  """
823  Creates a Struct with default, sequential, field names of given types.
824  """
825  return NamedTuple('field', *fields)
826 
827 
828 def RawTuple(num_fields, name_prefix='field'):
829  """
830  Creates a tuple of `num_field` untyped scalars.
831  """
832  assert isinstance(num_fields, int)
833  assert num_fields >= 0
834  return NamedTuple(name_prefix, *([np.void] * num_fields))
835 
836 
837 def from_dtype(dtype, _outer_shape=()):
838  """Constructs a Caffe2 schema from the given numpy's dtype.
839 
840  Numpy supports scalar, array-like and structured datatypes, as long as
841  all the shapes are fixed. This function breaks down the given dtype into
842  a Caffe2 schema containing `Struct` and `Scalar` types.
843 
844  Fields containing byte offsets are not currently supported.
845  """
846  if not isinstance(dtype, np.dtype):
847  # wrap into a ndtype
848  shape = _outer_shape
849  dtype = np.dtype((dtype, _outer_shape))
850  else:
851  # concatenate shapes if necessary
852  shape = _outer_shape + dtype.shape
853  if shape != dtype.shape:
854  dtype = np.dtype((dtype.base, shape))
855 
856  if not dtype.fields:
857  return Scalar(dtype)
858 
859  struct_fields = []
860  for name, (fdtype, offset) in dtype.fields:
861  assert offset == 0, ('Fields with byte offsets are not supported.')
862  struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
863  return Struct(*struct_fields)
864 
865 
866 class _SchemaNode(object):
867  """This is a private class used to represent a Schema Node"""
868 
869  def __init__(self, name, type_str=''):
870  self.name = name
871  self.children = []
872  self.type_str = type_str
873  self.field = None
874 
875  def add_child(self, name, type_str=''):
876  for child in self.children:
877  if child.name == name and child.type_str == type_str:
878  return child
879  child = _SchemaNode(name, type_str)
880  self.children.append(child)
881  return child
882 
883  def get_field(self):
884 
885  list_names = ['lengths', 'values']
886  map_names = ['lengths', 'keys', 'values']
887 
888  if len(self.children) == 0 or self.field is not None:
889  if self.field is None:
890  return Struct()
891  else:
892  return self.field
893 
894  child_names = []
895  for child in self.children:
896  child_names.append(child.name)
897 
898  if (set(child_names) == set(list_names)):
899  for child in self.children:
900  if child.name == 'values':
901  values_field = child.get_field()
902  else:
903  lengths_field = child.get_field()
904  self.field = List(
905  values_field,
906  lengths_blob=lengths_field
907  )
908  self.type_str = "List"
909  return self.field
910  elif (set(child_names) == set(map_names)):
911  for child in self.children:
912  if child.name == 'keys':
913  key_field = child.get_field()
914  elif child.name == 'values':
915  values_field = child.get_field()
916  else:
917  lengths_field = child.get_field()
918  self.field = Map(
919  key_field,
920  values_field,
921  lengths_blob=lengths_field
922  )
923  self.type_str = "Map"
924  return self.field
925 
926  else:
927  struct_fields = []
928  for child in self.children:
929  struct_fields.append((child.name, child.get_field()))
930 
931  self.field = Struct(*struct_fields)
932  self.type_str = "Struct"
933  return self.field
934 
935  def print_recursively(self):
936  for child in self.children:
937  child.print_recursively()
938  logger.info("Printing node: Name and type")
939  logger.info(self.name)
940  logger.info(self.type_str)
941 
942 
943 def from_column_list(
944  col_names, col_types=None,
945  col_blobs=None, col_metadata=None
946 ):
947  """
948  Given a list of names, types, and optionally values, construct a Schema.
949  """
950  if col_types is None:
951  col_types = [None] * len(col_names)
952  if col_metadata is None:
953  col_metadata = [None] * len(col_names)
954  if col_blobs is None:
955  col_blobs = [None] * len(col_names)
956  assert len(col_names) == len(col_types), (
957  'col_names and col_types must have the same length.'
958  )
959  assert len(col_names) == len(col_metadata), (
960  'col_names and col_metadata must have the same length.'
961  )
962  assert len(col_names) == len(col_blobs), (
963  'col_names and col_blobs must have the same length.'
964  )
965  root = _SchemaNode('root', 'Struct')
966  for col_name, col_type, col_blob, col_metadata in zip(
967  col_names, col_types, col_blobs, col_metadata
968  ):
969  columns = col_name.split(FIELD_SEPARATOR)
970  current = root
971  for i in range(len(columns)):
972  name = columns[i]
973  type_str = ''
974  field = None
975  if i == len(columns) - 1:
976  type_str = col_type
977  field = Scalar(
978  dtype=col_type,
979  blob=col_blob,
980  metadata=col_metadata
981  )
982  next = current.add_child(name, type_str)
983  if field is not None:
984  next.field = field
985  current = next
986 
987  return root.get_field()
988 
989 
990 def from_blob_list(schema, values, throw_on_type_mismatch=False):
991  """
992  Create a schema that clones the given schema, but containing the given
993  list of values.
994  """
995  assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
996  if isinstance(values, BlobReference):
997  values = [values]
998  record = schema.clone_schema()
999  scalars = record.all_scalars()
1000  assert len(scalars) == len(values), (
1001  'Values must have %d elements, got %d.' % (len(scalars), len(values))
1002  )
1003  for scalar, value in zip(scalars, values):
1004  scalar.set_value(value, throw_on_type_mismatch, unsafe=True)
1005  return record
1006 
1007 
1008 def as_record(value):
1009  if isinstance(value, Field):
1010  return value
1011  elif isinstance(value, list) or isinstance(value, tuple):
1012  is_field_list = all(
1013  f is tuple and len(f) == 2 and isinstance(f[0], basestring)
1014  for f in value
1015  )
1016  if is_field_list:
1017  return Struct(* [(k, as_record(v)) for k, v in value])
1018  else:
1019  return Tuple(* [as_record(f) for f in value])
1020  elif isinstance(value, dict):
1021  return Struct(* [(k, as_record(v)) for k, v in viewitems(value)])
1022  else:
1023  return _normalize_field(value)
1024 
1025 
1026 def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1027  """
1028  Given a record containing BlobReferences, return a new record with same
1029  schema, containing numpy arrays, fetched from the current active workspace.
1030  """
1031 
1032  def fetch(v):
1033  if ws is None:
1034  return workspace.FetchBlob(str(v))
1035  else:
1036  return ws.blobs[str(v)].fetch()
1037 
1038  assert isinstance(blob_record, Field)
1039  field_blobs = blob_record.field_blobs()
1040  assert all(isinstance(v, BlobReference) for v in field_blobs)
1041  field_arrays = [fetch(value) for value in field_blobs]
1042  return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1043 
1044 
1045 def FeedRecord(blob_record, arrays, ws=None):
1046  """
1047  Given a Record containing blob_references and arrays, which is either
1048  a list of numpy arrays or a Record containing numpy arrays, feeds the
1049  record to the current workspace.
1050  """
1051 
1052  def feed(b, v):
1053  if ws is None:
1054  workspace.FeedBlob(str(b), v)
1055  else:
1056  ws.create_blob(str(b))
1057  ws.blobs[str(b)].feed(v)
1058 
1059  assert isinstance(blob_record, Field)
1060  field_blobs = blob_record.field_blobs()
1061  assert all(isinstance(v, BlobReference) for v in field_blobs)
1062  if isinstance(arrays, Field):
1063  # TODO: check schema
1064  arrays = arrays.field_blobs()
1065  assert len(arrays) == len(field_blobs), (
1066  'Values must contain exactly %d ndarrays.' % len(field_blobs)
1067  )
1068  for blob, array in zip(field_blobs, arrays):
1069  feed(blob, array)
1070 
1071 
1072 def NewRecord(net, schema):
1073  """
1074  Given a record of np.arrays, create a BlobReference for each one of them,
1075  returning a record containing BlobReferences. The name of each returned blob
1076  is NextScopedBlob(field_name), which guarantees unique name in the current
1077  net. Use NameScope explicitly to avoid name conflictions between different
1078  nets.
1079  """
1080  if isinstance(schema, Scalar):
1081  result = schema.clone()
1082  result.set_value(
1083  blob=net.NextScopedBlob('unnamed_scalar'),
1084  unsafe=True,
1085  )
1086  return result
1087 
1088  assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
1089  blob_refs = [
1090  net.NextScopedBlob(prefix=name)
1091  for name in schema.field_names()
1092  ]
1093  return from_blob_list(schema, blob_refs)
1094 
1095 
1096 def ConstRecord(net, array_record):
1097  """
1098  Given a record of arrays, returns a record of blobs,
1099  initialized with net.Const.
1100  """
1101  blob_record = NewRecord(net, array_record)
1102  for blob, array in zip(
1103  blob_record.field_blobs(), array_record.field_blobs()
1104  ):
1105  net.Const(array, blob)
1106  return blob_record
1107 
1108 
1109 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1110  if not schema_or_record.has_blobs():
1111  record = NewRecord(net, schema_or_record)
1112  else:
1113  record = schema_or_record
1114 
1115  for blob_type, blob in zip(record.field_types(), record.field_blobs()):
1116  try:
1117  data_type = data_type_for_dtype(blob_type)
1118  shape = [0] + list(blob_type.shape)
1119  net.ConstantFill([], blob, shape=shape, dtype=data_type)
1120  except TypeError:
1121  logger.warning("Blob {} has type error".format(blob))
1122  # If data_type_for_dtype doesn't know how to resolve given numpy
1123  # type to core.DataType, that function can throw type error (for
1124  # example that would happen for cases of unknown types such as
1125  # np.void). This is not a problem for cases when the record if going
1126  # to be overwritten by some operator later, though it might be an
1127  # issue for type/shape inference.
1128  if enforce_types:
1129  raise
1130  # If we don't enforce types for all items we'll create a blob with
1131  # the default ConstantFill (FLOAT, no shape)
1132  net.ConstantFill([], blob, shape=[0])
1133 
1134  return record
1135 
1136 
1137 _DATA_TYPE_FOR_DTYPE = [
1138  (np.str, core.DataType.STRING),
1139  (np.float16, core.DataType.FLOAT16),
1140  (np.float32, core.DataType.FLOAT),
1141  (np.float64, core.DataType.DOUBLE),
1142  (np.bool, core.DataType.BOOL),
1143  (np.int8, core.DataType.INT8),
1144  (np.int16, core.DataType.INT16),
1145  (np.int32, core.DataType.INT32),
1146  (np.int64, core.DataType.INT64),
1147  (np.uint8, core.DataType.UINT8),
1148  (np.uint16, core.DataType.UINT16),
1149 ]
1150 
1151 
1152 def is_schema_subset(schema, original_schema):
1153  # TODO add more checks
1154  return set(schema.field_names()).issubset(
1155  set(original_schema.field_names()))
1156 
1157 
1158 def equal_schemas(schema,
1159  original_schema,
1160  check_field_names=True,
1161  check_field_types=True,
1162  check_field_metas=False):
1163  assert isinstance(schema, Field)
1164  assert isinstance(original_schema, Field)
1165 
1166  if check_field_names and (
1167  schema.field_names() != original_schema.field_names()):
1168  return False
1169  if check_field_types and (
1170  schema.field_types() != original_schema.field_types()):
1171  return False
1172  if check_field_metas and (
1173  schema.field_metadata() != original_schema.field_metadata()):
1174  return False
1175 
1176  return True
1177 
1178 
1179 def schema_check(schema, previous=None):
1180  record = as_record(schema)
1181  if previous is not None:
1182  assert equal_schemas(schema, previous)
1183  return record
1184 
1185 
1186 def data_type_for_dtype(dtype):
1187  for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1188  if dtype.base == np_type:
1189  return dt
1190  raise TypeError('Unknown dtype: ' + str(dtype.base))
1191 
1192 
1193 def dtype_for_core_type(core_type):
1194  for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1195  if dt == core_type:
1196  return np_type
1197  raise TypeError('Unknown core type: ' + str(core_type))
1198 
1199 
1200 def attach_metadata_to_scalars(field, metadata):
1201  for f in field.all_scalars():
1202  f.set_metadata(metadata)
def __init__(self, fields)
Definition: schema.py:283
def set(self, dtype=None, blob=None, metadata=None, unsafe=False)
Definition: schema.py:695
def __getattr__(self, item)
Definition: schema.py:253
def set_metadata(self, value)
Definition: schema.py:672
def field_metadata(self)
Definition: schema.py:122
def __add__(self, other)
Definition: schema.py:468
def get(self, item, default_value)
Definition: schema.py:442
def _pprint_impl(self, indent, str_buffer)
Definition: schema.py:187
def __getitem__(self, item)
Definition: schema.py:417
def clone(self, keep_blobs=True)
Definition: schema.py:142
def _child_base_id(self, child_index=None)
Definition: schema.py:171
def _struct_from_nested_name(self, nested_name, field)
Definition: schema.py:332
def __sub__(self, other)
Definition: schema.py:509
def _validate_metadata(self)
Definition: schema.py:678
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False)
Definition: schema.py:687
def __eq__(self, other)
Definition: schema.py:179
def _get_field_by_nested_name(self, nested_name)
Definition: schema.py:388
def __init__(self, children)
Definition: schema.py:101