3 #include "caffe2/core/operator.h" 4 #include "caffe2/utils/eigen_utils.h" 5 #include "caffe2/utils/math.h" 9 template <
typename T,
class Context>
12 USE_OPERATOR_CONTEXT_FUNCTIONS;
14 template <
class... Args>
17 min_(std::numeric_limits<T>::lowest()),
18 max_(std::numeric_limits<T>::max()) {
20 min_ =
static_cast<T>(this->
template GetSingleArgument<float>(
"min", 0));
23 max_ =
static_cast<T>(this->
template GetSingleArgument<float>(
"max", 0));
27 bool RunOnDevice()
override {
28 if (InputSize() > INDICES) {
31 Input(PARAM).size_from_dim(1),
32 Input(GRAD).size_from_dim(
Input(INDICES).dim()));
34 this,
Input(INDICES));
36 auto& X =
Input(PARAM);
38 auto* Y = Output(OUTPUT_PARAM, X.sizes(), at::dtype<float>());
39 EigenVectorMap<float>(Y->template mutable_data<float>(), Y->numel()) =
40 ConstEigenVectorMap<float>(X.template data<float>(), X.numel())
47 template <
typename SIndex>
53 INPUT_TAGS(PARAM, INDICES, GRAD);
54 OUTPUT_TAGS(OUTPUT_PARAM);
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.