Caffe2 - C++ API
A deep learning, cross platform ML framework
backend_transformer_base.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/workspace.h"
5 #include "caffe2/opt/bound_shape_inferencer.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 
8 #include <string>
9 #include <unordered_map>
10 #include <vector>
11 
12 namespace caffe2 {
13 namespace {
14 constexpr char kNetPos[] = "net_pos";
15 constexpr char kModelId[] = "model_id";
16 } // namespace
17 
18 // This class contains some common functions for backend lowering and graph
19 // cutting
21  public:
23  virtual ~BackendTransformerBase() {}
24 
25  const std::unordered_map<std::string, std::string>& input_mapping() const {
26  return input_mapping_;
27  }
28 
29  const std::unordered_map<std::string, std::string>& reverse_input_mapping()
30  const {
31  return reverse_input_mapping_;
32  }
33 
34  virtual void transform(
35  Workspace* ws,
36  NetDef* pred_net,
37  const std::vector<std::string>& weight_names,
38  const std::unordered_map<std::string, TensorShape>& shape_hints,
39  const std::unordered_set<int>& blacklisted_ops) = 0;
40 
41  protected:
42  // get model ID from the NetDef
43  std::string getModelId(const NetDef& net);
44 
45  // SSA rewrite the net and return name mapping
46  std::unordered_map<std::string, TensorShape> ssaRewriteAndMapNames(
47  Workspace* ws,
48  NetDef* pred_net,
49  const std::unordered_map<std::string, TensorShape>& input_shape_hints);
50 
51  // Wrap TensorShape into TensorProto
52  TensorProto wrapShapeInfoIntoTensorProto(
53  const std::string& name,
54  const ShapeInfo& shape_info) const;
55 
56  // Do bound shape inference and collect shape infos
57  ShapeInfoMap inferShapes(
58  Workspace* ws,
59  NetDef* pred_net,
60  const std::unordered_map<std::string, TensorShape>& shape_hints_mapped,
61  const BoundShapeSpec& spec);
62 
63  // Input mapping of input name -> original input name
64  std::unordered_map<std::string, std::string> input_mapping_;
65 
66  // Input mapping of orignal input name -> input name
67  std::unordered_map<std::string, std::string> reverse_input_mapping_;
68 };
69 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13