4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
16 template <
class... Args>
17 explicit ShapeOp(Args&&... args)
19 axes_(OperatorBase ::GetRepeatedArgument<int>(
"axes")) {}
21 bool RunOnDevice()
override {
22 auto& data =
Input(DATA);
24 int numDims = data.dim();
25 int numAxes = axes_.size();
27 auto* output = Output(0, {numDims}, at::dtype<int64_t>());
28 int64_t* output_data = output->template mutable_data<int64_t>();
29 context_.CopyBytesSameDevice(
30 numDims *
sizeof(int64_t), data.sizes().data(), output_data);
34 auto* output = Output(0, {numAxes}, at::dtype<int64_t>());
35 auto src =
reinterpret_cast<const char*
>(data.sizes().data());
36 auto out =
reinterpret_cast<char*
>(output->template mutable_data<int64_t>());
37 for (
int i = 0; i < numAxes; i++) {
39 CAFFE_ENFORCE_LT(axis, numDims,
"Axis out of range");
40 CAFFE_ENFORCE_GE(axis, 0,
"Each axis should be non-negative");
41 context_.CopyBytesSameDevice(
42 sizeof(int64_t), src + axis *
sizeof(int64_t), out);
43 out +=
sizeof(int64_t);
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...