1 #ifndef CAFFE2_UTILS_MATH_H_ 2 #define CAFFE2_UTILS_MATH_H_ 9 #include "caffe2/utils/cblas.h" 12 #ifdef CAFFE2_USE_ACCELERATE 13 #include <Accelerate/Accelerate.h> 14 #endif // CAFFE2_USE_ACCELERATE 16 #include "caffe2/core/common.h" 17 #include "caffe2/core/types.h" 18 #include "caffe2/utils/math/broadcast.h" 19 #include "caffe2/utils/math/elementwise.h" 20 #include "caffe2/utils/math/reduce.h" 21 #include "caffe2/utils/math/transpose.h" 22 #include "caffe2/utils/math/utils.h" 35 #define C10_DECLARE_COMPARE_OP(Comp) \ 36 template <typename T, class Context, bool kBroadcast1st = false> \ 45 template <typename T, class Context, bool kBroadcast1st = false> \ 54 template <typename T, class Context> \ 65 C10_DECLARE_COMPARE_OP(EQ)
66 C10_DECLARE_COMPARE_OP(NE)
67 C10_DECLARE_COMPARE_OP(LT)
68 C10_DECLARE_COMPARE_OP(LE)
69 C10_DECLARE_COMPARE_OP(GT)
70 C10_DECLARE_COMPARE_OP(GE)
72 #undef C10_DECLARE_COMPARE_OP 74 #define C10_DECLARE_BINARY_OP(Func) \ 75 template <typename T, class Context, bool kBroadcast1st = false> \ 84 template <typename T, class Context, bool kBroadcast1st = false> \ 93 template <typename T, class Context> \ 104 C10_DECLARE_BINARY_OP(
Add)
105 C10_DECLARE_BINARY_OP(Sub)
106 C10_DECLARE_BINARY_OP(Mul)
107 C10_DECLARE_BINARY_OP(Div)
109 C10_DECLARE_BINARY_OP(And)
110 C10_DECLARE_BINARY_OP(Or)
111 C10_DECLARE_BINARY_OP(Xor)
113 C10_DECLARE_BINARY_OP(BitwiseAnd)
114 C10_DECLARE_BINARY_OP(BitwiseOr)
115 C10_DECLARE_BINARY_OP(BitwiseXor)
117 #undef C10_DECLARE_BINARY_OP 120 template <
typename T,
class Context>
121 CAFFE2_API
void Broadcast(
132 template <
typename T,
class Context>
133 CAFFE2_API
void InvStd(
142 template <
typename T,
class Context>
143 CAFFE2_API
void AddStripedBatch(
153 template <
typename T,
class Context>
155 RowwiseMax(
const int N,
const int D,
const T* x,
T* y, Context* context);
159 template <
typename T,
class Context>
161 ColwiseMax(
const int N,
const int D,
const T* x,
T* y, Context* context);
164 template <
typename T,
class Context>
166 Maximum(
const int N,
const float alpha,
const T* x,
T* y, Context* context);
170 template <
typename T,
class Context,
class Engine = DefaultEngine>
171 CAFFE2_API
void Gemm(
172 const CBLAS_TRANSPOSE trans_A,
173 const CBLAS_TRANSPOSE trans_B,
183 TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
187 template <
typename T,
class Context,
class Engine = DefaultEngine>
188 CAFFE2_API
void GemmEx(
189 const CBLAS_TRANSPOSE trans_A,
190 const CBLAS_TRANSPOSE trans_B,
205 template <
typename T,
class Context,
class Engine = DefaultEngine>
206 CAFFE2_API
void GemmBatched(
207 const CBLAS_TRANSPOSE trans_A,
208 const CBLAS_TRANSPOSE trans_B,
209 const int batch_size,
219 TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
221 template <
typename T,
class Context,
class Engine = DefaultEngine>
222 CAFFE2_API
void GemmStridedBatched(
223 const CBLAS_TRANSPOSE trans_A,
224 const CBLAS_TRANSPOSE trans_B,
225 const int batch_size,
238 TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
244 template <
typename T,
class Context,
class Engine = DefaultEngine>
245 CAFFE2_API
void Gemv(
246 const CBLAS_TRANSPOSE trans_A,
255 TensorProto::DataType math_type = TensorProto_DataType_FLOAT);
257 template <
typename T,
class Context>
259 RandUniform(
const size_t n,
const T a,
const T b,
T* r, Context* context);
263 template <
typename T,
class Context>
264 CAFFE2_API
void RandFixedSum(
272 template <
typename T,
class Context>
273 CAFFE2_API
void RandUniformUnique(
284 template <
typename T,
class Context>
286 RandSyntheticData(
const size_t n,
const T a,
const T b,
T* r, Context* context);
288 template <
typename T,
class Context>
290 RandGaussian(
const size_t n,
const T mean,
const T std,
T* r, Context* context);
293 template <
typename T,
class Context>
295 Dot(
const int N,
const T* a,
const T* b,
T* y, Context* context);
298 template <
typename T,
class Context>
304 Tensor* scratch_ptr =
nullptr);
307 template <
typename T,
class Context>
308 CAFFE2_API
void SumSqr(
313 Tensor* scratch_ptr =
nullptr);
317 template <
typename T,
class Context>
318 CAFFE2_API
void Select(
326 template <
typename T,
class Context>
328 Axpy(
const int N,
const float alpha,
const T* x,
T* y, Context* context);
333 template <
typename T,
class Context>
335 Axpy(
const int N,
const float* alpha,
const T* x,
T* y, Context* context);
337 template <
typename TCoeff,
typename TData,
class Context>
338 CAFFE2_API
void Axpby(
346 template <
typename TCoeff,
typename TData,
class Context>
347 CAFFE2_API
void Axpby(
360 template <
typename T,
class Context, StorageOrder kOrder>
361 CAFFE2_API
void Im2Col(
367 const int dilation_h,
368 const int dilation_w,
378 const int groups = 1);
381 template <
typename T,
class Context, StorageOrder kOrder>
382 CAFFE2_API
void Im2ColNd(
386 const int* img_shape,
387 const int* col_shape,
388 const int* kernel_shape,
395 const int groups = 1);
402 template <
typename T,
class Context, StorageOrder kOrder>
403 CAFFE2_API
void Col2Im(
409 const int dilation_h,
410 const int dilation_w,
420 const int groups = 1);
427 template <
typename T,
class Context, StorageOrder kOrder>
428 CAFFE2_API
void Col2ImNd(
432 const int* img_shape,
433 const int* col_shape,
434 const int* kernel_shape,
441 const int groups = 1);
445 template <
typename T,
class Context>
446 CAFFE2_API
void BiasCHW(
448 const T* bias_multiplier,
449 const int bias_channels,
450 const int image_size,
454 template <
class Context>
455 CAFFE2_API
void CopyMatrix(
456 const size_t item_size,
464 TypeMeta::Copy copy =
nullptr);
466 template <
typename T,
class Context>
467 CAFFE2_API
void CopyMatrix(
476 template <
typename T,
class Context>
477 CAFFE2_API
void CopyMatrix(
481 const int A_outer_stride,
482 const int A_inner_stride,
484 const int B_outer_stride,
485 const int B_inner_stride,
488 template <
typename T,
class Context>
489 CAFFE2_API
void CopyVector(
const int N,
const T* A,
T* B, Context* context);
495 #include "caffe2/utils/math-detail.h" 496 #endif // CAFFE2_UTILS_MATH_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...