Caffe2 - C++ API
A deep learning, cross platform ML framework
filler_op.h
1 #ifndef CAFFE2_OPERATORS_FILLER_OP_H_
2 #define CAFFE2_OPERATORS_FILLER_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 // FillerOp takes in either zero or one input.
12 //
13 // If the number of input is 1, the shape will be identical to that of the input
14 // at run time with optional additional dimensions appended at the end as
15 // specified by "extra_shape" argument. In that case the "shape" parameter
16 // should not be set.
17 //
18 // If the number of inputs is 0, the full shape must be provided via "shape"
19 // argument
20 template <class Context>
21 class FillerOp : public Operator<Context> {
22  public:
23  template <class... Args>
24  explicit FillerOp(Args&&... args)
25  : Operator<Context>(std::forward<Args>(args)...),
26  shape_(this->template GetRepeatedArgument<int64_t>("shape")),
27  extra_shape_(ToVectorint64_t(
28  this->template GetRepeatedArgument<int>("extra_shape"))),
29  input_as_shape_(
30  this->template GetSingleArgument<bool>("input_as_shape", false)) {
31  if (InputSize()) {
32  if (shape_.size() != 0) {
33  CAFFE_THROW(
34  "Cannot set the shape argument and pass in an input at "
35  "the same time");
36  }
37  } else {
38  if (!extra_shape_.empty()) {
39  CAFFE_THROW("Cannot set extra_shape when there is no input");
40  }
41  if (input_as_shape_) {
42  CAFFE_THROW("An input must be given if input_as_shape is true");
43  }
44  if (shape_.size() == 0 &&
45  this->template HasSingleArgumentOfType<int>("shape")) {
46  CAFFE_THROW("Fill 'shape' argument was a scalar, list expected");
47  }
48  }
49  }
50 
51  virtual ~FillerOp() {}
52  USE_OPERATOR_CONTEXT_FUNCTIONS;
53 
54  bool RunOnDevice() override {
55  auto* output = Operator<Context>::Output(0);
56  if (InputSize()) {
57  auto shape = vector<int64_t>{};
58  if (input_as_shape_) {
59  if (this->InputIsTensorType(0, CPU)) {
60  // originally, shape input must be in CPU context
61  auto& input = this->template Input<Tensor>(0, CPU);
62  CAFFE_ENFORCE_EQ(
63  input.dim(),
64  1,
65  "When input_as_shape is true, the input must be a 1D tensor of "
66  "data type int64_t");
67  CAFFE_ENFORCE(input.numel() > 0);
68  auto* shape_data = input.template data<int64_t>();
69  shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
70  } else {
71  // in ONNX case, we allow shape to be in CUDA context
72  auto& input = Input(0);
73  CAFFE_ENFORCE_EQ(
74  input.dim(),
75  1,
76  "When input_as_shape is true, the input must be a 1D tensor of "
77  "data type int64_t");
78  CAFFE_ENFORCE(input.numel() > 0);
79  auto* shape_data = input.template data<int64_t>();
80  std::unique_ptr<int64_t[]> shape_data_copy = caffe2::make_unique<int64_t[]>(input.dim32(0));
81  context_.template CopyToCPU<int64_t>(input.dim32(0), shape_data, shape_data_copy.get());
82  shape.insert(shape.end(), shape_data_copy.get(), shape_data_copy.get() + input.dim32(0));
83  }
84  } else {
85  auto& input = Input(0);
86  shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
87  }
88  shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
89  output->Resize(shape);
90  } else {
91  output->Resize(shape_);
92  }
93  return Fill(output);
94  }
95 
96  virtual bool Fill(Tensor* output) = 0;
97 
98  protected:
99  vector<int64_t> shape_;
100  vector<int64_t> extra_shape_;
101  bool input_as_shape_;
102 };
103 
104 template <typename T, class Context>
105 class UniformFillOp final : public FillerOp<Context> {
106  public:
107  USE_OPERATOR_CONTEXT_FUNCTIONS;
108  template <class... Args>
109  explicit UniformFillOp(Args&&... args)
110  : FillerOp<Context>(std::forward<Args>(args)...),
111  min_(this->template GetSingleArgument<T>("min", 0)),
112  max_(this->template GetSingleArgument<T>("max", 1)) {
113  if (InputSize() == 3) {
114  CAFFE_ENFORCE(
115  !this->template HasSingleArgumentOfType<T>("min"),
116  "Cannot set both min arg and min input blob");
117  CAFFE_ENFORCE(
118  !this->template HasSingleArgumentOfType<T>("max"),
119  "Cannot set both max arg and max input blob");
120  } else {
121  CAFFE_ENFORCE_LT(
122  min_, max_, "Max value should be bigger than min value.");
123  }
124  }
125 
126  bool Fill(Tensor* output) override {
127  T min = min_;
128  T max = max_;
129  if (InputSize() == 3) {
130  CAFFE_ENFORCE_EQ(1, Input(1).numel(), "min blob must be scalar");
131  CAFFE_ENFORCE_EQ(1, Input(2).numel(), "max blob must be scalar");
132  min = *Input(1).template data<T>();
133  max = *Input(2).template data<T>();
134  if (min > max) {
135  auto shape = output->sizes().vec();
136  shape[0] = 0;
137  output->Resize(shape);
138  output->template mutable_data<T>();
139  return true;
140  }
141  }
142  math::RandUniform<T, Context>(
143  output->numel(),
144  min,
145  max,
146  output->template mutable_data<T>(),
147  &context_);
148  return true;
149  }
150 
151  private:
152  T min_;
153  T max_;
154 };
155 
156 template <class Context>
157 class UniqueUniformFillOp final : public FillerOp<Context> {
158  public:
159  USE_OPERATOR_CONTEXT_FUNCTIONS;
160  template <class... Args>
161  explicit UniqueUniformFillOp(Args&&... args)
162  : FillerOp<Context>(std::forward<Args>(args)...) {
163  TensorProto_DataType dtype =
164  static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
165  "dtype", TensorProto_DataType_INT32));
166 
167  switch (dtype) {
168  case TensorProto_DataType_INT32:
169  CheckRange<int>();
170  body_ = &UniqueUniformFillOp::FillWithType<int>;
171  break;
172  case TensorProto_DataType_INT64:
173  CheckRange<int64_t>();
174  body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
175  break;
176  case TensorProto_DataType_UNDEFINED:
177  CAFFE_THROW(
178  "UniqueUniformFill op cannot have undefined 'dtype' argument");
179  // break;
180  default:
181  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
182  }
183  }
184 
185  bool Fill(Tensor* output) override {
186  return (this->*body_)(output);
187  }
188 
189  private:
190  template <typename T>
191  void CheckRange() {
192  CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("min"));
193  CAFFE_ENFORCE(this->template HasSingleArgumentOfType<T>("max"));
194  CAFFE_ENFORCE_LT(
195  this->template GetSingleArgument<T>("min", 0),
196  this->template GetSingleArgument<T>("max", 0),
197  "Max value should be bigger than min value.");
198  }
199 
200  template <typename T>
201  bool FillWithType(Tensor* output) {
202  T min = this->template GetSingleArgument<T>("min", 0);
203  T max = this->template GetSingleArgument<T>("max", 0);
204 
205  const T* avoid_data = nullptr;
206  size_t avoid_size = 0;
207  if (InputSize() >= 2) {
208  auto& avoid = Input(1);
209  avoid_data = avoid.template data<T>();
210  avoid_size = avoid.numel();
211  }
212  math::RandUniformUnique<T, Context>(
213  output->numel(),
214  min,
215  max,
216  output->template mutable_data<T>(),
217  avoid_size,
218  avoid_data,
219  &context_);
220  return true;
221  }
222 
223  bool (UniqueUniformFillOp::*body_)(Tensor* output);
224 };
225 
226 template <class Context>
227 class ConstantFillOp final : public FillerOp<Context> {
228  public:
229  USE_OPERATOR_CONTEXT_FUNCTIONS;
230  template <class... Args>
231  explicit ConstantFillOp(Args&&... args)
232  : FillerOp<Context>(std::forward<Args>(args)...) {
233  TensorProto_DataType dtype =
234  static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
235  "dtype", TensorProto_DataType_FLOAT));
236 
237  if (!OperatorBase::HasArgument("dtype") &&
238  OperatorBase::HasArgument("value")) {
239  // If 'dtype' is not provided, infer type based on the type of 'value'
240  // Currently, single argument contains either float, int64 or bytes
241  if (this->template HasSingleArgumentOfType<float>("value")) {
242  dtype = TensorProto_DataType_FLOAT;
243  } else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
244  dtype = TensorProto_DataType_INT64;
245  } else {
246  CAFFE_THROW("Argument 'value' is of unexpected type");
247  }
248  VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
249  << "the same as that of argument 'value': " << dtype;
250  }
251 
252  switch (dtype) {
253  case TensorProto_DataType_FLOAT:
254  body_ = &ConstantFillOp::FillWithType<float>;
255  break;
256  case TensorProto_DataType_DOUBLE:
257  body_ = &ConstantFillOp::FillWithType<double>;
258  break;
259  case TensorProto_DataType_BOOL:
260  body_ = &ConstantFillOp::FillWithType<bool>;
261  break;
262  case TensorProto_DataType_INT8:
263  body_ = &ConstantFillOp::FillWithType<int8_t>;
264  break;
265  case TensorProto_DataType_INT16:
266  body_ = &ConstantFillOp::FillWithType<int16_t>;
267  break;
268  case TensorProto_DataType_INT32:
269  body_ = &ConstantFillOp::FillWithType<int>;
270  break;
271  case TensorProto_DataType_INT64:
272  body_ = &ConstantFillOp::FillWithType<int64_t>;
273  break;
274  case TensorProto_DataType_UINT8:
275  body_ = &ConstantFillOp::FillWithType<uint8_t>;
276  break;
277  case TensorProto_DataType_UINT16:
278  body_ = &ConstantFillOp::FillWithType<uint16_t>;
279  break;
280  case TensorProto_DataType_STRING:
281  body_ = &ConstantFillOp::FillWithString;
282  break;
283  case TensorProto_DataType_UNDEFINED:
284  CAFFE_THROW("ConstantFill op cannot have undefined 'dtype' argument");
285  // break;
286  default:
287  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
288  }
289  }
290 
291  bool Fill(Tensor* output) override {
292  return (this->*body_)(output);
293  }
294 
295  template <typename T>
296  bool FillWithType(Tensor* output) {
297  T value = this->template GetSingleArgument<T>("value", 0);
298  auto* data = output->template mutable_data<T>();
299  if (output->numel()) {
300  math::Set<T, Context>(output->numel(), value, data, &context_);
301  }
302  return true;
303  }
304 
305  bool FillWithString(Tensor* output) {
306  auto value = this->template GetSingleArgument<std::string>("value", "");
307  auto* data = output->template mutable_data<std::string>();
308  for (int i = 0; i < output->numel(); ++i) {
309  data[i] = value;
310  }
311  return true;
312  }
313 
314  private:
315  bool (ConstantFillOp::*body_)(Tensor* output);
316 };
317 
318 template <class Context>
319 class DiagonalFillOp final : public FillerOp<Context> {
320  public:
321  USE_OPERATOR_CONTEXT_FUNCTIONS;
322  template <class... Args>
323  explicit DiagonalFillOp(Args&&... args)
324  : FillerOp<Context>(std::forward<Args>(args)...) {
325  TensorProto_DataType dtype =
326  static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
327  "dtype", TensorProto_DataType_FLOAT));
328 
329  if (!OperatorBase::HasArgument("dtype") &&
330  OperatorBase::HasArgument("value")) {
331  // If 'dtype' is not provided, infer type based on the type of 'value'
332  // Currently, single argument contains either float, int64 or bytes
333  if (this->template HasSingleArgumentOfType<float>("value")) {
334  dtype = TensorProto_DataType_FLOAT;
335  } else if (this->template HasSingleArgumentOfType<int64_t>("value")) {
336  dtype = TensorProto_DataType_INT64;
337  } else {
338  CAFFE_THROW("Argument 'value' is of unexpected type");
339  }
340  VLOG(1) << "Argument 'dtype' is not provided. Assume the data type is "
341  << "the same as that of argument 'value': " << dtype;
342  }
343 
344  switch (dtype) {
345  case TensorProto_DataType_FLOAT:
346  body_ = &DiagonalFillOp::FillWithType<float>;
347  break;
348  case TensorProto_DataType_DOUBLE:
349  body_ = &DiagonalFillOp::FillWithType<double>;
350  break;
351  case TensorProto_DataType_BOOL:
352  body_ = &DiagonalFillOp::FillWithType<bool>;
353  break;
354  case TensorProto_DataType_INT8:
355  body_ = &DiagonalFillOp::FillWithType<int8_t>;
356  break;
357  case TensorProto_DataType_INT16:
358  body_ = &DiagonalFillOp::FillWithType<int16_t>;
359  break;
360  case TensorProto_DataType_INT32:
361  body_ = &DiagonalFillOp::FillWithType<int>;
362  break;
363  case TensorProto_DataType_INT64:
364  body_ = &DiagonalFillOp::FillWithType<int64_t>;
365  break;
366  case TensorProto_DataType_UINT8:
367  body_ = &DiagonalFillOp::FillWithType<uint8_t>;
368  break;
369  case TensorProto_DataType_UINT16:
370  body_ = &DiagonalFillOp::FillWithType<uint16_t>;
371  break;
372  case TensorProto_DataType_UNDEFINED:
373  CAFFE_THROW("Cannot have undefined 'dtype' argument");
374  default:
375  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
376  }
377  }
378 
379  bool Fill(Tensor* output) override {
380  return (this->*body_)(output);
381  }
382 
383  template <typename T>
384  bool FillWithType(Tensor* output);
385 
386  private:
387  void VerifyOutputShape(Tensor* output) {
388  CAFFE_ENFORCE(output->dim() >= 2, "Input shape must be >= 2D");
389  }
390 
391  int64_t GetStepSize(Tensor* output) {
392  int64_t step;
393  if (output->dim() == 2) {
394  step = output->size(1) + 1;
395  } else {
396  int64_t prev_i = output->size(0);
397  for (auto i : output->sizes()) {
398  if (i != prev_i) {
399  CAFFE_THROW("All dimensions of input must be of equal length");
400  }
401  }
402  vector<int64_t> cumprod(output->dim());
403  auto dims = output->sizes();
404  std::partial_sum(
405  dims.begin(),
406  dims.end() - 1,
407  cumprod.begin(),
408  std::multiplies<int64_t>());
409  step = 1 +
410  std::accumulate(
411  cumprod.begin(), cumprod.end(), static_cast<int64_t>(0));
412  VLOG(0) << step;
413  }
414  return step;
415  }
416 
417  bool (DiagonalFillOp::*body_)(Tensor* output);
418 };
419 
420 template <typename T, class Context>
421 class GaussianFillOp final : public FillerOp<Context> {
422  public:
423  USE_OPERATOR_CONTEXT_FUNCTIONS;
424  template <class... Args>
425  explicit GaussianFillOp(Args&&... args)
426  : FillerOp<Context>(std::forward<Args>(args)...),
427  mean_(this->template GetSingleArgument<float>("mean", 0)),
428  std_(this->template GetSingleArgument<float>("std", 1)) {
429  DCHECK_GT(std_, 0) << "Standard deviation should be nonnegative.";
430  }
431 
432  bool Fill(Tensor* output) override {
433  math::RandGaussian<T, Context>(
434  output->numel(),
435  mean_,
436  std_,
437  output->template mutable_data<T>(),
438  &context_);
439  return true;
440  }
441 
442  private:
443  T mean_;
444  T std_;
445 };
446 
447 template <typename T, class Context>
448 class XavierFillOp final : public FillerOp<Context> {
449  public:
450  USE_OPERATOR_CONTEXT_FUNCTIONS;
451  template <class... Args>
452  explicit XavierFillOp(Args&&... args)
453  : FillerOp<Context>(std::forward<Args>(args)...) {}
454 
455  bool Fill(Tensor* output) override {
456  const int fan_in = output->numel() / output->dim32(0);
457  T scale = std::sqrt(T(3) / fan_in);
458  math::RandUniform<T, Context>(
459  output->numel(),
460  -scale,
461  scale,
462  output->template mutable_data<T>(),
463  &context_);
464  return true;
465  }
466 };
467 
468 template <typename T, class Context>
469 class MSRAFillOp final : public FillerOp<Context> {
470  public:
471  USE_OPERATOR_CONTEXT_FUNCTIONS;
472  template <class... Args>
473  explicit MSRAFillOp(Args&&... args)
474  : FillerOp<Context>(std::forward<Args>(args)...) {}
475 
476  bool Fill(Tensor* output) override {
477  const int fan_out = output->numel() / output->dim32(1);
478  T scale = std::sqrt(T(2) / fan_out);
479  math::RandGaussian<T, Context>(
480  output->numel(),
481  0.0,
482  scale,
483  output->template mutable_data<T>(),
484  &context_);
485  return true;
486  }
487 };
488 
489 // This is mostly used just as a debugging purpose stuff: it fills a tensor
490 // sequentially with values 0, 1, 2..., which can then be used to check e.g.
491 // reshape operations by allowing one to read the indices more easily.
492 template <typename T, class Context>
493 class RangeFillOp final : public FillerOp<Context> {
494  public:
495  USE_OPERATOR_CONTEXT_FUNCTIONS;
496  template <class... Args>
497  explicit RangeFillOp(Args&&... args)
498  : FillerOp<Context>(std::forward<Args>(args)...) {}
499 
500  bool Fill(Tensor* output) override;
501 };
502 
503 template <class Context>
504 class LengthsRangeFillOp : public Operator<Context> {
505  public:
506  USE_OPERATOR_CONTEXT_FUNCTIONS;
507  USE_SIMPLE_CTOR_DTOR(LengthsRangeFillOp);
508 
509  bool RunOnDevice() override {
510  auto& input = Input(0);
511 
512  auto* input_data = input.template data<int32_t>();
513 
514  CAFFE_ENFORCE_EQ(input.dim(), 1, "Input must be a vector.");
515 
516  auto len_sum = std::accumulate(input_data, input_data + input.numel(), 0);
517 
518  auto* output = Output(0, {len_sum}, at::dtype<int32_t>());
519  auto* output_data = output->template mutable_data<int32_t>();
520 
521  int32_t offset = 0;
522  for (int i = 0; i < input.numel(); ++i) {
523  auto len = input_data[i];
524  auto start = output_data + offset;
525  std::iota(
526  start,
527  start + len,
528  0); // make the third argument the arg of this operator
529  offset += len;
530  }
531  return true;
532  }
533 };
534 
535 template <int VALUE_TYPE = TensorProto_DataType_FLOAT>
536 inline std::vector<TensorShape> FillerTensorInference(
537  const OperatorDef& def,
538  const vector<TensorShape>& in) {
539  vector<TensorShape> out(1);
540  ArgumentHelper helper(def);
541  out[0].set_data_type(static_cast<TensorProto_DataType>(
542  helper.GetSingleArgument<int>("dtype", VALUE_TYPE)));
543 
544  if (in.size()) {
545  // TODO
546  bool input_as_shape =
547  helper.GetSingleArgument<bool>("input_as_shape", false);
548  if (input_as_shape) {
549  out[0].set_unknown_shape(true);
550  return out;
551  }
552  for (auto d : in[0].dims()) {
553  out[0].add_dims(d);
554  }
555  } else {
556  auto shape = helper.GetRepeatedArgument<int64_t>("shape");
557  for (auto d : shape) {
558  out[0].add_dims(d);
559  }
560  }
561  return out;
562 }
563 
564 } // namespace caffe2
565 
566 #endif // CAFFE2_OPERATORS_FILLER_OP_H_
std::vector< int64_t > ToVectorint64_t(ArrayRef< int > src)
A utility function to convert vector<int> to vector<int64_t>.
Definition: TensorImpl.h:46
A helper class to index into arguments.
Definition: proto_utils.h:200
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70