Caffe2 - C++ API
A deep learning, cross platform ML framework
bound_shape_inferencer.cc
1 #include "bound_shape_inferencer.h"
2 #include "caffe2/core/operator_schema.h"
3 #include "caffe2/core/tensor_impl.h"
4 #include "caffe2/utils/proto_utils.h"
5 #include "caffe2/utils/string_utils.h"
6 
7 namespace caffe2 {
8 
9 namespace {
10 std::vector<int64_t> ConvertToVec(
11  const ::google::protobuf::RepeatedField<::google::protobuf::int64>& in) {
12  std::vector<int64_t> out;
13  out.reserve(in.size());
14  for (const auto d : in) {
15  out.push_back(d);
16  }
17  return out;
18 }
19 
20 int64_t SizeFromDim(const TensorShape& shape, int axis) {
21  int64_t r = 1;
22  for (int i = axis; i < shape.dims_size(); ++i) {
23  r *= shape.dims(i);
24  }
25  return r;
26 }
27 
28 int64_t SizeToDim(const TensorShape& shape, int axis) {
29  CAFFE_ENFORCE_LE(axis, shape.dims_size());
30  int64_t r = 1;
31  for (int i = 0; i < axis; ++i) {
32  r *= shape.dims(i);
33  }
34  return r;
35 }
36 
37 void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) {
38  for (auto& kv : *info) {
39  kv.second.shape.set_name(kv.first);
40  }
41 }
42 } // namespace
43 
44 void BoundShapeInferencer::InferBoundShapeAndType(
45  const NetDef& net,
46  const std::unordered_map<std::string, ShapeInfo>& info) {
47  shape_info_ = info;
48 
49  for (const auto& op : net.op()) {
50  VLOG(1) << op.type();
51  if (op.type() == "SparseLengthsSum" ||
52  op.type() == "SparseLengthsSumFused8BitRowwise" ||
53  op.type() == "SparseLengthsWeightedSum" ||
54  op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
55  InferSparseLengthsSum(op);
56  } else if (op.type() == "FC" || op.type() == "FCTransposed") {
57  InferFC(op);
58  } else if (op.type() == "Concat") {
59  InferConcat(op);
60  } else if (op.type() == "Reshape") {
61  InferReshape(op);
62  } else if (op.type() == "LengthsRangeFill") {
63  InferLengthsRangeFill(op);
64  } else if (
65  (caffe2::StartsWith(op.type(), "GivenTensor") &&
66  caffe2::EndsWith(op.type(), "Fill")) ||
67  op.type() == "ConstantFill" || op.type() == "Int8GivenTensorFill" ||
68  op.type() == "Int8GivenIntTensorFill") {
69  InferGivenTensorFill(op);
70  } else if (op.type() == "Shape") {
71  InferShape(op);
72  } else {
73  InferCommonOp(op);
74  }
75  }
76 
77  // Make sure shape has name
78  EnsureShapeNames(&shape_info_);
79 }
80 
81 TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType(
82  const std::string& name,
83  ShapeInfo::DimType t,
84  std::vector<int64_t> bound_dims,
85  TensorProto::DataType type) {
86  auto rt = shape_info_.emplace(name, ShapeInfo());
87  ShapeInfo& shape_info = rt.first->second;
88  TensorShape& shape = shape_info.shape;
89  if (!rt.second) {
90  // Check shape consistency
91  CAFFE_ENFORCE_EQ(shape.dims_size(), bound_dims.size());
92  // For shapes that was provided as a hint at the input of the net, fix the
93  // batch size first.
94  if (shape.dims_size() > 0 &&
95  shape_info.dim_type == ShapeInfo::DimType::UNKNOWN &&
96  t > ShapeInfo::DimType::CONSTANT) {
97  shape_info.dim_type = t;
98  shape.set_dims(0, bound_dims.front());
99  }
100  for (int i = 0; i < shape.dims_size(); ++i) {
101  CAFFE_ENFORCE_EQ(
102  shape.dims(i),
103  bound_dims[i],
104  "Shape inconsistency found in tensor ",
105  name,
106  " on dim ",
107  i,
108  " (",
109  shape.dims(i),
110  " vs ",
111  bound_dims[i],
112  ")");
113  }
114  return shape;
115  }
116 
117  shape_info.dim_type = t;
118  shape.mutable_dims()->Clear();
119  for (const auto d : bound_dims) {
120  shape.add_dims(d);
121  }
122  shape.set_data_type(type);
123  return shape;
124 }
125 
126 std::vector<TensorShape> InferOutput(
127  const OperatorDef& op,
128  const std::vector<TensorShape>& input_shapes) {
129  const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
130  CAFFE_ENFORCE(schema);
131  return schema->InferTensor(op, input_shapes);
132 }
133 
134 void BoundShapeInferencer::InferGivenTensorFill(const OperatorDef& op) {
135  CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(), " must have 1 output");
136  InferCommonOp(op);
137  auto it = shape_info_.find(op.output(0));
138  if (it != shape_info_.end()) {
139  it->second.dim_type = ShapeInfo::DimType::CONSTANT;
140  }
141 }
142 
143 void BoundShapeInferencer::InferLengthsRangeFill(const OperatorDef& op) {
144  CAFFE_ENFORCE_EQ(op.input_size(), 1, "LengthsRangeFill must have 1 input");
145  CAFFE_ENFORCE_EQ(op.output_size(), 1, "LengthsRangeFill must have 1 output");
146  // Both input and ouptut of LengthsRangeFill is int32:
147  // https://fburl.com/fhwb5666
148  CheckAndSetTensorShapeAndType(
149  op.input(0),
150  ShapeInfo::DimType::BATCH,
151  {spec_.max_batch_size},
152  TensorProto_DataType_INT32);
153  CheckAndSetTensorShapeAndType(
154  op.output(0),
155  ShapeInfo::DimType::SEQ,
156  {spec_.max_seq_size},
157  TensorProto_DataType_INT32);
158  current_dim_type_ = ShapeInfo::DimType::SEQ;
159 }
160 
161 void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) {
162  CAFFE_ENFORCE_GE(
163  op.input_size(), 3, op.type(), " must have at least 3 inputs");
164  const auto it = shape_info_.find(op.input(0));
165  CAFFE_ENFORCE(
166  it != shape_info_.end(),
167  "Shape of DATA input of SparseLengthsSum ",
168  op.input(0),
169  " needs to be presented");
170  CAFFE_ENFORCE_EQ(
171  it->second.shape.dims().size(),
172  2,
173  "DATA input ",
174  op.input(0),
175  "needs to be 2D");
176 
177  int weight = (op.type() == "SparseLengthsWeightedSum" ||
178  op.type() == "SparseLengthsWeightedSumFused8BitRowwise")
179  ? 1
180  : 0;
181 
182  if (weight) {
183  CAFFE_ENFORCE_EQ(
184  op.input_size(), 4, "SparseLengthsWeightedSum must have 4 inputs");
185  CheckAndSetTensorShapeAndType(
186  op.input(weight),
187  ShapeInfo::DimType::SEQ,
188  {spec_.max_seq_size},
189  TensorProto_DataType_FLOAT);
190  }
191 
192  // Bound inputs
193  CheckAndSetTensorShapeAndType(
194  op.input(1 + weight),
195  ShapeInfo::DimType::SEQ,
196  {spec_.max_seq_size},
197  TensorProto_DataType_INT64);
198  CheckAndSetTensorShapeAndType(
199  op.input(2 + weight),
200  ShapeInfo::DimType::BATCH,
201  {spec_.max_batch_size},
202  TensorProto_DataType_INT32);
203 
204  // Infer output
205  CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
206  current_dim_type_ = ShapeInfo::DimType::BATCH;
207  current_max_batch_size_ = spec_.max_batch_size;
208  auto output_dim1 = it->second.shape.dims(1);
209  // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for
210  // scale and 4 byte for bias (https://fburl.com/t6dp9tsc)
211  if (op.type() == "SparseLengthsSumFused8BitRowwise" ||
212  op.type() == "SparseLengthsWeightedSumFused8BitRowwise") {
213  output_dim1 -= 8;
214  }
215  CheckAndSetTensorShapeAndType(
216  op.output(0),
217  ShapeInfo::DimType::BATCH,
218  {spec_.max_batch_size, output_dim1},
219  TensorProto_DataType_FLOAT);
220 }
221 
222 void BoundShapeInferencer::InferShape(const OperatorDef& op) {
223  InferCommonOp(op);
224  // old_shape should be a constant
225  if (op.output_size() > 0 && shape_info_.count(op.output(0))) {
226  shape_info_[op.output(0)].dim_type = ShapeInfo::DimType::CONSTANT;
227  }
228 }
229 
230 void BoundShapeInferencer::InferReshape(const OperatorDef& op) {
231  InferCommonOp(op);
232  // old_shape should be a constant
233  if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
234  shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
235  }
236 }
237 // For concat net, if some inputs are missing and we have add_axis argument, it
238 // means that all the inputs should be of the same dimension. In this case, we
239 // can infer the shape of the missing inputs
240 void BoundShapeInferencer::InferConcat(const OperatorDef& op) {
241  ArgumentHelper helper(op);
242  auto add_axis = helper.GetSingleArgument<int32_t>("add_axis", 0);
243  if (add_axis) {
244  ShapeInfo* ref_input_shape = nullptr;
245  std::string ref_name;
246  std::unordered_set<std::string> missing_shape_inputs;
247  for (const auto& i : op.input()) {
248  const auto it = shape_info_.find(i);
249  if (it != shape_info_.end()) {
250  const auto& current_input_shape = it->second;
251  if (ref_input_shape) {
252  CAFFE_ENFORCE_EQ(
253  ref_input_shape->shape.dims_size(),
254  current_input_shape.shape.dims_size(),
255  ref_name,
256  " vs ",
257  i);
258  for (int j = 0; j < ref_input_shape->shape.dims_size(); ++j) {
259  CAFFE_ENFORCE_EQ(
260  ref_input_shape->shape.dims(j),
261  current_input_shape.shape.dims(j),
262  "Mismatched size on dim ",
263  j,
264  " between ",
265  ref_name,
266  " and ",
267  i,
268  " (",
269  ref_input_shape->shape.dims(j),
270  " vs ",
271  current_input_shape.shape.dims(j),
272  ")");
273  }
274  } else {
275  ref_input_shape = &it->second;
276  ref_name = i;
277  }
278  } else {
279  missing_shape_inputs.emplace(i);
280  }
281  }
282 
283  if (ref_input_shape) {
284  current_dim_type_ = ref_input_shape->dim_type;
285  for (const auto& i : missing_shape_inputs) {
286  shape_info_.emplace(i, *ref_input_shape);
287  }
288  }
289  }
290  InferCommonOp(op);
291  // split_info should be a constant
292  if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
293  shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
294  }
295 }
296 
297 void BoundShapeInferencer::InferFC(const OperatorDef& op) {
298  CAFFE_ENFORCE_EQ(op.input_size(), 3, "FC has to have 3 inputs");
299  const auto w_it = shape_info_.find(op.input(1));
300  CAFFE_ENFORCE(
301  w_it != shape_info_.end(),
302  "Shape of WEIGHT input of FC ",
303  op.input(1),
304  " needs to be presented");
305  const ShapeInfo& w_shape_info = w_it->second;
306  const auto b_it = shape_info_.find(op.input(2));
307  CAFFE_ENFORCE(
308  w_it != shape_info_.end(),
309  "Shape of BIAS input of FC ",
310  op.input(2),
311  " needs to be presented");
312  const ShapeInfo& b_shape_info = b_it->second;
313  auto x_it = shape_info_.find(op.input(0));
314  if (x_it == shape_info_.end()) {
315  // We don't have a hint at the x input we try to deduce it from weight shape
316  ArgumentHelper helper(op);
317  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
318  auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
319  CAFFE_ENFORCE_EQ(
320  axis,
321  1,
322  "Don't know how to deduce input of FC with axis not equal to 1: ",
323  op.input(0));
324  CAFFE_ENFORCE_EQ(
325  axis_w,
326  1,
327  "Don't know how to deduce input of FC with axis_w not equal to 1: ",
328  op.input(0));
329  const TensorShape w_shape = w_shape_info.shape;
330  CAFFE_ENFORCE_EQ(
331  w_shape.dims_size(),
332  2,
333  "Don't know how to deduce input of FC other than of dim size 2: ",
334  op.input(0));
335  bool transposed = (op.type() == "FC") ? false : true;
336  const int canonical_axis_w =
337  canonical_axis_index_(axis_w, w_shape.dims().size());
338  const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
339  : SizeFromDim(w_shape, canonical_axis_w);
340  current_dim_type_ = ShapeInfo::DimType::BATCH;
341  current_max_batch_size_ = spec_.max_batch_size;
342  CheckAndSetTensorShapeAndType(
343  op.input(0),
344  ShapeInfo::DimType::BATCH,
345  {spec_.max_batch_size, K},
346  w_shape.data_type());
347  } else {
348  ShapeInfo& x_shape_info = x_it->second;
349  if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {
350  CAFFE_ENFORCE_GE(x_shape_info.shape.dims_size(), 1);
351  x_shape_info.shape.set_dims(0, spec_.max_batch_size);
352  x_shape_info.dim_type = ShapeInfo::DimType::BATCH;
353  }
354  }
355 
356  // Standard shape inference for outputs
357  std::vector<TensorShape> input_shapes{
358  shape_info_[op.input(0)].shape, w_shape_info.shape, b_shape_info.shape};
359  std::vector<TensorShape> output_shapes = InferOutput(op, input_shapes);
360  CAFFE_ENFORCE_EQ(output_shapes.size(), 1);
361  CheckAndSetTensorShapeAndType(
362  op.output(0),
363  ShapeInfo::DimType::BATCH,
364  ConvertToVec(output_shapes[0].dims()),
365  output_shapes[0].data_type());
366 }
367 
368 void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) {
369  // First, we need to check that all the input shape/types are already
370  // presented
371  try {
372  std::vector<TensorShape> input_shapes;
373  for (const auto& input : op.input()) {
374  const auto it = shape_info_.find(input);
375  if (it == shape_info_.end()) {
376  LOG(WARNING) << "Cannot find shape info for " << input << ". Skipping "
377  << op.type();
378  return;
379  }
380  input_shapes.emplace_back(it->second.shape);
381  }
382 
383  const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
384  CAFFE_ENFORCE(schema);
385  std::vector<TensorShape> output_shapes;
386  output_shapes = schema->InferTensor(op, input_shapes);
387  int i = 0;
388  for (const auto& shape : output_shapes) {
389  if (shape.unknown_shape()) {
390  ++i;
391  continue;
392  }
393  CheckAndSetTensorShapeAndType(
394  op.output(i++),
395  current_dim_type_,
396  ConvertToVec(shape.dims()),
397  shape.data_type());
398  }
399  } catch (const caffe2::EnforceNotMet& e) {
400  LOG(ERROR) << "Enforce not met while inferring shapes for " << op.type()
401  << ": " << e.msg();
402  } catch (const std::exception& e) {
403  LOG(WARNING) << "Caught exception while inferring shapes for " << op.type()
404  << ": " << e.what();
405  }
406 }
407 
408 } // namespace caffe2
The primary ATen error class.
Definition: Exception.h:27
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13