1 #ifndef CAFFE2_OPERATORS_FILLER_OP_H_ 2 #define CAFFE2_OPERATORS_FILLER_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 20 template <
class Context>
23 template <
class... Args>
26 shape_(this->
template GetRepeatedArgument<int64_t>(
"shape")),
28 this->
template GetRepeatedArgument<int>(
"extra_shape"))),
30 this->
template GetSingleArgument<bool>(
"input_as_shape",
false)) {
32 if (shape_.size() != 0) {
34 "Cannot set the shape argument and pass in an input at " 38 if (!extra_shape_.empty()) {
39 CAFFE_THROW(
"Cannot set extra_shape when there is no input");
41 if (input_as_shape_) {
42 CAFFE_THROW(
"An input must be given if input_as_shape is true");
44 if (shape_.size() == 0 &&
45 this->
template HasSingleArgumentOfType<int>(
"shape")) {
46 CAFFE_THROW(
"Fill 'shape' argument was a scalar, list expected");
51 virtual ~FillerOp() {}
52 USE_OPERATOR_CONTEXT_FUNCTIONS;
54 bool RunOnDevice()
override {
57 auto shape = vector<int64_t>{};
58 if (input_as_shape_) {
59 if (this->InputIsTensorType(0, CPU)) {
61 auto& input = this->
template Input<Tensor>(0, CPU);
65 "When input_as_shape is true, the input must be a 1D tensor of " 67 CAFFE_ENFORCE(input.numel() > 0);
68 auto* shape_data = input.template data<int64_t>();
69 shape.insert(shape.end(), shape_data, shape_data + input.dim32(0));
72 auto& input =
Input(0);
76 "When input_as_shape is true, the input must be a 1D tensor of " 78 CAFFE_ENFORCE(input.numel() > 0);
79 auto* shape_data = input.template data<int64_t>();
80 std::unique_ptr<int64_t[]> shape_data_copy = caffe2::make_unique<int64_t[]>(input.dim32(0));
81 context_.template CopyToCPU<int64_t>(input.dim32(0), shape_data, shape_data_copy.get());
82 shape.insert(shape.end(), shape_data_copy.get(), shape_data_copy.get() + input.dim32(0));
85 auto& input =
Input(0);
86 shape.insert(shape.end(), input.sizes().begin(), input.sizes().end());
88 shape.insert(shape.end(), extra_shape_.begin(), extra_shape_.end());
89 output->Resize(shape);
91 output->Resize(shape_);
96 virtual bool Fill(
Tensor* output) = 0;
99 vector<int64_t> shape_;
100 vector<int64_t> extra_shape_;
101 bool input_as_shape_;
104 template <
typename T,
class Context>
107 USE_OPERATOR_CONTEXT_FUNCTIONS;
108 template <
class... Args>
111 min_(this->
template GetSingleArgument<T>(
"min", 0)),
112 max_(this->
template GetSingleArgument<T>(
"max", 1)) {
113 if (InputSize() == 3) {
115 !this->
template HasSingleArgumentOfType<T>(
"min"),
116 "Cannot set both min arg and min input blob");
118 !this->
template HasSingleArgumentOfType<T>(
"max"),
119 "Cannot set both max arg and max input blob");
122 min_, max_,
"Max value should be bigger than min value.");
126 bool Fill(
Tensor* output)
override {
129 if (InputSize() == 3) {
130 CAFFE_ENFORCE_EQ(1,
Input(1).numel(),
"min blob must be scalar");
131 CAFFE_ENFORCE_EQ(1,
Input(2).numel(),
"max blob must be scalar");
132 min = *
Input(1).template data<T>();
133 max = *
Input(2).template data<T>();
135 auto shape = output->sizes().vec();
137 output->Resize(shape);
138 output->template mutable_data<T>();
142 math::RandUniform<T, Context>(
146 output->template mutable_data<T>(),
156 template <
class Context>
159 USE_OPERATOR_CONTEXT_FUNCTIONS;
160 template <
class... Args>
163 TensorProto_DataType dtype =
164 static_cast<TensorProto_DataType
>(this->
template GetSingleArgument<int>(
165 "dtype", TensorProto_DataType_INT32));
168 case TensorProto_DataType_INT32:
170 body_ = &UniqueUniformFillOp::FillWithType<int>;
172 case TensorProto_DataType_INT64:
173 CheckRange<int64_t>();
174 body_ = &UniqueUniformFillOp::FillWithType<int64_t>;
176 case TensorProto_DataType_UNDEFINED:
178 "UniqueUniformFill op cannot have undefined 'dtype' argument");
181 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
185 bool Fill(
Tensor* output)
override {
186 return (this->*body_)(output);
190 template <
typename T>
192 CAFFE_ENFORCE(this->
template HasSingleArgumentOfType<T>(
"min"));
193 CAFFE_ENFORCE(this->
template HasSingleArgumentOfType<T>(
"max"));
195 this->
template GetSingleArgument<T>(
"min", 0),
196 this->
template GetSingleArgument<T>(
"max", 0),
197 "Max value should be bigger than min value.");
200 template <
typename T>
201 bool FillWithType(
Tensor* output) {
202 T min = this->
template GetSingleArgument<T>(
"min", 0);
203 T max = this->
template GetSingleArgument<T>(
"max", 0);
205 const T* avoid_data =
nullptr;
206 size_t avoid_size = 0;
207 if (InputSize() >= 2) {
208 auto& avoid =
Input(1);
209 avoid_data = avoid.template data<T>();
210 avoid_size = avoid.numel();
212 math::RandUniformUnique<T, Context>(
216 output->template mutable_data<T>(),
223 bool (UniqueUniformFillOp::*body_)(
Tensor* output);
226 template <
class Context>
229 USE_OPERATOR_CONTEXT_FUNCTIONS;
230 template <
class... Args>
233 TensorProto_DataType dtype =
234 static_cast<TensorProto_DataType
>(this->
template GetSingleArgument<int>(
235 "dtype", TensorProto_DataType_FLOAT));
241 if (this->
template HasSingleArgumentOfType<float>(
"value")) {
242 dtype = TensorProto_DataType_FLOAT;
243 }
else if (this->
template HasSingleArgumentOfType<int64_t>(
"value")) {
244 dtype = TensorProto_DataType_INT64;
246 CAFFE_THROW(
"Argument 'value' is of unexpected type");
248 VLOG(1) <<
"Argument 'dtype' is not provided. Assume the data type is " 249 <<
"the same as that of argument 'value': " << dtype;
253 case TensorProto_DataType_FLOAT:
254 body_ = &ConstantFillOp::FillWithType<float>;
256 case TensorProto_DataType_DOUBLE:
257 body_ = &ConstantFillOp::FillWithType<double>;
259 case TensorProto_DataType_BOOL:
260 body_ = &ConstantFillOp::FillWithType<bool>;
262 case TensorProto_DataType_INT8:
263 body_ = &ConstantFillOp::FillWithType<int8_t>;
265 case TensorProto_DataType_INT16:
266 body_ = &ConstantFillOp::FillWithType<int16_t>;
268 case TensorProto_DataType_INT32:
269 body_ = &ConstantFillOp::FillWithType<int>;
271 case TensorProto_DataType_INT64:
272 body_ = &ConstantFillOp::FillWithType<int64_t>;
274 case TensorProto_DataType_UINT8:
275 body_ = &ConstantFillOp::FillWithType<uint8_t>;
277 case TensorProto_DataType_UINT16:
278 body_ = &ConstantFillOp::FillWithType<uint16_t>;
280 case TensorProto_DataType_STRING:
281 body_ = &ConstantFillOp::FillWithString;
283 case TensorProto_DataType_UNDEFINED:
284 CAFFE_THROW(
"ConstantFill op cannot have undefined 'dtype' argument");
287 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
291 bool Fill(
Tensor* output)
override {
292 return (this->*body_)(output);
295 template <
typename T>
296 bool FillWithType(
Tensor* output) {
297 T value = this->
template GetSingleArgument<T>(
"value", 0);
298 auto* data = output->template mutable_data<T>();
299 if (output->numel()) {
300 math::Set<T, Context>(output->numel(), value, data, &context_);
305 bool FillWithString(
Tensor* output) {
306 auto value = this->
template GetSingleArgument<std::string>(
"value",
"");
307 auto* data = output->template mutable_data<std::string>();
308 for (
int i = 0; i < output->numel(); ++i) {
315 bool (ConstantFillOp::*body_)(
Tensor* output);
318 template <
class Context>
321 USE_OPERATOR_CONTEXT_FUNCTIONS;
322 template <
class... Args>
325 TensorProto_DataType dtype =
326 static_cast<TensorProto_DataType
>(this->
template GetSingleArgument<int>(
327 "dtype", TensorProto_DataType_FLOAT));
333 if (this->
template HasSingleArgumentOfType<float>(
"value")) {
334 dtype = TensorProto_DataType_FLOAT;
335 }
else if (this->
template HasSingleArgumentOfType<int64_t>(
"value")) {
336 dtype = TensorProto_DataType_INT64;
338 CAFFE_THROW(
"Argument 'value' is of unexpected type");
340 VLOG(1) <<
"Argument 'dtype' is not provided. Assume the data type is " 341 <<
"the same as that of argument 'value': " << dtype;
345 case TensorProto_DataType_FLOAT:
346 body_ = &DiagonalFillOp::FillWithType<float>;
348 case TensorProto_DataType_DOUBLE:
349 body_ = &DiagonalFillOp::FillWithType<double>;
351 case TensorProto_DataType_BOOL:
352 body_ = &DiagonalFillOp::FillWithType<bool>;
354 case TensorProto_DataType_INT8:
355 body_ = &DiagonalFillOp::FillWithType<int8_t>;
357 case TensorProto_DataType_INT16:
358 body_ = &DiagonalFillOp::FillWithType<int16_t>;
360 case TensorProto_DataType_INT32:
361 body_ = &DiagonalFillOp::FillWithType<int>;
363 case TensorProto_DataType_INT64:
364 body_ = &DiagonalFillOp::FillWithType<int64_t>;
366 case TensorProto_DataType_UINT8:
367 body_ = &DiagonalFillOp::FillWithType<uint8_t>;
369 case TensorProto_DataType_UINT16:
370 body_ = &DiagonalFillOp::FillWithType<uint16_t>;
372 case TensorProto_DataType_UNDEFINED:
373 CAFFE_THROW(
"Cannot have undefined 'dtype' argument");
375 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
379 bool Fill(
Tensor* output)
override {
380 return (this->*body_)(output);
383 template <
typename T>
384 bool FillWithType(
Tensor* output);
387 void VerifyOutputShape(
Tensor* output) {
388 CAFFE_ENFORCE(output->dim() >= 2,
"Input shape must be >= 2D");
391 int64_t GetStepSize(
Tensor* output) {
393 if (output->dim() == 2) {
394 step = output->size(1) + 1;
396 int64_t prev_i = output->size(0);
397 for (
auto i : output->sizes()) {
399 CAFFE_THROW(
"All dimensions of input must be of equal length");
402 vector<int64_t> cumprod(output->dim());
403 auto dims = output->sizes();
408 std::multiplies<int64_t>());
411 cumprod.begin(), cumprod.end(),
static_cast<int64_t
>(0));
417 bool (DiagonalFillOp::*body_)(
Tensor* output);
420 template <
typename T,
class Context>
423 USE_OPERATOR_CONTEXT_FUNCTIONS;
424 template <
class... Args>
427 mean_(this->
template GetSingleArgument<float>(
"mean", 0)),
428 std_(this->
template GetSingleArgument<float>(
"std", 1)) {
429 DCHECK_GT(std_, 0) <<
"Standard deviation should be nonnegative.";
432 bool Fill(
Tensor* output)
override {
433 math::RandGaussian<T, Context>(
437 output->template mutable_data<T>(),
447 template <
typename T,
class Context>
450 USE_OPERATOR_CONTEXT_FUNCTIONS;
451 template <
class... Args>
455 bool Fill(
Tensor* output)
override {
456 const int fan_in = output->numel() / output->dim32(0);
457 T scale = std::sqrt(
T(3) / fan_in);
458 math::RandUniform<T, Context>(
462 output->template mutable_data<T>(),
468 template <
typename T,
class Context>
471 USE_OPERATOR_CONTEXT_FUNCTIONS;
472 template <
class... Args>
476 bool Fill(
Tensor* output)
override {
477 const int fan_out = output->numel() / output->dim32(1);
478 T scale = std::sqrt(
T(2) / fan_out);
479 math::RandGaussian<T, Context>(
483 output->template mutable_data<T>(),
492 template <
typename T,
class Context>
495 USE_OPERATOR_CONTEXT_FUNCTIONS;
496 template <
class... Args>
500 bool Fill(
Tensor* output)
override;
503 template <
class Context>
506 USE_OPERATOR_CONTEXT_FUNCTIONS;
509 bool RunOnDevice()
override {
510 auto& input =
Input(0);
512 auto* input_data = input.template data<int32_t>();
514 CAFFE_ENFORCE_EQ(input.dim(), 1,
"Input must be a vector.");
516 auto len_sum = std::accumulate(input_data, input_data + input.numel(), 0);
518 auto* output = Output(0, {len_sum}, at::dtype<int32_t>());
519 auto* output_data = output->template mutable_data<int32_t>();
522 for (
int i = 0; i < input.numel(); ++i) {
523 auto len = input_data[i];
524 auto start = output_data + offset;
535 template <
int VALUE_TYPE = TensorProto_DataType_FLOAT>
536 inline std::vector<TensorShape> FillerTensorInference(
537 const OperatorDef& def,
538 const vector<TensorShape>& in) {
539 vector<TensorShape> out(1);
541 out[0].set_data_type(static_cast<TensorProto_DataType>(
542 helper.GetSingleArgument<
int>(
"dtype", VALUE_TYPE)));
546 bool input_as_shape =
547 helper.GetSingleArgument<
bool>(
"input_as_shape",
false);
548 if (input_as_shape) {
549 out[0].set_unknown_shape(
true);
552 for (
auto d : in[0].dims()) {
556 auto shape = helper.GetRepeatedArgument<int64_t>(
"shape");
557 for (
auto d : shape) {
566 #endif // CAFFE2_OPERATORS_FILLER_OP_H_
std::vector< int64_t > ToVectorint64_t(ArrayRef< int > src)
A utility function to convert vector<int> to vector<int64_t>.
A helper class to index into arguments.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.