1 #ifndef CAFFE2_OPERATOR_GLU_OP_H_ 2 #define CAFFE2_OPERATOR_GLU_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 8 template <
typename T,
class Context>
11 template <
class... Args>
12 explicit GluOp(Args&&... args)
14 dim_(this->
template GetSingleArgument<int>(
"dim", -1)) {}
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 vector<int64_t> Yshape;
22 Yshape.insert(Yshape.end(), X.sizes().begin(), X.sizes().end());
23 const int split_index = dim_ == -1 ? Yshape.size() - 1 : dim_;
25 Yshape[split_index] % 2 == 0,
28 " should be divided by two");
29 const int split_dim_size = Yshape[split_index] / 2;
30 const int M = X.size_to_dim(split_index);
31 const int N = X.size_from_dim(split_index + 1);
32 Yshape[split_index] = split_dim_size;
33 auto* Y = Output(0, Yshape, at::dtype<T>());
39 Y->template mutable_data<T>());
46 const int split_dim_size,
56 #endif // CAFFE2_OPERATOR_GLU_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...