4 Defines a minimal set of data types that allow to represent datasets with 5 arbitrary nested structure, including objects of variable length, such as 8 This defines a columnar storage format for such datasets on top of caffe2 9 tensors. In terms of capacity of representation, it can represent most of 10 the data types supported by Parquet, ORC, DWRF file formats. 12 See comments in operator_test/dataset_ops_test.py for an example and 13 walkthrough on how to use schema to store and iterate through a structured 16 from __future__
import absolute_import
17 from __future__
import division
18 from __future__
import print_function
19 from __future__
import unicode_literals
26 from collections
import OrderedDict, namedtuple
27 from past.builtins
import basestring
28 from future.utils
import viewitems, viewkeys, viewvalues
29 from itertools
import islice
30 from six
import StringIO
32 logger = logging.getLogger(__name__)
33 logger.setLevel(logging.INFO)
38 def _join_field_name(prefix, suffix):
40 return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
49 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
50 """Clones/normalizes a field before adding it to a container.""" 51 if isinstance(field_or_type_or_blob, Field):
52 return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
53 elif type(field_or_type_or_blob)
in (type, np.dtype):
54 return Scalar(dtype=field_or_type_or_blob)
56 return Scalar(blob=field_or_type_or_blob)
59 FeatureSpec = namedtuple(
65 'feature_is_request_only',
71 FeatureSpec.__new__.__defaults__ = (
None,
None,
None,
None,
None,
None)
76 'Metadata', [
'categorical_limit',
'expected_value',
'feature_specs']
79 """Represents additional information associated with a scalar in schema. 81 `categorical_limit` - for fields of integral type that are guaranteed to be 82 non-negative it specifies the maximum possible value plus one. It's often 83 used as a size of an embedding table. 85 `expected_value` - anticipated average value of elements in the field. 86 Usually makes sense for length fields of lists. 88 `feature_specs` - information about the features that contained in this 89 field. For example if field have more than 1 feature it can have list of 90 feature names contained in this field.""" 94 Metadata.__new__.__defaults__ = (
None,
None,
None)
98 """Represents an abstract field type in a dataset. 102 """Derived classes must call this after their initialization.""" 106 for child
in children:
107 self._field_offsets.append(offset)
108 offset += len(child.field_names())
109 self._field_offsets.append(offset)
111 def clone_schema(self):
112 return self.
clone(keep_blobs=
False)
115 """Return the children field names for this field.""" 116 raise NotImplementedError(
'Field is an abstract class.')
119 """Return the numpy.dtype for each of the children fields.""" 120 raise NotImplementedError(
'Field is an abstract class.')
123 """Return the Metadata for each of the children fields.""" 124 raise NotImplementedError(
'Field is an abstract class.')
127 """Return the list of blobs with contents for this Field. 128 Values can either be all numpy.ndarray or BlobReference. 129 If any of the fields doens't have a blob, throws. 131 raise NotImplementedError(
'Field is an abstract class.')
134 """Return the list of all Scalar instances in the Field. 135 The order is the same as for field_names() or field_blobs()""" 136 raise NotImplementedError(
'Field is an abstract class.')
139 """Return True if every scalar of this field has blobs.""" 140 raise NotImplementedError(
'Field is an abstract class.')
143 """Clone this Field along with its children.""" 144 raise NotImplementedError(
'Field is an abstract class.')
146 def _set_parent(self, parent, relative_id):
147 self.
_parent = (parent, relative_id)
151 Returns a slice representing the range of field ids that belong to 152 this field. This slice can be used to index a list of fields. 159 >>> ('b1', Scalar()), 160 >>> ('b2', Scalar()), 164 >>> field_data = ['da', 'db1', 'db2', 'dc'] 165 >>> field_data[s.b.split()] 171 def _child_base_id(self, child_index=None):
172 """Get the base id of the given child""" 174 pos = 0
if child_index
is None else self.
_field_offsets[child_index]
176 pos += p._child_base_id(i)
180 """Equivalance of two schemas""" 187 def _pprint_impl(self, indent, str_buffer):
188 raise NotImplementedError(
'Field is an abstrct class.')
191 str_buffer = StringIO()
193 contents = str_buffer.getvalue()
199 """Represents a variable-length list. 201 Values of a list can also be complex fields such as Lists and Structs. 202 In addition to the fields exposed by its `values` field, a List exposes an 203 additional `lengths` field, which will contain the size of each list under 207 def __init__(self, values, lengths_blob=None):
208 if isinstance(lengths_blob, Field):
209 assert isinstance(lengths_blob, Scalar)
210 self.
lengths = _normalize_field(lengths_blob)
213 self.
_items = _normalize_field(values)
214 self.lengths._set_parent(self, 0)
215 self._items._set_parent(self, 1)
218 def field_names(self):
219 value_fields = self._items.field_names()
221 [
'lengths'] + [_join_field_name(
'values', v)
for v
in value_fields]
224 def field_types(self):
225 return self.lengths.field_types() + self._items.field_types()
227 def field_metadata(self):
228 return self.lengths.field_metadata() + self._items.field_metadata()
230 def field_blobs(self):
231 return self.lengths.field_blobs() + self._items.field_blobs()
233 def all_scalars(self):
234 return self.lengths.all_scalars() + self._items.all_scalars()
237 return self.lengths.has_blobs()
and self._items.has_blobs()
239 def clone(self, keep_blobs=True):
241 _normalize_field(self.
_items, keep_blobs=keep_blobs),
242 _normalize_field(self.
lengths, keep_blobs=keep_blobs)
245 def _pprint_impl(self, indent, str_buffer):
246 str_buffer.write(
' ' * indent +
"List(\n")
247 str_buffer.write(
' ' * (indent + 1) +
"lengths=\n")
248 self.lengths._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
249 str_buffer.write(
' ' * (indent + 1) +
"_items=\n")
250 self._items._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
251 str_buffer.write(
' ' * indent +
")\n")
254 """If the value of this list is a struct, 255 allow to introspect directly into its fields.""" 256 if item.startswith(
'__'):
257 raise AttributeError(item)
258 if isinstance(self.
_items, Struct):
259 return getattr(self.
_items, item)
260 elif item ==
'value' or item ==
'items':
263 raise AttributeError(
'Field not found in list: %s.' % item)
265 def __getitem__(self, item):
266 names = item.split(FIELD_SEPARATOR, 1)
269 if item ==
'lengths':
271 elif item ==
'values':
274 if names[0] ==
'values':
275 return self.
_items[names[1]]
276 raise KeyError(
'Field not found in list: %s.' % item)
280 """Represents a named list of fields sharing the same domain. 284 """ fields is a list of tuples in format of (name, field). The name is 285 a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example 302 ('d', Struct(('e', Scalar()))), 308 assert len(field) == 2
309 assert field[0],
'Field names cannot be empty' 310 assert field[0] !=
'lengths', (
311 'Struct cannot contain a field named `lengths`.' 313 fields = [(name, _normalize_field(field))
for name, field
in fields]
314 self.
fields = OrderedDict()
315 for name, field
in fields:
316 if FIELD_SEPARATOR
in name:
318 if name
not in self.
fields:
322 not isinstance(field, Struct)
or 323 not isinstance(self.
fields[name], Struct)
325 raise ValueError(
'Duplicate field name: %s' % name)
327 for id, (_, field)
in enumerate(viewitems(self.
fields)):
328 field._set_parent(self, id)
329 Field.__init__(self, viewvalues(self.
fields))
332 def _struct_from_nested_name(self, nested_name, field):
333 def create_internal(nested_name, field):
334 names = nested_name.split(FIELD_SEPARATOR, 1)
338 added_field = create_internal(names[1], field)
339 return Struct((names[0], added_field))
341 names = nested_name.split(FIELD_SEPARATOR, 1)
342 assert len(names) >= 2
343 return names[0], create_internal(names[1], field)
345 def get_children(self):
346 return list(viewitems(self.
fields))
348 def field_names(self):
350 for name, field
in viewitems(self.
fields):
351 names += [_join_field_name(name, f)
for f
in field.field_names()]
354 def field_types(self):
356 for _, field
in viewitems(self.
fields):
357 types += field.field_types()
360 def field_metadata(self):
362 for _, field
in viewitems(self.
fields):
363 metadata += field.field_metadata()
366 def field_blobs(self):
368 for _, field
in viewitems(self.
fields):
369 blobs += field.field_blobs()
372 def all_scalars(self):
374 for _, field
in viewitems(self.
fields):
375 scalars += field.all_scalars()
379 return all(field.has_blobs()
for field
in viewvalues(self.
fields))
381 def clone(self, keep_blobs=True):
382 normalized_fields = [
383 (k, _normalize_field(v, keep_blobs=keep_blobs))
384 for k, v
in viewitems(self.
fields)
386 return type(self)(*normalized_fields)
388 def _get_field_by_nested_name(self, nested_name):
389 names = nested_name.split(FIELD_SEPARATOR, 1)
390 field = self.fields.get(names[0],
None)
399 return field[names[1]]
400 except (KeyError, TypeError):
403 def _pprint_impl(self, indent, str_buffer):
404 str_buffer.write(
' ' * indent +
"Struct( \n")
405 for name, field
in viewitems(self.
fields):
406 str_buffer.write(
' ' * (indent + 1) +
"{}=".format(name) +
"\n")
407 field._pprint_impl(indent=indent + 2, str_buffer=str_buffer)
408 str_buffer.write(
' ' * indent +
") \n")
410 def __contains__(self, item):
412 return field
is not None 419 item can be a tuple or list of ints or strings, or a single 420 int or string. String item is a nested field name, e.g., "a", "a:b", 421 "a:b:c". Int item is the index of a field at the first level of the 424 if isinstance(item, list)
or isinstance(item, tuple):
425 keys = list(viewkeys(self.
fields))
430 if isinstance(k, int)
else k, self[k]
434 elif isinstance(item, int):
435 return next(islice(viewvalues(self.
fields), item,
None))
439 raise KeyError(
'field "%s" not found' % (item))
442 def get(self, item, default_value):
444 similar to python's dictionary get method, return field of item if found 445 (i.e. self.item is valid) or otherwise return default_value 447 it's a syntax suger of python's builtin getattr method 449 return getattr(self, item, default_value)
451 def __getattr__(self, item):
452 if item.startswith(
'__'):
453 raise AttributeError(item)
455 return self.__dict__[
'fields'][item]
457 raise AttributeError(item)
459 def __setattr__(self, key, value):
464 if getattr(self,
'_frozen',
None)
and not key.startswith(
'_'):
465 raise TypeError(
'Struct.__setattr__() is disabled after __init__()')
466 super(Struct, self).__setattr__(key, value)
470 Allows to merge fields of two schema.Struct using '+' operator. 471 If two Struct have common field names, the merge is conducted 472 recursively. Here are examples: 475 s1 = Struct(('a', Scalar())) 476 s2 = Struct(('b', Scalar())) 485 ('b', Struct(('c', Scalar()))), 487 s2 = Struct(('b', Struct(('d', Scalar())))) 496 if not isinstance(other, Struct):
497 return NotImplemented
500 for name, right_field
in other.get_children():
501 if name
not in children:
502 children[name] = right_field
504 left_field = children[name]
505 children[name] = left_field + right_field
507 return Struct(*(viewitems(children)))
511 Allows to remove common fields of two schema.Struct from self by 512 using '-' operator. If two Struct have common field names, the 513 removal is conducted recursively. If a child struct has no fields 514 inside, it will be removed from its parent. Here are examples: 521 s2 = Struct(('a', Scalar())) 522 s1 - s2 == Struct(('b', Scalar())) 532 ('b', Struct(('c', Scalar()))), 557 if not isinstance(other, Struct):
558 return NotImplemented
561 for name, right_field
in other.get_children():
563 left_field = children[name]
564 if type(left_field) == type(right_field):
565 if isinstance(left_field, Struct):
566 child = left_field - right_field
567 if child.get_children():
568 children[name] = child
573 "Type of left_field, " + str(type(left_field)) +
574 ", is not the same as that of right_field, " +
575 str(type(right_field)) +
576 ", yet they have the same field name, " + name)
577 return Struct(*(children.items()))
581 """Represents a typed scalar or tensor of fixed shape. 583 A Scalar is a leaf in a schema tree, translating to exactly one tensor in 584 the dataset's underlying storage. 586 Usually, the tensor storing the actual values of this field is a 1D tensor, 587 representing a series of values in its domain. It is possible however to 588 have higher rank values stored as a Scalar, as long as all entries have 595 Scalar field of type float64. Caffe2 will expect readers and 596 datasets to expose it as a 1D tensor of doubles (vector), where 597 the size of the vector is determined by this fields' domain. 599 Scalar((np.int32, 5)) 601 Tensor field of type int32. Caffe2 will expect readers and 602 datasets to implement it as a 2D tensor (matrix) of shape (L, 5), 603 where L is determined by this fields' domain. 605 Scalar((str, (10, 20))) 607 Tensor field of type str. Caffe2 will expect readers and 608 datasets to implement it as a 3D tensor of shape (L, 10, 20), 609 where L is determined by this fields' domain. 611 If the field type is unknown at construction time, call Scalar(), that will 612 default to np.void as its dtype. 614 It is an error to pass a structured dtype to Scalar, since it would contain 615 more than one field. Instead, use from_dtype, which will construct 616 a nested `Struct` field reflecting the given dtype's structure. 618 A Scalar can also contain a blob, which represents the value of this 619 Scalar. A blob can be either a numpy.ndarray, in which case it contain the 620 actual contents of the Scalar, or a BlobReference, which represents a 621 blob living in a caffe2 Workspace. If blob of different types are passed, 622 a conversion to numpy.ndarray is attempted. 625 def __init__(self, dtype=None, blob=None, metadata=None):
627 self.
set(dtype, blob, metadata, unsafe=
True)
628 Field.__init__(self, [])
630 def field_names(self):
633 def field_type(self):
636 def field_types(self):
639 def field_metadata(self):
643 return self.
_blob is not None 645 def field_blobs(self):
646 assert self.
_blob is not None,
'Value is not set for this field.' 649 def all_scalars(self):
652 def clone(self, keep_blobs=True):
655 blob=self.
_blob if keep_blobs
else None,
660 """Gets the current blob of this Scalar field.""" 661 assert self.
_blob is not None,
'Value is not set for this field.' 665 """Shortcut for self.get()""" 672 def set_metadata(self, value):
673 assert isinstance(value, Metadata), \
674 'metadata must be Metadata, got {}'.format(type(value))
678 def _validate_metadata(self):
681 if (self._metadata.categorical_limit
is not None and 682 self.
dtype is not None):
683 assert np.issubdtype(self.
dtype, np.integer), \
684 "`categorical_limit` can be specified only in integral " + \
685 "fields but got {}".format(self.
dtype)
687 def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False):
688 """Sets only the blob field still validating the existing dtype""" 689 if self.dtype.base != np.void
and throw_on_type_mismatch:
690 assert isinstance(blob, np.ndarray),
"Got {!r}".format(blob)
691 assert blob.dtype.base == self.dtype.base, (
692 "Expected {}, got {}".format(self.dtype.base, blob.dtype.base))
695 def set(self, dtype=None, blob=None, metadata=None, unsafe=False):
696 """Set the type and/or blob of this scalar. See __init__ for details. 699 dtype: can be any numpy type. If not provided and `blob` is 700 provided, it will be inferred. If no argument is provided, 701 this Scalar will be of type np.void. 702 blob: if provided, can be either a BlobReference or a 703 numpy.ndarray. If a value of different type is passed, 704 a conversion to numpy.ndarray is attempted. Strings aren't 705 accepted, since they can be ambiguous. If you want to pass 706 a string, to either BlobReference(blob) or np.array(blob). 707 metadata: optional instance of Metadata, if provided overrides 708 the metadata information of the scalar 712 "Scalar should be considered immutable. Only call Scalar.set() " 713 "on newly created Scalar with unsafe=True. This will become an " 716 if blob
is not None and isinstance(blob, basestring):
718 'Passing str blob to Scalar.set() is ambiguous. ' 719 'Do either set(blob=np.array(blob)) or ' 720 'set(blob=BlobReference(blob))' 731 if dtype
is not None and isinstance(dtype, tuple)
and dtype[1] == 1:
732 dtype = (dtype[0], (1,))
733 if dtype
is not None:
734 if isinstance(dtype, tuple)
and dtype[0] == np.void:
736 "Cannot set the Scalar with type {} for blob {}." 737 "If this blob is the output of some operation, " 738 "please verify the input of that operation has " 739 "proper type.".format(dtype, blob)
741 dtype = np.dtype(dtype)
744 if blob
is not None and not isinstance(blob, BlobReference):
745 preserve_shape = isinstance(blob, np.ndarray)
746 if dtype
is not None and dtype != np.void:
747 blob = np.array(blob, dtype=dtype.base)
749 if blob.size == 0
and not preserve_shape:
750 blob = blob.reshape((0, ) + dtype.shape)
752 assert isinstance(blob, np.ndarray), (
753 'Invalid blob type: %s' % str(type(blob)))
757 if len(blob.shape) == 0
and not preserve_shape:
758 blob = blob.reshape((1, ))
762 if (len(blob.shape) > 1
and dtype
is not None and 763 dtype.base != np.void):
764 dtype = np.dtype((dtype.base, blob.shape[1:]))
767 dtype = np.dtype(np.void)
768 assert not dtype.fields, (
769 'Cannot create Scalar with a structured dtype. ' +
770 'Use from_dtype instead.' 774 if metadata
is not None:
778 def set_type(self, dtype):
780 if dtype
is not None:
781 self.
dtype = np.dtype(dtype)
783 self.
dtype = np.dtype(np.void)
786 def _pprint_impl(self, indent, str_buffer):
787 str_buffer.write(
' ' * (indent) +
788 'Scalar({!r}, {!r}, {!r})'.format(
793 Return the zero-indexed position of this scalar field in its schema. 794 Used in order to index into the field_blob list returned by readers or 804 values_name=
'values',
807 """A map is a List of Struct containing keys and values fields. 808 Optionally, you can provide custom name for the key and value fields. 811 Struct((keys_name, keys), (values_name, values)),
812 lengths_blob=lengths_blob
816 def NamedTuple(name_prefix, *fields):
817 return Struct(* [(
'%s_%d' % (name_prefix, i), field)
818 for i, field
in enumerate(fields)])
823 Creates a Struct with default, sequential, field names of given types. 825 return NamedTuple(
'field', *fields)
828 def RawTuple(num_fields, name_prefix='field'):
830 Creates a tuple of `num_field` untyped scalars. 832 assert isinstance(num_fields, int)
833 assert num_fields >= 0
834 return NamedTuple(name_prefix, *([np.void] * num_fields))
837 def from_dtype(dtype, _outer_shape=()):
838 """Constructs a Caffe2 schema from the given numpy's dtype. 840 Numpy supports scalar, array-like and structured datatypes, as long as 841 all the shapes are fixed. This function breaks down the given dtype into 842 a Caffe2 schema containing `Struct` and `Scalar` types. 844 Fields containing byte offsets are not currently supported. 846 if not isinstance(dtype, np.dtype):
849 dtype = np.dtype((dtype, _outer_shape))
852 shape = _outer_shape + dtype.shape
853 if shape != dtype.shape:
854 dtype = np.dtype((dtype.base, shape))
860 for name, (fdtype, offset)
in dtype.fields:
861 assert offset == 0, (
'Fields with byte offsets are not supported.')
862 struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
863 return Struct(*struct_fields)
867 """This is a private class used to represent a Schema Node""" 869 def __init__(self, name, type_str=''):
875 def add_child(self, name, type_str=''):
877 if child.name == name
and child.type_str == type_str:
880 self.children.append(child)
885 list_names = [
'lengths',
'values']
886 map_names = [
'lengths',
'keys',
'values']
889 if self.
field is None:
896 child_names.append(child.name)
898 if (set(child_names) == set(list_names)):
900 if child.name ==
'values':
901 values_field = child.get_field()
903 lengths_field = child.get_field()
906 lengths_blob=lengths_field
910 elif (set(child_names) == set(map_names)):
912 if child.name ==
'keys':
913 key_field = child.get_field()
914 elif child.name ==
'values':
915 values_field = child.get_field()
917 lengths_field = child.get_field()
921 lengths_blob=lengths_field
929 struct_fields.append((child.name, child.get_field()))
935 def print_recursively(self):
937 child.print_recursively()
938 logger.info(
"Printing node: Name and type")
939 logger.info(self.
name)
943 def from_column_list(
944 col_names, col_types=
None,
945 col_blobs=
None, col_metadata=
None 948 Given a list of names, types, and optionally values, construct a Schema. 950 if col_types
is None:
951 col_types = [
None] * len(col_names)
952 if col_metadata
is None:
953 col_metadata = [
None] * len(col_names)
954 if col_blobs
is None:
955 col_blobs = [
None] * len(col_names)
956 assert len(col_names) == len(col_types), (
957 'col_names and col_types must have the same length.' 959 assert len(col_names) == len(col_metadata), (
960 'col_names and col_metadata must have the same length.' 962 assert len(col_names) == len(col_blobs), (
963 'col_names and col_blobs must have the same length.' 966 for col_name, col_type, col_blob, col_metadata
in zip(
967 col_names, col_types, col_blobs, col_metadata
969 columns = col_name.split(FIELD_SEPARATOR)
971 for i
in range(len(columns)):
975 if i == len(columns) - 1:
980 metadata=col_metadata
982 next = current.add_child(name, type_str)
983 if field
is not None:
987 return root.get_field()
990 def from_blob_list(schema, values, throw_on_type_mismatch=False):
992 Create a schema that clones the given schema, but containing the given 995 assert isinstance(schema, Field),
'Argument `schema` must be a Field.' 996 if isinstance(values, BlobReference):
998 record = schema.clone_schema()
999 scalars = record.all_scalars()
1000 assert len(scalars) == len(values), (
1001 'Values must have %d elements, got %d.' % (len(scalars), len(values))
1003 for scalar, value
in zip(scalars, values):
1004 scalar.set_value(value, throw_on_type_mismatch, unsafe=
True)
1008 def as_record(value):
1009 if isinstance(value, Field):
1011 elif isinstance(value, list)
or isinstance(value, tuple):
1012 is_field_list = all(
1013 f
is tuple
and len(f) == 2
and isinstance(f[0], basestring)
1017 return Struct(* [(k, as_record(v))
for k, v
in value])
1019 return Tuple(* [as_record(f)
for f
in value])
1020 elif isinstance(value, dict):
1021 return Struct(* [(k, as_record(v))
for k, v
in viewitems(value)])
1023 return _normalize_field(value)
1026 def FetchRecord(blob_record, ws=None, throw_on_type_mismatch=False):
1028 Given a record containing BlobReferences, return a new record with same 1029 schema, containing numpy arrays, fetched from the current active workspace. 1034 return workspace.FetchBlob(str(v))
1036 return ws.blobs[str(v)].fetch()
1038 assert isinstance(blob_record, Field)
1039 field_blobs = blob_record.field_blobs()
1040 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
1041 field_arrays = [fetch(value)
for value
in field_blobs]
1042 return from_blob_list(blob_record, field_arrays, throw_on_type_mismatch)
1045 def FeedRecord(blob_record, arrays, ws=None):
1047 Given a Record containing blob_references and arrays, which is either 1048 a list of numpy arrays or a Record containing numpy arrays, feeds the 1049 record to the current workspace. 1054 workspace.FeedBlob(str(b), v)
1056 ws.create_blob(str(b))
1057 ws.blobs[str(b)].feed(v)
1059 assert isinstance(blob_record, Field)
1060 field_blobs = blob_record.field_blobs()
1061 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
1062 if isinstance(arrays, Field):
1064 arrays = arrays.field_blobs()
1065 assert len(arrays) == len(field_blobs), (
1066 'Values must contain exactly %d ndarrays.' % len(field_blobs)
1068 for blob, array
in zip(field_blobs, arrays):
1072 def NewRecord(net, schema):
1074 Given a record of np.arrays, create a BlobReference for each one of them, 1075 returning a record containing BlobReferences. The name of each returned blob 1076 is NextScopedBlob(field_name), which guarantees unique name in the current 1077 net. Use NameScope explicitly to avoid name conflictions between different 1080 if isinstance(schema, Scalar):
1081 result = schema.clone()
1083 blob=net.NextScopedBlob(
'unnamed_scalar'),
1088 assert isinstance(schema, Field),
'Record must be a schema.Field instance.' 1090 net.NextScopedBlob(prefix=name)
1091 for name
in schema.field_names()
1093 return from_blob_list(schema, blob_refs)
1096 def ConstRecord(net, array_record):
1098 Given a record of arrays, returns a record of blobs, 1099 initialized with net.Const. 1101 blob_record = NewRecord(net, array_record)
1102 for blob, array
in zip(
1103 blob_record.field_blobs(), array_record.field_blobs()
1105 net.Const(array, blob)
1109 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
1110 if not schema_or_record.has_blobs():
1111 record = NewRecord(net, schema_or_record)
1113 record = schema_or_record
1115 for blob_type, blob
in zip(record.field_types(), record.field_blobs()):
1117 data_type = data_type_for_dtype(blob_type)
1118 shape = [0] + list(blob_type.shape)
1119 net.ConstantFill([], blob, shape=shape, dtype=data_type)
1121 logger.warning(
"Blob {} has type error".format(blob))
1132 net.ConstantFill([], blob, shape=[0])
1137 _DATA_TYPE_FOR_DTYPE = [
1138 (np.str, core.DataType.STRING),
1139 (np.float16, core.DataType.FLOAT16),
1140 (np.float32, core.DataType.FLOAT),
1141 (np.float64, core.DataType.DOUBLE),
1142 (np.bool, core.DataType.BOOL),
1143 (np.int8, core.DataType.INT8),
1144 (np.int16, core.DataType.INT16),
1145 (np.int32, core.DataType.INT32),
1146 (np.int64, core.DataType.INT64),
1147 (np.uint8, core.DataType.UINT8),
1148 (np.uint16, core.DataType.UINT16),
1152 def is_schema_subset(schema, original_schema):
1154 return set(schema.field_names()).issubset(
1155 set(original_schema.field_names()))
1158 def equal_schemas(schema,
1160 check_field_names=
True,
1161 check_field_types=
True,
1162 check_field_metas=
False):
1163 assert isinstance(schema, Field)
1164 assert isinstance(original_schema, Field)
1166 if check_field_names
and (
1167 schema.field_names() != original_schema.field_names()):
1169 if check_field_types
and (
1170 schema.field_types() != original_schema.field_types()):
1172 if check_field_metas
and (
1173 schema.field_metadata() != original_schema.field_metadata()):
1179 def schema_check(schema, previous=None):
1180 record = as_record(schema)
1181 if previous
is not None:
1182 assert equal_schemas(schema, previous)
1186 def data_type_for_dtype(dtype):
1187 for np_type, dt
in _DATA_TYPE_FOR_DTYPE:
1188 if dtype.base == np_type:
1190 raise TypeError(
'Unknown dtype: ' + str(dtype.base))
1193 def dtype_for_core_type(core_type):
1194 for np_type, dt
in _DATA_TYPE_FOR_DTYPE:
1197 raise TypeError(
'Unknown core type: ' + str(core_type))
1200 def attach_metadata_to_scalars(field, metadata):
1201 for f
in field.all_scalars():
1202 f.set_metadata(metadata)
def __init__(self, fields)
def set(self, dtype=None, blob=None, metadata=None, unsafe=False)
def __getattr__(self, item)
def set_metadata(self, value)
def get(self, item, default_value)
def _pprint_impl(self, indent, str_buffer)
def __getitem__(self, item)
def clone(self, keep_blobs=True)
def _child_base_id(self, child_index=None)
def _struct_from_nested_name(self, nested_name, field)
def _validate_metadata(self)
def set_value(self, blob, throw_on_type_mismatch=False, unsafe=False)
def _get_field_by_nested_name(self, nested_name)
def __init__(self, children)