Caffe2 - C++ API
A deep learning, cross platform ML framework
glu_op.h
1 #ifndef CAFFE2_OPERATOR_GLU_OP_H_
2 #define CAFFE2_OPERATOR_GLU_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 template <typename T, class Context>
9 class GluOp final : public Operator<Context> {
10  public:
11  GluOp(const OperatorDef& operator_def, Workspace* ws)
12  : Operator<Context>(operator_def, ws),
13  dim_(OperatorBase::GetSingleArgument<int>("dim", -1)) {}
14 
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  bool RunOnDevice() {
18  auto& X = Input(0);
19  auto* Y = Output(0);
20  vector<TIndex> Yshape;
21  Yshape.insert(Yshape.end(), X.dims().begin(), X.dims().end());
22  const int split_index = dim_ == -1 ? Yshape.size() - 1 : dim_;
23  CAFFE_ENFORCE(
24  Yshape[split_index] % 2 == 0,
25  "Split dimension ",
26  Yshape[split_index],
27  " should be divided by two");
28  const int split_dim_size = Yshape[split_index] / 2;
29  const int M = X.size_to_dim(split_index);
30  const int N = X.size_from_dim(split_index + 1);
31  Yshape[split_index] = split_dim_size;
32  Y->Resize(Yshape);
33  ComputeGlu(
34  M,
35  split_dim_size,
36  N,
37  X.template data<T>(),
38  Y->template mutable_data<T>());
39  return true;
40  }
41 
42  protected:
43  void ComputeGlu(
44  const int M,
45  const int split_dim_size,
46  const int N,
47  const T* X,
48  T* output);
49 
50  private:
51  const int dim_;
52 };
53 } // namespace caffe2
54 
55 #endif // CAFFE2_OPERATOR_GLU_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.