1 #ifndef CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_ 2 #define CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 10 template <
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
20 bool RunOnDevice()
override {
25 template <
typename TData>
26 bool DoRunWithType() {
31 template <
typename TData,
typename TScale>
32 bool DoRunWithType2() {
33 auto& data =
Input(DATA);
34 auto& scale =
Input(SCALE);
36 auto* Y = Output(0, data.sizes(), at::dtype<TData>());
37 size_t batchSize = data.size(0);
38 size_t exampleSize = data.size_from_dim(1);
39 CAFFE_ENFORCE(batchSize == scale.size(0), batchSize,
" != ", scale.size(0));
40 auto* scalePtr = scale.template data<TScale>();
41 auto* dataPtr = data.template data<TData>();
42 auto* yPtr = Y->template mutable_data<TData>();
43 for (
auto i = 0; i < batchSize; ++i) {
44 auto scale = scalePtr[i];
45 CAFFE_ENFORCE(scale >= 0, scale,
" < 0");
46 auto multiplier = scale == 0 ? 1.0 : 1 / std::sqrt(scale);
47 math::Scale<float, TData, Context>(
50 dataPtr + i * exampleSize,
51 yPtr + i * exampleSize,
57 INPUT_TAGS(DATA, SCALE);
62 #endif // CAFFE2_OPERATORS_SQUARE_ROOT_DIVIDE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...