1 #include "bound_shape_inferencer.h" 2 #include "caffe2/core/operator_schema.h" 3 #include "caffe2/core/tensor_impl.h" 4 #include "caffe2/utils/proto_utils.h" 5 #include "caffe2/utils/string_utils.h" 10 std::vector<int64_t> ConvertToVec(
11 const ::google::protobuf::RepeatedField<::google::protobuf::int64>& in) {
12 std::vector<int64_t> out;
13 out.reserve(in.size());
14 for (
const auto d : in) {
20 int64_t SizeFromDim(
const TensorShape& shape,
int axis) {
22 for (
int i = axis; i < shape.dims_size(); ++i) {
28 int64_t SizeToDim(
const TensorShape& shape,
int axis) {
29 CAFFE_ENFORCE_LE(axis, shape.dims_size());
31 for (
int i = 0; i < axis; ++i) {
37 void EnsureShapeNames(std::unordered_map<std::string, ShapeInfo>* info) {
38 for (
auto& kv : *info) {
39 kv.second.shape.set_name(kv.first);
44 void BoundShapeInferencer::InferBoundShapeAndType(
46 const std::unordered_map<std::string, ShapeInfo>& info) {
49 for (
const auto& op : net.op()) {
51 if (op.type() ==
"SparseLengthsSum" ||
52 op.type() ==
"SparseLengthsSumFused8BitRowwise" ||
53 op.type() ==
"SparseLengthsWeightedSum" ||
54 op.type() ==
"SparseLengthsWeightedSumFused8BitRowwise") {
55 InferSparseLengthsSum(op);
56 }
else if (op.type() ==
"FC" || op.type() ==
"FCTransposed") {
58 }
else if (op.type() ==
"Concat") {
60 }
else if (op.type() ==
"Reshape") {
62 }
else if (op.type() ==
"LengthsRangeFill") {
63 InferLengthsRangeFill(op);
65 (caffe2::StartsWith(op.type(),
"GivenTensor") &&
66 caffe2::EndsWith(op.type(),
"Fill")) ||
67 op.type() ==
"ConstantFill" || op.type() ==
"Int8GivenTensorFill" ||
68 op.type() ==
"Int8GivenIntTensorFill") {
69 InferGivenTensorFill(op);
70 }
else if (op.type() ==
"Shape") {
78 EnsureShapeNames(&shape_info_);
81 TensorShape& BoundShapeInferencer::CheckAndSetTensorShapeAndType(
82 const std::string& name,
84 std::vector<int64_t> bound_dims,
85 TensorProto::DataType type) {
86 auto rt = shape_info_.emplace(name, ShapeInfo());
87 ShapeInfo& shape_info = rt.first->second;
88 TensorShape& shape = shape_info.shape;
91 CAFFE_ENFORCE_EQ(shape.dims_size(), bound_dims.size());
94 if (shape.dims_size() > 0 &&
95 shape_info.dim_type == ShapeInfo::DimType::UNKNOWN &&
96 t > ShapeInfo::DimType::CONSTANT) {
97 shape_info.dim_type = t;
98 shape.set_dims(0, bound_dims.front());
100 for (
int i = 0; i < shape.dims_size(); ++i) {
104 "Shape inconsistency found in tensor ",
117 shape_info.dim_type = t;
118 shape.mutable_dims()->Clear();
119 for (
const auto d : bound_dims) {
122 shape.set_data_type(type);
126 std::vector<TensorShape> InferOutput(
127 const OperatorDef& op,
128 const std::vector<TensorShape>& input_shapes) {
129 const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
130 CAFFE_ENFORCE(schema);
131 return schema->InferTensor(op, input_shapes);
134 void BoundShapeInferencer::InferGivenTensorFill(
const OperatorDef& op) {
135 CAFFE_ENFORCE_EQ(op.output_size(), 1, op.type(),
" must have 1 output");
137 auto it = shape_info_.find(op.output(0));
138 if (it != shape_info_.end()) {
139 it->second.dim_type = ShapeInfo::DimType::CONSTANT;
143 void BoundShapeInferencer::InferLengthsRangeFill(
const OperatorDef& op) {
144 CAFFE_ENFORCE_EQ(op.input_size(), 1,
"LengthsRangeFill must have 1 input");
145 CAFFE_ENFORCE_EQ(op.output_size(), 1,
"LengthsRangeFill must have 1 output");
148 CheckAndSetTensorShapeAndType(
150 ShapeInfo::DimType::BATCH,
151 {spec_.max_batch_size},
152 TensorProto_DataType_INT32);
153 CheckAndSetTensorShapeAndType(
155 ShapeInfo::DimType::SEQ,
156 {spec_.max_seq_size},
157 TensorProto_DataType_INT32);
158 current_dim_type_ = ShapeInfo::DimType::SEQ;
161 void BoundShapeInferencer::InferSparseLengthsSum(
const OperatorDef& op) {
163 op.input_size(), 3, op.type(),
" must have at least 3 inputs");
164 const auto it = shape_info_.find(op.input(0));
166 it != shape_info_.end(),
167 "Shape of DATA input of SparseLengthsSum ",
169 " needs to be presented");
171 it->second.shape.dims().size(),
177 int weight = (op.type() ==
"SparseLengthsWeightedSum" ||
178 op.type() ==
"SparseLengthsWeightedSumFused8BitRowwise")
184 op.input_size(), 4,
"SparseLengthsWeightedSum must have 4 inputs");
185 CheckAndSetTensorShapeAndType(
187 ShapeInfo::DimType::SEQ,
188 {spec_.max_seq_size},
189 TensorProto_DataType_FLOAT);
193 CheckAndSetTensorShapeAndType(
194 op.input(1 + weight),
195 ShapeInfo::DimType::SEQ,
196 {spec_.max_seq_size},
197 TensorProto_DataType_INT64);
198 CheckAndSetTensorShapeAndType(
199 op.input(2 + weight),
200 ShapeInfo::DimType::BATCH,
201 {spec_.max_batch_size},
202 TensorProto_DataType_INT32);
205 CAFFE_ENFORCE_EQ(it->second.shape.dims_size(), 2);
206 current_dim_type_ = ShapeInfo::DimType::BATCH;
207 current_max_batch_size_ = spec_.max_batch_size;
208 auto output_dim1 = it->second.shape.dims(1);
211 if (op.type() ==
"SparseLengthsSumFused8BitRowwise" ||
212 op.type() ==
"SparseLengthsWeightedSumFused8BitRowwise") {
215 CheckAndSetTensorShapeAndType(
217 ShapeInfo::DimType::BATCH,
218 {spec_.max_batch_size, output_dim1},
219 TensorProto_DataType_FLOAT);
222 void BoundShapeInferencer::InferShape(
const OperatorDef& op) {
225 if (op.output_size() > 0 && shape_info_.count(op.output(0))) {
226 shape_info_[op.output(0)].dim_type = ShapeInfo::DimType::CONSTANT;
230 void BoundShapeInferencer::InferReshape(
const OperatorDef& op) {
233 if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
234 shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
240 void BoundShapeInferencer::InferConcat(
const OperatorDef& op) {
241 ArgumentHelper helper(op);
242 auto add_axis = helper.GetSingleArgument<int32_t>(
"add_axis", 0);
244 ShapeInfo* ref_input_shape =
nullptr;
245 std::string ref_name;
246 std::unordered_set<std::string> missing_shape_inputs;
247 for (
const auto& i : op.input()) {
248 const auto it = shape_info_.find(i);
249 if (it != shape_info_.end()) {
250 const auto& current_input_shape = it->second;
251 if (ref_input_shape) {
253 ref_input_shape->shape.dims_size(),
254 current_input_shape.shape.dims_size(),
258 for (
int j = 0; j < ref_input_shape->shape.dims_size(); ++j) {
260 ref_input_shape->shape.dims(j),
261 current_input_shape.shape.dims(j),
262 "Mismatched size on dim ",
269 ref_input_shape->shape.dims(j),
271 current_input_shape.shape.dims(j),
275 ref_input_shape = &it->second;
279 missing_shape_inputs.emplace(i);
283 if (ref_input_shape) {
284 current_dim_type_ = ref_input_shape->dim_type;
285 for (
const auto& i : missing_shape_inputs) {
286 shape_info_.emplace(i, *ref_input_shape);
292 if (op.output_size() > 1 && shape_info_.count(op.output(1))) {
293 shape_info_[op.output(1)].dim_type = ShapeInfo::DimType::CONSTANT;
297 void BoundShapeInferencer::InferFC(
const OperatorDef& op) {
298 CAFFE_ENFORCE_EQ(op.input_size(), 3,
"FC has to have 3 inputs");
299 const auto w_it = shape_info_.find(op.input(1));
301 w_it != shape_info_.end(),
302 "Shape of WEIGHT input of FC ",
304 " needs to be presented");
305 const ShapeInfo& w_shape_info = w_it->second;
306 const auto b_it = shape_info_.find(op.input(2));
308 w_it != shape_info_.end(),
309 "Shape of BIAS input of FC ",
311 " needs to be presented");
312 const ShapeInfo& b_shape_info = b_it->second;
313 auto x_it = shape_info_.find(op.input(0));
314 if (x_it == shape_info_.end()) {
316 ArgumentHelper helper(op);
317 auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
318 auto axis_w = helper.GetSingleArgument<int32_t>(
"axis_w", 1);
322 "Don't know how to deduce input of FC with axis not equal to 1: ",
327 "Don't know how to deduce input of FC with axis_w not equal to 1: ",
329 const TensorShape w_shape = w_shape_info.shape;
333 "Don't know how to deduce input of FC other than of dim size 2: ",
335 bool transposed = (op.type() ==
"FC") ?
false :
true;
336 const int canonical_axis_w =
337 canonical_axis_index_(axis_w, w_shape.dims().size());
338 const int64_t K = transposed ? SizeToDim(w_shape, canonical_axis_w)
339 : SizeFromDim(w_shape, canonical_axis_w);
340 current_dim_type_ = ShapeInfo::DimType::BATCH;
341 current_max_batch_size_ = spec_.max_batch_size;
342 CheckAndSetTensorShapeAndType(
344 ShapeInfo::DimType::BATCH,
345 {spec_.max_batch_size, K},
346 w_shape.data_type());
348 ShapeInfo& x_shape_info = x_it->second;
349 if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) {
350 CAFFE_ENFORCE_GE(x_shape_info.shape.dims_size(), 1);
351 x_shape_info.shape.set_dims(0, spec_.max_batch_size);
352 x_shape_info.dim_type = ShapeInfo::DimType::BATCH;
357 std::vector<TensorShape> input_shapes{
358 shape_info_[op.input(0)].shape, w_shape_info.shape, b_shape_info.shape};
359 std::vector<TensorShape> output_shapes = InferOutput(op, input_shapes);
360 CAFFE_ENFORCE_EQ(output_shapes.size(), 1);
361 CheckAndSetTensorShapeAndType(
363 ShapeInfo::DimType::BATCH,
364 ConvertToVec(output_shapes[0].dims()),
365 output_shapes[0].data_type());
368 void BoundShapeInferencer::InferCommonOp(
const OperatorDef& op) {
372 std::vector<TensorShape> input_shapes;
373 for (
const auto& input : op.input()) {
374 const auto it = shape_info_.find(input);
375 if (it == shape_info_.end()) {
376 LOG(WARNING) <<
"Cannot find shape info for " << input <<
". Skipping " 380 input_shapes.emplace_back(it->second.shape);
383 const OpSchema* schema = OpSchemaRegistry::Schema(op.type());
384 CAFFE_ENFORCE(schema);
385 std::vector<TensorShape> output_shapes;
386 output_shapes = schema->InferTensor(op, input_shapes);
388 for (
const auto& shape : output_shapes) {
389 if (shape.unknown_shape()) {
393 CheckAndSetTensorShapeAndType(
396 ConvertToVec(shape.dims()),
400 LOG(ERROR) <<
"Enforce not met while inferring shapes for " << op.type()
402 }
catch (
const std::exception& e) {
403 LOG(WARNING) <<
"Caught exception while inferring shapes for " << op.type()
The primary ATen error class.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...