Caffe2 - Python API
A deep learning, cross platform ML framework
text_file_reader.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package text_file_reader
17 # Module caffe2.python.text_file_reader
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 from caffe2.python import core
23 from caffe2.python.dataio import Reader
24 from caffe2.python.schema import Scalar, Struct, data_type_for_dtype
25 
26 
28  """
29  Wrapper around operators for reading from text files.
30  """
31  def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1):
32  """
33  Create op for building a TextFileReader instance in the workspace.
34 
35  Args:
36  init_net : Net that will be run only once at startup.
37  filename : Path to file to read from.
38  schema : schema.Struct representing the schema of the data.
39  Currently, only support Struct of strings.
40  num_passes : Number of passes over the data.
41  batch_size : Number of rows to read at a time.
42  """
43  assert isinstance(schema, Struct), 'Schema must be a schema.Struct'
44  for name, child in schema.get_children():
45  assert isinstance(child, Scalar), (
46  'Only scalar fields are supported in TextFileReader.')
47  field_types = [
48  data_type_for_dtype(dtype) for dtype in schema.field_types()]
49  Reader.__init__(self, schema)
50  self._reader = init_net.CreateTextFileReader(
51  [],
52  filename=filename,
53  num_passes=num_passes,
54  field_types=field_types)
55  self._batch_size = batch_size
56 
57  def read(self, net):
58  """
59  Create op for reading a batch of rows.
60  """
61  blobs = net.TextFileReaderRead(
62  [self._reader],
63  len(self.schema().field_names()),
64  batch_size=self._batch_size)
65  if type(blobs) is core.BlobReference:
66  blobs = [blobs]
67 
68  is_empty = net.IsEmpty(
69  [blobs[0]],
70  core.ScopedBlobReference(net.NextName('should_stop'))
71  )
72 
73  return (is_empty, blobs)
def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1)