Caffe2 - C++ API
A deep learning, cross platform ML framework
shape_op.h
1 
2 #pragma once
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 // RecordShapeOp records the shape of the input tensor to a vector of int. You
10 // mostly don't need this operator explicitly, and it is mostly used in the
11 // autodiff process.
12 template <class Context>
13 class ShapeOp : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16  template <class... Args>
17  explicit ShapeOp(Args&&... args)
18  : Operator<Context>(std::forward<Args>(args)...),
19  axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
20 
21  bool RunOnDevice() override {
22  auto& data = Input(DATA);
23 
24  int numDims = data.dim();
25  int numAxes = axes_.size();
26  if (numAxes == 0) {
27  auto* output = Output(0, {numDims}, at::dtype<int64_t>());
28  int64_t* output_data = output->template mutable_data<int64_t>();
29  context_.CopyBytesSameDevice(
30  numDims * sizeof(int64_t), data.sizes().data(), output_data);
31  return true;
32  }
33 
34  auto* output = Output(0, {numAxes}, at::dtype<int64_t>());
35  auto src = reinterpret_cast<const char*>(data.sizes().data());
36  auto out = reinterpret_cast<char*>(output->template mutable_data<int64_t>());
37  for (int i = 0; i < numAxes; i++) {
38  auto axis = axes_[i];
39  CAFFE_ENFORCE_LT(axis, numDims, "Axis out of range");
40  CAFFE_ENFORCE_GE(axis, 0, "Each axis should be non-negative");
41  context_.CopyBytesSameDevice(
42  sizeof(int64_t), src + axis * sizeof(int64_t), out);
43  out += sizeof(int64_t);
44  }
45  return true;
46  }
47 
48  INPUT_TAGS(DATA);
49 
50  private:
51  vector<int> axes_;
52 };
53 
54 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13