1 #include "caffe2/core/context.h" 2 #include "caffe2/core/operator.h" 3 #include "caffe2/core/tensor.h" 4 #include "caffe2/core/types.h" 5 #include "caffe2/operators/text_file_reader_utils.h" 6 #include "caffe2/utils/string_utils.h" 12 const std::vector<char>& delims,
14 const std::string& filename,
16 const std::vector<int>& types)
17 : fileReader(filename),
18 tokenizer(
Tokenizer(delims, escape), &fileReader, numPasses),
20 for (
const auto dt : fieldTypes) {
22 DataTypeToTypeMeta(static_cast<TensorProto_DataType>(dt)));
23 fieldByteSizes.push_back(fieldMetas.back().itemsize());
29 std::vector<int> fieldTypes;
30 std::vector<TypeMeta> fieldMetas;
31 std::vector<size_t> fieldByteSizes;
36 std::mutex globalMutex_;
41 template <
class... Args>
44 filename_(GetSingleArgument<string>(
"filename",
"")),
45 numPasses_(GetSingleArgument<int>(
"num_passes", 1)),
46 fieldTypes_(GetRepeatedArgument<int>(
"field_types")) {
47 CAFFE_ENFORCE(fieldTypes_.size() > 0,
"field_types arg must be non-empty");
50 bool RunOnDevice()
override {
51 *OperatorBase::Output<std::unique_ptr<TextFileReaderInstance>>(0) =
53 {
'\n',
'\t'},
'\0', filename_, numPasses_, fieldTypes_));
58 std::string filename_;
60 std::vector<int> fieldTypes_;
64 TensorProto_DataType dst_type,
65 const char* src_start,
69 case TensorProto_DataType_STRING: {
70 static_cast<std::string*
>(dst)->assign(src_start, src_end);
72 case TensorProto_DataType_FLOAT: {
74 std::string str_copy(src_start, src_end);
75 const char* src_copy = str_copy.c_str();
77 float val = strtof(src_copy, &src_copy_end);
78 if (src_copy == src_copy_end) {
79 throw std::runtime_error(
"Invalid float: " + str_copy);
81 *
static_cast<float*
>(dst) = val;
84 throw std::runtime_error(
"Unsupported type.");
90 template <
class... Args>
93 batchSize_(GetSingleArgument<int>(
"batch_size", 1)) {}
95 bool RunOnDevice()
override {
96 const int numFields = OutputSize();
97 CAFFE_ENFORCE(numFields > 0,
"Expected at least one output.");
100 OperatorBase::Input<std::unique_ptr<TextFileReaderInstance>>(0).
get();
103 instance->fieldTypes.size() == numFields,
104 "Invalid number of outputs. Expected " +
105 to_string(instance->fieldTypes.size()) +
" got " +
106 to_string(numFields));
111 std::vector<char*> datas(numFields);
112 for (
int i = 0; i < numFields; ++i) {
113 Output(i)->Resize(batchSize_);
114 datas[i] = (
char*)Output(i)->raw_mutable_data(instance->fieldMetas[i]);
120 std::lock_guard<std::mutex> guard(instance->globalMutex_);
122 bool finished =
false;
124 while (!finished && (rowsRead < batchSize_)) {
126 for (field = 0; field < numFields; ++field) {
127 finished = !instance->tokenizer.next(token);
130 field == 0,
"Invalid number of fields at end of file.");
134 (field == 0 && token.startDelimId == 0) ||
135 (field > 0 && token.startDelimId == 1),
136 "Invalid number of columns at row ",
137 instance->rowsRead + rowsRead + 1);
138 const auto& meta = instance->fieldMetas[field];
139 char*& data = datas[field];
141 (TensorProto_DataType)instance->fieldTypes[field],
145 data += instance->fieldByteSizes[field];
151 instance->rowsRead += rowsRead;
154 for (
int i = 0; i < numFields; ++i) {
155 Output(i)->ShrinkTo(rowsRead);
164 CAFFE_KNOWN_TYPE(std::unique_ptr<TextFileReaderInstance>);
169 OPERATOR_SCHEMA(CreateTextFileReader)
172 .SetDoc(
"Create a text file reader. Fields are delimited by <TAB>.")
173 .Arg(
"filename",
"Path to the file.")
174 .Arg(
"num_passes",
"Number of passes over the file.")
177 "List with type of each field. Type enum is found at core.DataType.")
178 .Output(0,
"handler",
"Pointer to the created TextFileReaderInstance.");
180 OPERATOR_SCHEMA(TextFileReaderRead)
182 .NumOutputs(1, INT_MAX)
184 "Read a batch of rows from the given text file reader instance. " 185 "Expects the number of fields to be equal to the number of outputs. " 186 "Each output is a 1D tensor containing the values for the given field " 187 "for each row. When end of file is reached, returns empty tensors.")
188 .Input(0,
"handler",
"Pointer to an existing TextFileReaderInstance.")
189 .Arg(
"batch_size",
"Maximum number of rows to read.");
191 NO_GRADIENT(CreateTextFileReader);
192 NO_GRADIENT(TextFileReaderRead);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...