17 #ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_ 18 #define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_SPARSE_H_ 20 #include "caffe2/core/context.h" 21 #include "caffe2/core/operator.h" 22 #include "caffe2/utils/math.h" 25 #endif // CAFFE2_USE_MKL 32 using Shape = std::array<int, N>;
35 const std::vector<int64_t>& shape(Shape<N> vs) {
36 static thread_local std::vector<int64_t> cache;
37 cache.resize(vs.size());
38 for (
auto i = 0; i < vs.size(); ++i) {
44 inline const std::vector<int64_t>& shape(
int i) {
45 return shape<1>(Shape<1>({i}));
48 inline const std::vector<int64_t>& shape(
int i,
int j) {
49 return shape<2>(Shape<2>({i, j}));
52 template <
typename T,
class Context>
53 void Sparse_mm(
const T* acsr,
const int* ia,
const int* ja,
54 int m,
int k,
int n,
const T* b,
T* c, Context* context);
56 template<
typename T,
class Context>
57 void trans_mat(
const T* o,
T* t,
int m,
int n, Context* context);
60 void trans_mat<float, CPUContext>(
66 for(
int i = 0; i < m; ++i){
67 for(
int j = 0; j < n; ++j){
76 void Sparse_mm<float, CPUContext>(
86 float alpha = 1.0, beta = 0.;
87 mkl_scsrmm(
"N", &m, &n, &k, &alpha,
"GLNC",
88 acsr, ja, ia, ia+1, b, &n, &beta, c, &n);
94 template <
typename T,
class Context,
class Engine=DefaultEngine>
97 USE_OPERATOR_CONTEXT_FUNCTIONS;
102 bool RunOnDevice()
override {
103 const auto& Xt =
Input(0);
104 const auto& Wcsr =
Input(1);
105 const auto& iw =
Input(2);
106 const auto& jw =
Input(3);
108 const auto& b =
Input(4);
111 CAFFE_ENFORCE_EQ(Xt.dim(), 2);
112 CAFFE_ENFORCE_EQ(b.dim(), 1);
114 int K = Xt.dim() > 1 ? Xt.dim32(0) : 1;
116 int M = Xt.numel() / K;
118 int N = iw.dim32(0)-1;
119 CAFFE_ENFORCE_EQ(N, b.dim32(0));
120 auto* Yt = Output(0, shape(N, M), at::dtype<T>());
123 Sparse_mm<T, Context>(
124 Wcsr.template data<T>(), iw.template data<int>(),
125 jw.template data<int>(), N, K, M, Xt.template data<T>(),
126 Yt->template mutable_data<T>(), &context_);
128 if (bias_multiplier_.numel() != M) {
130 bias_multiplier_.Resize(shape(M));
131 math::Set<T, Context>(
132 M,
static_cast<T>(1), bias_multiplier_.template mutable_data<T>(),
135 math::Gemm<T, Context, Engine>(
136 CblasNoTrans, CblasNoTrans, N, M, 1, 1,
137 b.template data<T>(), bias_multiplier_.template data<T>(), 1,
138 Yt->template mutable_data<T>(), &context_);
143 Tensor bias_multiplier_{Context::GetDeviceType()};
149 #endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...