Caffe2 - C++ API
A deep learning, cross platform ML framework
fully_connected_op_sparse.h
1 
17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 #ifdef CAFFE2_USE_MKL
24 #include <mkl.h>
25 #endif // CAFFE2_USE_MKL
26 
27 namespace caffe2 {
28 
29 namespace {
30 
31 template<int N>
32 using Shape = std::array<int, N>;
33 
34 template<int N>
35 const std::vector<int64_t>& shape(Shape<N> vs) {
36  static thread_local std::vector<int64_t> cache;
37  cache.resize(vs.size());
38  for (auto i = 0; i < vs.size(); ++i) {
39  cache[i] = vs[i];
40  }
41  return cache;
42 }
43 
44 inline const std::vector<int64_t>& shape(int i) {
45  return shape<1>(Shape<1>({i}));
46 }
47 
48 inline const std::vector<int64_t>& shape(int i, int j) {
49  return shape<2>(Shape<2>({i, j}));
50 }
51 
52 template <typename T, class Context>
53 void Sparse_mm(const T* acsr, const int* ia, const int* ja,
54  int m, int k, int n, const T* b, T* c, Context* context);
55 
56 template<typename T, class Context>
57 void trans_mat(const T* o, T* t, int m, int n, Context* context);
58 
59 template <>
60 void trans_mat<float, CPUContext>(
61  const float* o,
62  float* t,
63  int m,
64  int n,
65  CPUContext* /*context*/) {
66  for(int i = 0; i < m; ++i){
67  for(int j = 0; j < n; ++j){
68  t[j*m+i]=o[i*n+j];
69  }
70  }
71 }
72 
73 // C = A(sparse) * B
74 // No transpose;
75 template <>
76 void Sparse_mm<float, CPUContext>(
77  const float* acsr,
78  const int* ia,
79  const int* ja,
80  int m,
81  int k,
82  int n,
83  const float* b,
84  float* c,
85  CPUContext* /*context*/) {
86  float alpha = 1.0, beta = 0.;
87  mkl_scsrmm("N", &m, &n, &k, &alpha, "GLNC",
88  acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
89 }
90 
91 }
92 
93 // This is Caffe's InnerProductOp, with a name that fits its purpose better.
94 template <typename T, class Context, class Engine=DefaultEngine>
95 class FullyConnectedOp_SPARSE final : public Operator<Context> {
96  public:
97  USE_OPERATOR_CONTEXT_FUNCTIONS;
98  FullyConnectedOp_SPARSE(const OperatorDef& operator_def, Workspace* ws)
99  : Operator<Context>(operator_def, ws) {}
101 
102  bool RunOnDevice() override {
103  const auto& Xt = Input(0); // transposed X
104  const auto& Wcsr = Input(1);
105  const auto& iw = Input(2);
106  const auto& jw = Input(3);
107  // Notice that we do not need to transpose b
108  const auto& b = Input(4);
109  // transposed Y
110  // here we assume X is k-by-m
111  CAFFE_ENFORCE_EQ(Xt.dim(), 2);
112  CAFFE_ENFORCE_EQ(b.dim(), 1);
113  // batch size
114  int K = Xt.dim() > 1 ? Xt.dim32(0) : 1;
115  // Feature dimension
116  int M = Xt.numel() / K;
117  // number of outputs.
118  int N = iw.dim32(0)-1;
119  CAFFE_ENFORCE_EQ(N, b.dim32(0));
120  auto* Yt = Output(0, shape(N, M), at::dtype<T>());
121 
122  // Y' = W * X';
123  Sparse_mm<T, Context>(
124  Wcsr.template data<T>(), iw.template data<int>(),
125  jw.template data<int>(), N, K, M, Xt.template data<T>(),
126  Yt->template mutable_data<T>(), &context_);
127  // Add bias term
128  if (bias_multiplier_.numel() != M) {
129  // If the helper bias multiplier is not M, reshape and fill it with one.
130  bias_multiplier_.Resize(shape(M));
131  math::Set<T, Context>(
132  M, static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
133  &context_);
134  }
135  math::Gemm<T, Context, Engine>(
136  CblasNoTrans, CblasNoTrans, N, M, 1, 1,
137  b.template data<T>(), bias_multiplier_.template data<T>(), 1,
138  Yt->template mutable_data<T>(), &context_);
139  return true;
140  }
141 
142  protected:
143  Tensor bias_multiplier_{Context::GetDeviceType()};
144 };
145 
146 
147 } // namespace caffe2
148 
149 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Definition: any.cpp:108
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13