1 #include "caffe2/operators/lengths_tile_op.h" 6 bool LengthsTileOp<CPUContext>::RunOnDevice() {
7 auto& data = Input(DATA);
8 auto& lengths = Input(LENGTHS);
9 auto* output = Output(0);
11 CAFFE_ENFORCE_EQ(lengths.dim(), 1,
"LENGTHS must be 1-D");
12 CAFFE_ENFORCE_GE(data.dim(), 1,
"DATA should be at least 1-D");
13 CAFFE_ENFORCE_EQ(lengths.numel(), data.size(0));
18 lengths_host_.CopyFrom(lengths);
19 auto lengths_size = lengths_host_.numel();
20 auto* lengths_data = lengths_host_.data<int32_t>();
22 int32_t total_length = 0;
23 CPUContext cpuContext;
24 math::Sum<int32_t, CPUContext>(
25 lengths_size, lengths_data, &total_length, &cpuContext);
27 auto shape = data.sizes().vec();
28 shape[0] = total_length;
29 output->Resize(shape);
31 auto block_bytesize = data.size_from_dim(1) * data.dtype().itemsize();
32 auto src =
static_cast<const char*
>(data.raw_data());
33 auto out =
static_cast<char*
>(output->raw_mutable_data(data.dtype()));
35 for (int64_t i = 0; i < lengths_size; ++i) {
36 auto length = lengths_data[i];
37 CAFFE_ENFORCE_GE(length, 0);
38 for (int32_t j = 0; j < length; ++j) {
39 context_.CopyBytesSameDevice(block_bytesize, src, out);
40 out += block_bytesize;
42 src += block_bytesize;
47 REGISTER_CPU_OPERATOR(LengthsTile, LengthsTileOp<CPUContext>);
49 OPERATOR_SCHEMA(LengthsTile)
53 Given DATA tensor of rank r >= 1, and LENGTHS tensor of rank 1, duplicate each 54 entry of the outer-most dimension of DATA according to LENGTHS, and concatenate 55 them in an output tensor of rank r. 64 LENGTHS = [0, 1, 3, 2] 77 "Tensor of rank r >= 1. First dimension must be equal to the size of " 79 .Input(1,
"LENGTHS",
"Tensor of int32 lengths of rank 1")
80 .Output(0,
"OUTPUT",
"Tensor of rank r");
82 class GetLengthsTileGradient :
public GradientMakerBase {
83 using GradientMakerBase::GradientMakerBase;
84 vector<OperatorDef> GetGradientDefs()
override {
85 CAFFE_ENFORCE_EQ(def_.input_size(), 2);
86 return SingleGradientDef(
91 vector<string>{GO(0), I(1)},
93 vector<string>{GI(0)});
96 REGISTER_GRADIENT(LengthsTile, GetLengthsTileGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...