1 #ifndef CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_ 2 #define CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 template <
class... Args>
20 n_split_(OperatorBase::GetSingleArgument<int32_t>(
"n_split", 0)) {
21 if (InputSize() == 1) {
25 "Argument `n_split` is missing and was not specified as input.");
28 "`n_split` must contain a positive value for defined behavior.");
33 bool RunOnDevice()
override {
34 const auto& L =
Input(0);
35 CAFFE_ENFORCE_EQ(L.dim(), 1,
"Input `LENGTHS` should be a 1D vector.");
37 if (InputSize() > 1) {
41 "Input `n_split` should be a vector of size 1.");
43 const auto& input1 =
Input(1);
44 context_.template CopyItems<Context, CPUContext>(
45 input1.dtype(), 1, input1.raw_data(), &n_split_);
50 "`n_split` must contain a positive value for defined behavior.");
51 const auto M = L.numel();
53 auto* Y = Output(0, {
M * n_split_}, at::dtype<int32_t>());
55 const int32_t* Ldata = L.template data<int32_t>();
56 int32_t* Ydata = Y->template mutable_data<int32_t>();
58 for (
int i = 0; i <
M; i++) {
59 int32_t mod = Ldata[i] % n_split_;
61 mod != 0 ? math::DivUp(Ldata[i], n_split_) : Ldata[i] / n_split_ + 1;
62 for (
int j = 0; j < n_split_; j++) {
63 Ydata[(i * n_split_) + j] = mod-- > 0 ? res : res - 1;
75 #endif // CAFFE2_OPERATORS_LENGTH_SPLIT_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.