Caffe2 - C++ API
A deep learning, cross platform ML framework
shape_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 // RecordShapeOp records the shape of the input tensor to a vector of int. You
6 // mostly don't need this operator explicitly, and it is mostly used in the
7 // autodiff process.
8 class IDEEPShapeOp : public IDEEPOperator {
9  public:
10  USE_IDEEP_DEF_ALIASES();
11  USE_IDEEP_OPERATOR_FUNCTIONS();
12 
13  IDEEPShapeOp(const OperatorDef& operator_def, Workspace* ws)
14  : IDEEPOperator(operator_def, ws),
15  axes_(OperatorBase ::GetRepeatedArgument<int>("axes")) {}
16 
17  bool RunOnDevice() override {
18  int numDims = 0;
19  int numAxes = axes_.size();
20  vector<int64_t> dims;
21  const char* data_dims = nullptr;
22  auto* output = OperatorBase::Output<Tensor>(OUTPUT, CPU);
23 
24  if (OperatorBase::InputBlob(DATA).template IsType<itensor>()) {
25  auto& data = Input(DATA);
26  numDims = data.ndims();
27  auto idims = data.get_dims();
28  dims.assign(idims.begin(), idims.end());
29  data_dims = reinterpret_cast<const char*>(dims.data());
30  } else {
31  auto& data = OperatorBase::Input<Tensor>(DATA, CPU);
32  numDims = data.dim();
33  data_dims = reinterpret_cast<const char*>(data.sizes().data());
34  }
35 
36  if (numAxes == 0) {
37  output->Resize(numDims);
38  int64_t* output_data = output->template mutable_data<int64_t>();
39  context_.CopyBytesSameDevice(
40  numDims * sizeof(int64_t), data_dims, output_data);
41  return true;
42  }
43 
44  output->Resize(numAxes);
45  auto out = reinterpret_cast<char*>(output->template mutable_data<int64_t>());
46  for (int i = 0; i < numAxes; i++) {
47  auto axis = axes_[i];
48  CAFFE_ENFORCE_LT(axis, numDims, "Axis out of range");
49  CAFFE_ENFORCE_GE(axis, 0, "Each axis should be non-negative");
50  context_.CopyBytesSameDevice(
51  sizeof(int64_t), data_dims + axis * sizeof(int64_t), out);
52  out += sizeof(int64_t);
53  }
54 
55  return true;
56  }
57 
58  private:
59  vector<int> axes_;
60 
61  INPUT_TAGS(DATA);
62  OUTPUT_TAGS(OUTPUT);
63 };
64 
65 
66 REGISTER_IDEEP_OPERATOR(Shape, IDEEPShapeOp);
67 
68 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13