Caffe2 - C++ API
A deep learning, cross platform ML framework
glu_op.h
1 #ifndef CAFFE2_OPERATOR_GLU_OP_H_
2 #define CAFFE2_OPERATOR_GLU_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 template <typename T, class Context>
9 class GluOp final : public Operator<Context> {
10  public:
11  template <class... Args>
12  explicit GluOp(Args&&... args)
13  : Operator<Context>(std::forward<Args>(args)...),
14  dim_(this->template GetSingleArgument<int>("dim", -1)) {}
15 
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17 
18  bool RunOnDevice() {
19  auto& X = Input(0);
20 
21  vector<int64_t> Yshape;
22  Yshape.insert(Yshape.end(), X.sizes().begin(), X.sizes().end());
23  const int split_index = dim_ == -1 ? Yshape.size() - 1 : dim_;
24  CAFFE_ENFORCE(
25  Yshape[split_index] % 2 == 0,
26  "Split dimension ",
27  Yshape[split_index],
28  " should be divided by two");
29  const int split_dim_size = Yshape[split_index] / 2;
30  const int M = X.size_to_dim(split_index);
31  const int N = X.size_from_dim(split_index + 1);
32  Yshape[split_index] = split_dim_size;
33  auto* Y = Output(0, Yshape, at::dtype<T>());
34  ComputeGlu(
35  M,
36  split_dim_size,
37  N,
38  X.template data<T>(),
39  Y->template mutable_data<T>());
40  return true;
41  }
42 
43  protected:
44  void ComputeGlu(
45  const int M,
46  const int split_dim_size,
47  const int N,
48  const T* X,
49  T* output);
50 
51  private:
52  const int dim_;
53 };
54 } // namespace caffe2
55 
56 #endif // CAFFE2_OPERATOR_GLU_OP_H_
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13