1 #include "caffe2/operators/fc_inference.h" 4 std::vector<TensorShape> FCShapeInference(
5 const OperatorDef& def,
6 const vector<TensorShape>& in,
7 bool pretransposed_weight) {
8 vector<TensorShape> out(1);
10 if (in[0].unknown_shape() || in[1].unknown_shape()) {
11 out[0].set_unknown_shape(
true);
15 ArgumentHelper helper(def);
17 auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
18 const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
19 auto axis_w = helper.GetSingleArgument<int32_t>(
"axis_w", 1);
20 const int canonical_axis_w =
21 canonical_axis_index_(axis_w, in[1].dims().size());
22 const int64_t N = pretransposed_weight
24 : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
26 vector<int64_t> y_shape(in[0].dims().begin(), in[0].dims().end());
27 CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
28 y_shape.resize(canonical_axis + 1);
29 y_shape[canonical_axis] = N;
31 out[0] = CreateTensorShape(y_shape, in[0].data_type());
35 OpSchema::Cost CostInferenceForFC(
36 const OperatorDef& def,
37 const vector<TensorShape>& in,
38 bool pretransposed_weight) {
39 CAFFE_ENFORCE_EQ(in.size(), 3,
"FC requires three inputs");
40 struct OpSchema::Cost c;
41 ArgumentHelper helper(def);
43 auto axis = helper.GetSingleArgument<int32_t>(
"axis", 1);
44 const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
45 const uint64_t
M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
46 const uint64_t K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
47 auto axis_w = helper.GetSingleArgument<int32_t>(
"axis_w", 1);
48 const int canonical_axis_w =
49 canonical_axis_index_(axis_w, in[1].dims().size());
50 const uint64_t N = pretransposed_weight
51 ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
52 : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
54 const auto& X = in[0];
55 c.flops = M * N * (2 * K + 1);
56 c.bytes_read = (K * (M + N) + N) *
sizeof(X.data_type());
57 c.bytes_written = M * N *
sizeof(X.data_type());
58 c.params_bytes = (K * N + N) *
sizeof(X.data_type());
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...