Caffe2 - C++ API
A deep learning, cross platform ML framework
bound_shape_inferencer.h
1 #pragma once
2 
3 #include "caffe2/core/logging.h"
4 #include "caffe2/opt/shape_info.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 
7 #include <sstream>
8 #include <string>
9 #include <unordered_map>
10 #include <unordered_set>
11 
12 namespace caffe2 {
13 // This struct stores the max bound size for batch in the general sense. We have
14 // the conventioal batch size and the look-up sequence, which is also batch in a
15 // sense.
16 struct CAFFE2_API BoundShapeSpec {
17  explicit BoundShapeSpec(int64_t b, int64_t q)
18  : max_batch_size(b), max_seq_size(q) {}
19  int64_t max_batch_size;
20  int64_t max_seq_size;
21 };
22 
30 class CAFFE2_API BoundShapeInferencer {
31  public:
32  explicit BoundShapeInferencer(const BoundShapeSpec& spec) : spec_(spec) {
33  CAFFE_ENFORCE_GE(spec_.max_batch_size, 0);
34  CAFFE_ENFORCE_GE(spec_.max_seq_size, 0);
35  }
36 
37  void InferBoundShapeAndType(
38  const NetDef& net,
39  const std::unordered_map<std::string, ShapeInfo>& info);
40 
41  const ShapeInfoMap& shape_info() const {
42  return shape_info_;
43  }
44 
46  std::string PrintShapeInfo() const {
47  std::stringstream ss;
48  for (const auto& kv : shape_info_) {
49  const auto& s = kv.second;
50  ss << s.shape.name() << ": dim_type: " << s.dim_type << ", dims: [";
51  for (const auto d : s.shape.dims()) {
52  ss << d << ", ";
53  }
54  ss << "], dtype: " << s.shape.data_type() << "\n";
55  }
56  return ss.str();
57  }
58 
59  private:
60  TensorShape& CheckAndSetTensorShapeAndType(
61  const std::string& name,
62  ShapeInfo::DimType t,
63  std::vector<int64_t> bound_dims,
64  TensorProto::DataType type);
65 
66  void InferGivenTensorFill(const OperatorDef& op);
67  void InferSparseLengthsSum(const OperatorDef& op);
68  void InferFC(const OperatorDef& op);
69  void InferConcat(const OperatorDef& op);
70  void InferShape(const OperatorDef& op);
71  void InferReshape(const OperatorDef& op);
72  void InferLengthsRangeFill(const OperatorDef& op);
73 
74  // Standard shape/type inference using op schema registered shape inference
75  // function
76  void InferCommonOp(const OperatorDef& op);
77 
78  const BoundShapeSpec spec_;
79  ShapeInfo::DimType current_dim_type_{ShapeInfo::DimType::BATCH};
80  int64_t current_max_batch_size_{0};
81  std::unordered_map<std::string, ShapeInfo> shape_info_;
82 };
83 
84 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
std::string PrintShapeInfo() const
Print out all the shape info.