3 #include "caffe2/core/context.h" 4 #include "caffe2/core/logging.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/operators/filler_op.h" 7 #include "caffe2/utils/cast.h" 8 #include "caffe2/utils/math.h" 12 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 if (!std::is_same<T, float>::value || !helper.HasArgument(
"dtype")) {
26 auto dtype = cast::GetCastDataType(helper,
"dtype");
28 case TensorProto_DataType_FLOAT:
29 ExtractValues<float>();
31 case TensorProto_DataType_DOUBLE:
32 ExtractValues<double>();
34 case TensorProto_DataType_BOOL:
35 ExtractValues<bool>();
37 case TensorProto_DataType_INT32:
40 case TensorProto_DataType_INT64:
41 ExtractValues<int64_t>();
43 case TensorProto_DataType_STRING:
44 ExtractValues<std::string>();
46 case TensorProto_DataType_UNDEFINED:
47 CAFFE_THROW(
"Cannot have undefined 'dtype' argument");
49 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
54 bool Fill(
Tensor* output)
override {
55 return (this->*body_)(output);
59 template <
typename Type>
60 void ExtractValues() {
62 this->
template GetRepeatedArgument<Type>(
"values");
63 ReinitializeTensor(&values_, {
static_cast<int64_t
>(source_values.size())}, at::dtype<Type>().device(CPU));
64 Type* values_data = values_.template mutable_data<Type>();
65 for (
int i = 0; i < source_values.size(); i++) {
66 values_data[i] =
static_cast<Type>(source_values[i]);
68 body_ = &GivenTensorFillOp::FillWithType<Type>;
71 template <
typename Type>
72 bool FillWithType(
Tensor* output) {
73 DCHECK_EQ(output->numel(), values_.numel())
74 <<
"output size: " << output->numel()
75 <<
" given size: " << values_.numel();
76 auto* data = output->template mutable_data<Type>();
77 const Type* values_data = values_.template data<Type>();
78 if (output->numel()) {
79 context_.CopyItemsFromCPU(
80 TypeMeta::Make<Type>(), output->numel(), values_data, data);
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A helper class to index into arguments.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...