Caffe2 - C++ API
A deep learning, cross platform ML framework
tile_op.cc
1 #include "caffe2/operators/tile_op.h"
2 
3 #include <string>
4 
5 namespace caffe2 {
6 
7 template <>
8 bool TileOp<CPUContext>::RunOnDevice() {
9  return DispatchHelper<
10  TensorTypes<std::int32_t, std::int64_t, float, double, std::string>>::
11  call(this, Input(0));
12 }
13 
14 template <>
15 template <>
16 bool TileOp<CPUContext>::DoRunWithType<std::string>() {
17  if (InputSize() > 1) {
18  // We potentially have tiles and/or axis specified as inputs
19  // as well. We will check for them in that order. In other words:
20  // InputSize() == 2: tiles is specified
21  // InputSize() == 3: tiles is specified and axis.
22  // Anything specified as input will override the arguments
23  CAFFE_ENFORCE(
24  Input(1).dim() == 1 && Input(1).numel() == 1,
25  "Input `tiles` should be a vector of size 1.");
26  tiles_ = GetArgFromTensor(Input(1));
27 
28  // Because of a bug in original code, temporarily adds this part to keep
29  // backward compatibility.
30  // TODO(yangxm): Remove this part when prod runtime upgraded with fixed
31  // model config.
32  if (Input(1).IsType<std::int64_t>()) {
33  axis_ = 0;
34  }
35 
36  if (InputSize() > 2) {
37  CAFFE_ENFORCE(
38  Input(2).dim() == 1 && Input(2).numel() == 1,
39  "Input `axis` should be a vector of size 1.");
40  axis_ = GetArgFromTensor(Input(2));
41  } else {
42  CAFFE_ENFORCE(
44  "Argument `axis` is missing and was not specified as input.");
45  }
46  } else {
47  CAFFE_ENFORCE(
49  "Argument `tiles` is missing and was not specified as input.");
50  CAFFE_ENFORCE(
52  "Argument `axis` is missing and was not specified as input.");
53  }
54 
55  const auto& X = Input(0);
56  auto* Y = Output(0);
57  const int axis = X.canonical_axis_index(axis_);
58 
59  // reshape output to be input tiled along the axis
60  std::vector<std::int64_t> Y_dims = X.sizes().vec();
61  Y_dims[axis] *= tiles_;
62  Y->Resize(Y_dims);
63 
64  // size up to (and not including) axis
65  const int outer_size = X.size_to_dim(axis);
66  // size from axis up
67  const int inner_size = X.size_from_dim(axis);
68 
69  const TypeMeta& meta = X.dtype();
70  const int item_size = X.itemsize();
71  const char* X_ptr = reinterpret_cast<const char*>(X.raw_data());
72  char* Y_ptr = reinterpret_cast<char*>(Y->raw_mutable_data(meta));
73  for (int i = 0; i < outer_size; ++i) {
74  for (int t = 0; t < tiles_; ++t) {
75  context_.CopyItemsSameDevice(meta, inner_size, X_ptr, Y_ptr);
76  Y_ptr += inner_size * item_size;
77  }
78  X_ptr += inner_size * item_size;
79  }
80  return true;
81 }
82 
83 REGISTER_CPU_OPERATOR(Tile, TileOp<CPUContext>);
84 REGISTER_CPU_OPERATOR(TileGradient, TileGradientOp<CPUContext>);
85 
86 OPERATOR_SCHEMA(Tile)
87  .NumInputs(1, 3)
88  .NumOutputs(1)
89  .TensorInferenceFunction([](const OperatorDef& def,
90  const std::vector<TensorShape>& in) {
91  std::vector<TensorShape> out(1);
92  out[0] = TensorShape(in[0]);
93  ArgumentHelper helper(def);
94  const std::int32_t tiles =
95  helper.GetSingleArgument<std::int32_t>("tiles", 1);
96  const std::int32_t axis =
97  helper.GetSingleArgument<std::int32_t>("axis", 0);
98  if (in.size() > 1) {
99  // Tile or axis is specified as input; we can't determine
100  // the size
101  out[0].set_unknown_shape(true);
102  } else {
103  const auto canonical_axis =
104  canonical_axis_index_(axis, out[0].dims().size());
105  out[0].set_dims(
106  canonical_axis, out[0].dims().Get(canonical_axis) * tiles);
107  }
108  return out;
109  })
110  .SetDoc(R"DOC(
111 Constructs a tensor by tiling a given tensor along a specified axis. This operation creates a new tensor by replicating the input tensor a number of times specified by the `tiles` argument along the `axis` dimension. The output tensor's `axis` dimension has $(X.dims(axis) * tiles)$ elements.
112 
113 Github Links:
114 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/tile_op.cc
115 
116 <details>
117 
118 <summary> <b>Example</b> </summary>
119 
120 **Code**
121 
122 ```
123 
124 workspace.ResetWorkspace()
125 
126 op = core.CreateOperator(
127  "Tile",
128  ["X", "tiles", "axis"],
129  ["Y"]
130 )
131 
132 workspace.FeedBlob("X", np.random.randint(10, size=(5,5)))
133 workspace.FeedBlob("tiles", np.array([5]).astype(np.int32))
134 workspace.FeedBlob("axis", np.array([1]).astype(np.int32))
135 print("X:", workspace.FetchBlob("X"))
136 workspace.RunOperatorOnce(op)
137 print("Y:", workspace.FetchBlob("Y"))
138 
139 ```
140 
141 **Result**
142 
143 ```
144 
145 X:
146 [[9 1 7 1 3]
147  [2 3 6 2 5]
148  [0 9 2 6 4]
149  [5 8 1 5 9]
150  [2 0 1 3 7]]
151 Y:
152 [[9 1 7 1 3 9 1 7 1 3 9 1 7 1 3 9 1 7 1 3 9 1 7 1 3]
153  [2 3 6 2 5 2 3 6 2 5 2 3 6 2 5 2 3 6 2 5 2 3 6 2 5]
154  [0 9 2 6 4 0 9 2 6 4 0 9 2 6 4 0 9 2 6 4 0 9 2 6 4]
155  [5 8 1 5 9 5 8 1 5 9 5 8 1 5 9 5 8 1 5 9 5 8 1 5 9]
156  [2 0 1 3 7 2 0 1 3 7 2 0 1 3 7 2 0 1 3 7 2 0 1 3 7]]
157 
158 ```
159 
160 </details>
161 
162 )DOC")
163  .Arg("tiles", "(*int*): number of replicas")
164  .Arg("axis", "(*int*): axis to replicate along")
165  .Input(0, "X", "(*Tensor*): input tensor")
166  .Input(
167  1,
168  "tiles",
169  "(*Tensor`<int>`*): [OPTIONAL] number of replicas (overrides `tiles` argument)")
170  .Input(
171  2,
172  "axis",
173  "(*Tensor`<int>`*): [OPTIONAL] axis to replicate along (overrides `axis` argument)")
174  .Output(0, "Y", "(*Tensor*): output tensor")
175  .InheritOnnxSchema();
176 
177 OPERATOR_SCHEMA(TileGradient).NumInputs(1, 3).NumOutputs(1);
178 
179 namespace {
180 
181 class GetTileGradient : public GradientMakerBase {
182  using GradientMakerBase::GradientMakerBase;
183  std::vector<OperatorDef> GetGradientDefs() override {
184  // Check whether the tiles/axis information was
185  // passed through input arguments
186  std::vector<std::string> g_inputs({GO(0)});
187  if (Def().input_size() > 1) {
188  g_inputs.push_back(I(1));
189  }
190  if (Def().input_size() > 2) {
191  g_inputs.push_back(I(2));
192  }
193  return SingleGradientDef(
194  "TileGradient", "", g_inputs, std::vector<std::string>{GI(0)});
195  }
196 };
197 
198 } // namespace
199 
200 REGISTER_GRADIENT(Tile, GetTileGradient);
201 
202 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70