Caffe2 - C++ API
A deep learning, cross platform ML framework
text_file_reader.cc
1 #include "caffe2/core/context.h"
2 #include "caffe2/core/operator.h"
3 #include "caffe2/core/tensor.h"
4 #include "caffe2/core/types.h"
5 #include "caffe2/operators/text_file_reader_utils.h"
6 #include "caffe2/utils/string_utils.h"
7 
8 namespace caffe2 {
9 
12  const std::vector<char>& delims,
13  char escape,
14  const std::string& filename,
15  int numPasses,
16  const std::vector<int>& types)
17  : fileReader(filename),
18  tokenizer(Tokenizer(delims, escape), &fileReader, numPasses),
19  fieldTypes(types) {
20  for (const auto dt : fieldTypes) {
21  fieldMetas.push_back(
22  DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
23  fieldByteSizes.push_back(fieldMetas.back().itemsize());
24  }
25  }
26 
27  FileReader fileReader;
28  BufferedTokenizer tokenizer;
29  std::vector<int> fieldTypes;
30  std::vector<TypeMeta> fieldMetas;
31  std::vector<size_t> fieldByteSizes;
32  size_t rowsRead{0};
33 
34  // hack to guarantee thread-safeness of the read op
35  // TODO(azzolini): support multi-threaded reading.
36  std::mutex globalMutex_;
37 };
38 
39 class CreateTextFileReaderOp : public Operator<CPUContext> {
40  public:
41  template <class... Args>
42  explicit CreateTextFileReaderOp(Args&&... args)
43  : Operator<CPUContext>(std::forward<Args>(args)...),
44  filename_(GetSingleArgument<string>("filename", "")),
45  numPasses_(GetSingleArgument<int>("num_passes", 1)),
46  fieldTypes_(GetRepeatedArgument<int>("field_types")) {
47  CAFFE_ENFORCE(fieldTypes_.size() > 0, "field_types arg must be non-empty");
48  }
49 
50  bool RunOnDevice() override {
51  *OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
52  std::unique_ptr<TextFileReaderInstance>(new TextFileReaderInstance(
53  {'\n', '\t'}, '\0', filename_, numPasses_, fieldTypes_));
54  return true;
55  }
56 
57  private:
58  std::string filename_;
59  int numPasses_;
60  std::vector<int> fieldTypes_;
61 };
62 
63 inline void convert(
64  TensorProto_DataType dst_type,
65  const char* src_start,
66  const char* src_end,
67  void* dst) {
68  switch (dst_type) {
69  case TensorProto_DataType_STRING: {
70  static_cast<std::string*>(dst)->assign(src_start, src_end);
71  } break;
72  case TensorProto_DataType_FLOAT: {
73  // TODO(azzolini): avoid copy, use faster convertion
74  std::string str_copy(src_start, src_end);
75  const char* src_copy = str_copy.c_str();
76  char* src_copy_end;
77  float val = strtof(src_copy, &src_copy_end);
78  if (src_copy == src_copy_end) {
79  throw std::runtime_error("Invalid float: " + str_copy);
80  }
81  *static_cast<float*>(dst) = val;
82  } break;
83  default:
84  throw std::runtime_error("Unsupported type.");
85  }
86 }
87 
88 class TextFileReaderReadOp : public Operator<CPUContext> {
89  public:
90  template <class... Args>
91  explicit TextFileReaderReadOp(Args&&... args)
92  : Operator<CPUContext>(std::forward<Args>(args)...),
93  batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
94 
95  bool RunOnDevice() override {
96  const int numFields = OutputSize();
97  CAFFE_ENFORCE(numFields > 0, "Expected at least one output.");
98 
99  auto instance =
100  OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).get();
101 
102  CAFFE_ENFORCE(
103  instance->fieldTypes.size() == numFields,
104  "Invalid number of outputs. Expected " +
105  to_string(instance->fieldTypes.size()) + " got " +
106  to_string(numFields));
107 
108  // char* datas[numFields];
109  // MSVC does not allow using const int, so we will need to dynamically allocate
110  // it.
111  std::vector<char*> datas(numFields);
112  for (int i = 0; i < numFields; ++i) {
113  Output(i)->Resize(batchSize_);
114  datas[i] = (char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
115  }
116 
117  int rowsRead = 0;
118  {
119  // TODO(azzolini): support multi-threaded reading
120  std::lock_guard<std::mutex> guard(instance->globalMutex_);
121 
122  bool finished = false;
123  Token token;
124  while (!finished && (rowsRead < batchSize_)) {
125  int field;
126  for (field = 0; field < numFields; ++field) {
127  finished = !instance->tokenizer.next(token);
128  if (finished) {
129  CAFFE_ENFORCE(
130  field == 0, "Invalid number of fields at end of file.");
131  break;
132  }
133  CAFFE_ENFORCE(
134  (field == 0 && token.startDelimId == 0) ||
135  (field > 0 && token.startDelimId == 1),
136  "Invalid number of columns at row ",
137  instance->rowsRead + rowsRead + 1);
138  const auto& meta = instance->fieldMetas[field];
139  char*& data = datas[field];
140  convert(
141  (TensorProto_DataType)instance->fieldTypes[field],
142  token.start,
143  token.end,
144  data);
145  data += instance->fieldByteSizes[field];
146  }
147  if (!finished) {
148  ++rowsRead;
149  }
150  }
151  instance->rowsRead += rowsRead;
152  }
153 
154  for (int i = 0; i < numFields; ++i) {
155  Output(i)->ShrinkTo(rowsRead);
156  }
157  return true;
158  }
159 
160  private:
161  int64_t batchSize_;
162 };
163 
164 CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
165 
166 REGISTER_CPU_OPERATOR(CreateTextFileReader, CreateTextFileReaderOp);
167 REGISTER_CPU_OPERATOR(TextFileReaderRead, TextFileReaderReadOp);
168 
169 OPERATOR_SCHEMA(CreateTextFileReader)
170  .NumInputs(0)
171  .NumOutputs(1)
172  .SetDoc("Create a text file reader. Fields are delimited by <TAB>.")
173  .Arg("filename", "Path to the file.")
174  .Arg("num_passes", "Number of passes over the file.")
175  .Arg(
176  "field_types",
177  "List with type of each field. Type enum is found at core.DataType.")
178  .Output(0, "handler", "Pointer to the created TextFileReaderInstance.");
179 
180 OPERATOR_SCHEMA(TextFileReaderRead)
181  .NumInputs(1)
182  .NumOutputs(1, INT_MAX)
183  .SetDoc(
184  "Read a batch of rows from the given text file reader instance. "
185  "Expects the number of fields to be equal to the number of outputs. "
186  "Each output is a 1D tensor containing the values for the given field "
187  "for each row. When end of file is reached, returns empty tensors.")
188  .Input(0, "handler", "Pointer to an existing TextFileReaderInstance.")
189  .Arg("batch_size", "Maximum number of rows to read.");
190 
191 NO_GRADIENT(CreateTextFileReader);
192 NO_GRADIENT(TextFileReaderRead);
193 
194 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13