Caffe2 - C++ API
A deep learning, cross platform ML framework
data_filler.h
1 #pragma once
2 #include "caffe2/core/logging.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/predictor/predictor.h"
5 #include "caffe2/utils/filler.h"
6 
7 namespace caffe2 {
8 namespace emulator {
9 
10 typedef caffe2::Predictor::TensorList TensorList_t;
11 
12 /*
13  * A filler to initialize the parameters and inputs of a predictor
14  */
15 class Filler {
16  protected:
17  virtual void fill_input_internal(TensorList_t* input_data) const = 0;
18 
19  public:
20  // initialize the workspace with parameter
21  virtual void fill_parameter(Workspace* ws) const = 0;
22 
23  // generate input data and return input data size
24  size_t fill_input(TensorList_t* input_data) const {
25  CAFFE_ENFORCE(input_data, "input_data is null");
26  input_data->clear();
27 
28  fill_input_internal(input_data);
29 
30  uint64_t bytes = 0;
31  for (const auto& item : *input_data) {
32  bytes += item.nbytes();
33  }
34  if (bytes == 0) {
35  LOG(WARNING) << "0 input bytes filled";
36  }
37 
38  return bytes;
39  }
40 
41  const std::vector<std::string>& get_input_names() const {
42  CAFFE_ENFORCE(!input_names_.empty(), "input names is not initialized");
43  return input_names_;
44  }
45 
46  virtual ~Filler() noexcept {}
47 
48  protected:
49  std::vector<std::string> input_names_;
50 };
51 
52 /*
53  * @init_net: a reader net to generate parameters
54  * @data_net: a reader net to generate inputs
55  */
56 class DataNetFiller : public Filler {
57  public:
58  DataNetFiller(const NetDef&& init_net, const NetDef&& data_net)
59  : init_net_(init_net), data_net_(data_net) {
60  // The output of the data_net_ will be served as the input
61  int op_size = data_net_.op_size();
62  for (int i = 0; i < op_size; ++i) {
63  OperatorDef op_def = data_net_.op(i);
64  // We rely on Fill op to generate inputs
65  CAFFE_ENFORCE(op_def.type().find("Fill") != std::string::npos);
66  int output_size = op_def.output_size();
67  for (int j = 0; j < output_size; ++j) {
68  input_names_.push_back(op_def.output(j));
69  }
70  }
71  }
72 
73  void fill_input_internal(TensorList_t* input_data) const override;
74 
75  void fill_parameter(Workspace* ws) const override;
76 
77  private:
78  const NetDef init_net_;
79  const NetDef data_net_;
80 };
81 
82 /*
83  * @run_net: the predict net with parameter and input names
84  * @input_dims: the input dimentions of all operator inputs of run_net
85  * @input_types: the input types of all operator inputs of run_net
86  */
87 class DataRandomFiller : public Filler {
88  public:
90  const NetDef& run_net,
91  const std::vector<std::vector<std::vector<int64_t>>>& input_dims,
92  const std::vector<std::vector<std::string>>& input_types);
93 
94  void fill_input_internal(TensorList_t* input_data) const override;
95 
96  void fill_parameter(Workspace* ws) const override;
97 
98  protected:
99  DataRandomFiller() {}
100 
101  TensorFiller get_tensor_filler(
102  const OperatorDef& op_def,
103  int input_index,
104  const std::vector<std::vector<int64_t>>& input_dims) {
105  Workspace ws;
106  for (size_t i = 0; i < op_def.input_size(); ++i) {
107  // CreateOperator requires all input blobs present
108  ws.CreateBlob(op_def.input(i));
109  }
110  CAFFE_ENFORCE(op_def.has_type());
111  const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op_def.type());
112  if (schema == nullptr) {
113  throw std::invalid_argument(
114  op_def.type() + " does not have input fillers");
115  }
116  auto filler = schema->InputFillers(input_dims)[input_index];
117  return filler;
118  }
119 
120  using filler_type_pair_t = std::pair<TensorFiller, std::string>;
121  std::unordered_map<std::string, filler_type_pair_t> parameters_;
122  std::unordered_map<std::string, filler_type_pair_t> inputs_;
123 };
124 
125 // A DataRandomFiller that is more convenient to use in unit tests.
126 // Callers just need to supply input dimensions and types for non-intermediate
127 // blobs.
128 // It also treats parameters the same way as non-intermediate inputs (no
129 // handling of parameters separately).
131  public:
133  const NetDef& net,
134  const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
135  const std::vector<std::vector<std::string>>& inputTypes);
136 
137  // Fill input directly to the workspace.
138  void fillInputToWorkspace(Workspace* workspace) const;
139 };
140 
141 } // namespace emulator
142 } // namespace caffe2
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
A class to record the schema of an op.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13