Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_tile_op.h
1 
17 #ifndef CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
18 #define CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
19 
20 #include "caffe2/core/operator.h"
21 #include "caffe2/utils/math.h"
22 
23 namespace caffe2 {
24 
25 template <class Context>
26 class LengthsTileOp : public Operator<Context> {
27  public:
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29  USE_SIMPLE_CTOR_DTOR(LengthsTileOp);
30 
31  bool RunOnDevice() override {
32  auto& data = Input(DATA);
33  auto& lengths = Input(LENGTHS);
34  auto* output = Output(0);
35 
36  CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
37  CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
38  CAFFE_ENFORCE_EQ(lengths.size(), data.dim(0));
39 
40  // Context::CopyFrom and math::Sum need the same context to avoid race
41  // conditions
42  CPUContext cpuContext;
43  lengths_host_.CopyFrom(lengths, &cpuContext);
44  auto lengths_size = lengths_host_.size();
45  auto* lengths_data = lengths_host_.data<int32_t>();
46 
47  int32_t total_length = 0;
48  math::Sum<int32_t, CPUContext>(
49  lengths_size, lengths_data, &total_length, &cpuContext);
50 
51  auto shape = data.dims();
52  shape[0] = total_length;
53  output->Resize(shape);
54 
55  auto block_bytesize = data.size_from_dim(1) * data.meta().itemsize();
56  auto src = static_cast<const char*>(data.raw_data());
57  auto out = static_cast<char*>(output->raw_mutable_data(data.meta()));
58 
59  for (TIndex i = 0; i < lengths_size; ++i) {
60  auto length = lengths_data[i];
61  CAFFE_ENFORCE_GE(length, 0);
62  for (int32_t j = 0; j < length; ++j) {
63  context_.template CopyBytes<Context, Context>(block_bytesize, src, out);
64  out += block_bytesize;
65  }
66  src += block_bytesize;
67  }
68  return true;
69  }
70 
71  INPUT_TAGS(DATA, LENGTHS);
72 
73  private:
74  TensorCPU lengths_host_;
75 };
76 
77 } // namespace caffe2
78 
79 #endif // CAFFE2_OPERATORS_LENGTHS_TILE_OP_H_
const T * data() const
Returns a typed pointer of the underlying storage.
Definition: tensor.h:500
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:182
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:82
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
Copyright (c) 2016-present, Facebook, Inc.