Caffe2 - Python API
A deep learning, cross platform ML framework
text_file_reader.py
1 ## @package text_file_reader
2 # Module caffe2.python.text_file_reader
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 from caffe2.python import core
8 from caffe2.python.dataio import Reader
9 from caffe2.python.schema import Scalar, Struct, data_type_for_dtype
10 
11 
13  """
14  Wrapper around operators for reading from text files.
15  """
16  def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1):
17  """
18  Create op for building a TextFileReader instance in the workspace.
19 
20  Args:
21  init_net : Net that will be run only once at startup.
22  filename : Path to file to read from.
23  schema : schema.Struct representing the schema of the data.
24  Currently, only support Struct of strings.
25  num_passes : Number of passes over the data.
26  batch_size : Number of rows to read at a time.
27  """
28  assert isinstance(schema, Struct), 'Schema must be a schema.Struct'
29  for name, child in schema.get_children():
30  assert isinstance(child, Scalar), (
31  'Only scalar fields are supported in TextFileReader.')
32  field_types = [
33  data_type_for_dtype(dtype) for dtype in schema.field_types()]
34  Reader.__init__(self, schema)
35  self._reader = init_net.CreateTextFileReader(
36  [],
37  filename=filename,
38  num_passes=num_passes,
39  field_types=field_types)
40  self._batch_size = batch_size
41 
42  def read(self, net):
43  """
44  Create op for reading a batch of rows.
45  """
46  blobs = net.TextFileReaderRead(
47  [self._reader],
48  len(self.schema().field_names()),
49  batch_size=self._batch_size)
50  if type(blobs) is core.BlobReference:
51  blobs = [blobs]
52 
53  is_empty = net.IsEmpty(
54  [blobs[0]],
55  core.ScopedBlobReference(net.NextName('should_stop'))
56  )
57 
58  return (is_empty, blobs)
def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1)