1 #include "caffe2/operators/sparse_normalize_op.h" 2 #include "caffe2/core/tensor.h" 3 #include "caffe2/utils/eigen_utils.h" 8 bool SparseNormalizeOp<float, CPUContext>::RunOnDevice() {
10 Input(PARAM).size_from_dim(1),
11 Input(GRAD).size_from_dim(Input(INDICES).dim()));
13 return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
14 this, Input(INDICES));
18 template <
typename SIndex>
19 bool SparseNormalizeOp<float, CPUContext>::DoRunWithType() {
20 const auto* indices = Input(INDICES).template data<SIndex>();
21 const auto* paramIn = Input(PARAM).template data<float>();
22 auto* paramOut = Output(OUTPUT_PARAM)->template mutable_data<float>();
23 const float kEps = 1e-12f;
26 auto n = Input(INDICES).numel();
32 auto block_size = Input(GRAD).numel() / n;
33 for (
int i = 0; i < n; ++i) {
34 auto idx = indices[i];
35 auto offsetIdx = idx * block_size;
36 ConstEigenVectorMap<float> xVec(paramIn + offsetIdx, block_size);
37 auto norm = xVec.template lpNorm<2>();
39 if (use_max_norm_ && norm <= norm_) {
45 norm_ / (norm + kEps),
53 REGISTER_CPU_OPERATOR(SparseNormalize, SparseNormalizeOp<float, CPUContext>);
54 OPERATOR_SCHEMA(SparseNormalize)
57 .Input(0,
"param",
"Parameters to be normalized")
58 .Input(1,
"indices",
"Sparse indices")
59 .Input(2,
"grad",
"Gradient computed")
60 .Output(0,
"output_param",
"Normalized parameters")
61 .EnforceOneToOneInplace()
64 "A bool variable to control whether to use max norm \ 65 or constant norm. When use_max_norm = false, constant norm is used so that \ 66 all the embedding vectors are scaled to have a L2 norm equals to A \ 67 (see blow arugment norm=A). If use_max_norm = true, \ 68 max norm is used so that embedding is scaled so that its l2 norm is no larger \ 69 than A. If an embedding's norm is less than A originally, \ 70 the embedding is left unchanged.\ 71 The default is True.")
72 .Arg(
"norm",
"L2 norm of the embedding. The default is 1.0.")
74 Given a sparse matrix, apply max_norm or constant_norm sparse regularization. 77 SHOULD_NOT_DO_GRADIENT(SparseNormalize); A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...