2 #include "caffe2/core/logging.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/predictor/predictor.h" 5 #include "caffe2/utils/filler.h" 10 typedef caffe2::Predictor::TensorList TensorList_t;
17 virtual void fill_input_internal(TensorList_t* input_data)
const = 0;
21 virtual void fill_parameter(
Workspace* ws)
const = 0;
24 size_t fill_input(TensorList_t* input_data)
const {
25 CAFFE_ENFORCE(input_data,
"input_data is null");
28 fill_input_internal(input_data);
31 for (
const auto& item : *input_data) {
32 bytes += item.nbytes();
35 LOG(WARNING) <<
"0 input bytes filled";
41 const std::vector<std::string>& get_input_names()
const {
42 CAFFE_ENFORCE(!input_names_.empty(),
"input names is not initialized");
46 virtual ~
Filler() noexcept {}
49 std::vector<std::string> input_names_;
58 DataNetFiller(
const NetDef&& init_net,
const NetDef&& data_net)
59 : init_net_(init_net), data_net_(data_net) {
61 int op_size = data_net_.op_size();
62 for (
int i = 0; i < op_size; ++i) {
63 OperatorDef op_def = data_net_.op(i);
65 CAFFE_ENFORCE(op_def.type().find(
"Fill") != std::string::npos);
66 int output_size = op_def.output_size();
67 for (
int j = 0; j < output_size; ++j) {
68 input_names_.push_back(op_def.output(j));
73 void fill_input_internal(TensorList_t* input_data)
const override;
75 void fill_parameter(
Workspace* ws)
const override;
78 const NetDef init_net_;
79 const NetDef data_net_;
90 const NetDef& run_net,
91 const std::vector<std::vector<std::vector<int64_t>>>& input_dims,
92 const std::vector<std::vector<std::string>>& input_types);
94 void fill_input_internal(TensorList_t* input_data)
const override;
96 void fill_parameter(
Workspace* ws)
const override;
102 const OperatorDef& op_def,
104 const std::vector<std::vector<int64_t>>& input_dims) {
106 for (
size_t i = 0; i < op_def.input_size(); ++i) {
110 CAFFE_ENFORCE(op_def.has_type());
111 const OpSchema* schema = caffe2::OpSchemaRegistry::Schema(op_def.type());
112 if (schema ==
nullptr) {
113 throw std::invalid_argument(
114 op_def.type() +
" does not have input fillers");
116 auto filler = schema->InputFillers(input_dims)[input_index];
120 using filler_type_pair_t = std::pair<TensorFiller, std::string>;
121 std::unordered_map<std::string, filler_type_pair_t> parameters_;
122 std::unordered_map<std::string, filler_type_pair_t> inputs_;
134 const std::vector<std::vector<std::vector<int64_t>>>& inputDims,
135 const std::vector<std::vector<std::string>>& inputTypes);
138 void fillInputToWorkspace(
Workspace* workspace)
const;
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
A class to record the schema of an op.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...