Caffe2 - C++ API
A deep learning, cross platform ML framework
fc_inference.cc
1 #include "caffe2/operators/fc_inference.h"
2 
3 namespace caffe2 {
4 std::vector<TensorShape> FCShapeInference(
5  const OperatorDef& def,
6  const vector<TensorShape>& in,
7  bool pretransposed_weight) {
8  vector<TensorShape> out(1);
9 
10  if (in[0].unknown_shape() || in[1].unknown_shape()) {
11  out[0].set_unknown_shape(true);
12  return out;
13  }
14 
15  ArgumentHelper helper(def);
16 
17  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
18  const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
19  auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
20  const int canonical_axis_w =
21  canonical_axis_index_(axis_w, in[1].dims().size());
22  const int64_t N = pretransposed_weight
23  ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
24  : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
25 
26  vector<int64_t> y_shape(in[0].dims().begin(), in[0].dims().end());
27  CAFFE_ENFORCE_LE(canonical_axis + 1, y_shape.size());
28  y_shape.resize(canonical_axis + 1);
29  y_shape[canonical_axis] = N;
30 
31  out[0] = CreateTensorShape(y_shape, in[0].data_type());
32  return out;
33 }
34 
35 OpSchema::Cost CostInferenceForFC(
36  const OperatorDef& def,
37  const vector<TensorShape>& in,
38  bool pretransposed_weight) {
39  CAFFE_ENFORCE_EQ(in.size(), 3, "FC requires three inputs");
40  struct OpSchema::Cost c;
41  ArgumentHelper helper(def);
42 
43  auto axis = helper.GetSingleArgument<int32_t>("axis", 1);
44  const auto canonical_axis = canonical_axis_index_(axis, in[0].dims().size());
45  const uint64_t M = size_to_dim_(canonical_axis, GetDimsVector(in[0]));
46  const uint64_t K = size_from_dim_(canonical_axis, GetDimsVector(in[0]));
47  auto axis_w = helper.GetSingleArgument<int32_t>("axis_w", 1);
48  const int canonical_axis_w =
49  canonical_axis_index_(axis_w, in[1].dims().size());
50  const uint64_t N = pretransposed_weight
51  ? size_from_dim_(canonical_axis_w, GetDimsVector(in[1]))
52  : size_to_dim_(canonical_axis_w, GetDimsVector(in[1]));
53 
54  const auto& X = in[0];
55  c.flops = M * N * (2 * K + 1);
56  c.bytes_read = (K * (M + N) + N) * sizeof(X.data_type());
57  c.bytes_written = M * N * sizeof(X.data_type());
58  c.params_bytes = (K * N + N) * sizeof(X.data_type());
59  return c;
60 }
61 } // namespace caffe2
Definition: any.cpp:108
int64_t size_from_dim_(int k, IntArrayRef dims)
Return product of all dimensions starting from k.
Definition: TensorImpl.h:53
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13