Caffe2 - C++ API
A deep learning, cross platform ML framework
text_file_reader.cc
1 
17 #include "caffe2/core/context.h"
18 #include "caffe2/core/operator.h"
19 #include "caffe2/core/tensor.h"
20 #include "caffe2/core/types.h"
21 #include "caffe2/operators/text_file_reader_utils.h"
22 #include "caffe2/utils/string_utils.h"
23 
24 namespace caffe2 {
25 
28  const std::vector<char>& delims,
29  char escape,
30  const std::string& filename,
31  int numPasses,
32  const std::vector<int>& types)
33  : fileReader(filename),
34  tokenizer(Tokenizer(delims, escape), &fileReader, numPasses),
35  fieldTypes(types) {
36  for (const auto dt : fieldTypes) {
37  fieldMetas.push_back(
38  DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
39  fieldByteSizes.push_back(fieldMetas.back().itemsize());
40  }
41  }
42 
43  FileReader fileReader;
44  BufferedTokenizer tokenizer;
45  std::vector<int> fieldTypes;
46  std::vector<TypeMeta> fieldMetas;
47  std::vector<size_t> fieldByteSizes;
48  size_t rowsRead{0};
49 
50  // hack to guarantee thread-safeness of the read op
51  // TODO(azzolini): support multi-threaded reading.
52  std::mutex globalMutex_;
53 };
54 
55 class CreateTextFileReaderOp : public Operator<CPUContext> {
56  public:
57  CreateTextFileReaderOp(const OperatorDef& operator_def, Workspace* ws)
58  : Operator<CPUContext>(operator_def, ws),
59  filename_(GetSingleArgument<string>("filename", "")),
60  numPasses_(GetSingleArgument<int>("num_passes", 1)),
61  fieldTypes_(GetRepeatedArgument<int>("field_types")) {
62  CAFFE_ENFORCE(fieldTypes_.size() > 0, "field_types arg must be non-empty");
63  }
64 
65  bool RunOnDevice() override {
66  *OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
67  std::unique_ptr<TextFileReaderInstance>(new TextFileReaderInstance(
68  {'\n', '\t'}, '\0', filename_, numPasses_, fieldTypes_));
69  return true;
70  }
71 
72  private:
73  std::string filename_;
74  int numPasses_;
75  std::vector<int> fieldTypes_;
76 };
77 
78 inline void convert(
79  TensorProto_DataType dst_type,
80  const char* src_start,
81  const char* src_end,
82  void* dst) {
83  switch (dst_type) {
84  case TensorProto_DataType_STRING: {
85  static_cast<std::string*>(dst)->assign(src_start, src_end);
86  } break;
87  case TensorProto_DataType_FLOAT: {
88  // TODO(azzolini): avoid copy, use faster convertion
89  std::string str_copy(src_start, src_end);
90  const char* src_copy = str_copy.c_str();
91  char* src_copy_end;
92  float val = strtof(src_copy, &src_copy_end);
93  if (src_copy == src_copy_end) {
94  throw std::runtime_error("Invalid float: " + str_copy);
95  }
96  *static_cast<float*>(dst) = val;
97  } break;
98  default:
99  throw std::runtime_error("Unsupported type.");
100  }
101 }
102 
103 class TextFileReaderReadOp : public Operator<CPUContext> {
104  public:
105  TextFileReaderReadOp(const OperatorDef& operator_def, Workspace* ws)
106  : Operator<CPUContext>(operator_def, ws),
107  batchSize_(GetSingleArgument<int>("batch_size", 1)) {}
108 
109  bool RunOnDevice() override {
110  const int numFields = OutputSize();
111  CAFFE_ENFORCE(numFields > 0, "Expected at least one output.");
112 
113  auto instance =
114  OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).get();
115 
116  CAFFE_ENFORCE(
117  instance->fieldTypes.size() == numFields,
118  "Invalid number of outputs. Expected " +
119  to_string(instance->fieldTypes.size()) + " got " +
120  to_string(numFields));
121 
122  // char* datas[numFields];
123  // MSVC does not allow using const int, so we will need to dynamically allocate
124  // it.
125  std::vector<char*> datas(numFields);
126  for (int i = 0; i < numFields; ++i) {
127  Output(i)->Resize(batchSize_);
128  datas[i] = (char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
129  }
130 
131  int rowsRead = 0;
132  {
133  // TODO(azzolini): support multi-threaded reading
134  std::lock_guard<std::mutex> guard(instance->globalMutex_);
135 
136  bool finished = false;
137  Token token;
138  while (!finished && (rowsRead < batchSize_)) {
139  int field;
140  for (field = 0; field < numFields; ++field) {
141  finished = !instance->tokenizer.next(token);
142  if (finished) {
143  CAFFE_ENFORCE(
144  field == 0, "Invalid number of fields at end of file.");
145  break;
146  }
147  CAFFE_ENFORCE(
148  (field == 0 && token.startDelimId == 0) ||
149  (field > 0 && token.startDelimId == 1),
150  "Invalid number of columns at row ",
151  instance->rowsRead + rowsRead + 1);
152  const auto& meta = instance->fieldMetas[field];
153  char*& data = datas[field];
154  convert(
155  (TensorProto_DataType)instance->fieldTypes[field],
156  token.start,
157  token.end,
158  data);
159  data += instance->fieldByteSizes[field];
160  }
161  if (!finished) {
162  ++rowsRead;
163  }
164  }
165  instance->rowsRead += rowsRead;
166  }
167 
168  for (int i = 0; i < numFields; ++i) {
169  Output(i)->Shrink(rowsRead);
170  }
171  return true;
172  }
173 
174  private:
175  TIndex batchSize_;
176 };
177 
178 CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
179 
180 REGISTER_CPU_OPERATOR(CreateTextFileReader, CreateTextFileReaderOp);
181 REGISTER_CPU_OPERATOR(TextFileReaderRead, TextFileReaderReadOp);
182 
183 OPERATOR_SCHEMA(CreateTextFileReader)
184  .NumInputs(0)
185  .NumOutputs(1)
186  .SetDoc("Create a text file reader. Fields are delimited by <TAB>.")
187  .Arg("filename", "Path to the file.")
188  .Arg("num_passes", "Number of passes over the file.")
189  .Arg(
190  "field_types",
191  "List with type of each field. Type enum is found at core.DataType.")
192  .Output(0, "handler", "Pointer to the created TextFileReaderInstance.");
193 
194 OPERATOR_SCHEMA(TextFileReaderRead)
195  .NumInputs(1)
196  .NumOutputs(1, INT_MAX)
197  .SetDoc(
198  "Read a batch of rows from the given text file reader instance. "
199  "Expects the number of fields to be equal to the number of outputs. "
200  "Each output is a 1D tensor containing the values for the given field "
201  "for each row. When end of file is reached, returns empty tensors.")
202  .Input(0, "handler", "Pointer to an existing TextFileReaderInstance.")
203  .Arg("batch_size", "Maximum number of rows to read.");
204 
205 NO_GRADIENT(CreateTextFileReader);
206 NO_GRADIENT(TextFileReaderRead);
207 
208 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.