Caffe2 - C++ API
A deep learning, cross platform ML framework
length_split_op.h
1 #ifndef CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
2 #define CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class LengthsSplitOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  template <class... Args>
18  explicit LengthsSplitOp(Args&&... args)
19  : Operator<Context>(std::forward<Args>(args)...),
20  n_split_(OperatorBase::GetSingleArgument<int32_t>("n_split", 0)) {
21  if (InputSize() == 1) {
22  // If not specified, then must have this argument
23  CAFFE_ENFORCE(
24  OperatorBase::HasArgument("n_split"),
25  "Argument `n_split` is missing and was not specified as input.");
26  CAFFE_ENFORCE(
27  n_split_ > 0,
28  "`n_split` must contain a positive value for defined behavior.");
29  }
30  }
31  ~LengthsSplitOp() {}
32 
33  bool RunOnDevice() override {
34  const auto& L = Input(0);
35  CAFFE_ENFORCE_EQ(L.dim(), 1, "Input `LENGTHS` should be a 1D vector.");
36 
37  if (InputSize() > 1) {
38  // We potentially have n_split specified as inputs as well
39  CAFFE_ENFORCE(
40  Input(1).dim() == 1 && Input(1).numel() == 1,
41  "Input `n_split` should be a vector of size 1.");
42 
43  const auto& input1 = Input(1);
44  context_.template CopyItems<Context, CPUContext>(
45  input1.dtype(), 1, input1.raw_data(), &n_split_);
46  }
47 
48  CAFFE_ENFORCE(
49  n_split_ > 0,
50  "`n_split` must contain a positive value for defined behavior.");
51  const auto M = L.numel();
52 
53  auto* Y = Output(0, {M * n_split_}, at::dtype<int32_t>());
54 
55  const int32_t* Ldata = L.template data<int32_t>();
56  int32_t* Ydata = Y->template mutable_data<int32_t>();
57 
58  for (int i = 0; i < M; i++) {
59  int32_t mod = Ldata[i] % n_split_;
60  int32_t res =
61  mod != 0 ? math::DivUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
62  for (int j = 0; j < n_split_; j++) {
63  Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1;
64  }
65  }
66  return true;
67  }
68 
69  private:
70  int32_t n_split_;
71 };
72 
73 } // namespace caffe2
74 
75 #endif // CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
Definition: any.cpp:108
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70