Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_normalize_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/operator.h"
20 #include "caffe2/utils/math.h"
21 
22 namespace caffe2 {
23 
24 template <typename T, class Context>
25 class SparseNormalizeOp final : public Operator<Context> {
26  public:
27  USE_OPERATOR_CONTEXT_FUNCTIONS;
28  SparseNormalizeOp(const OperatorDef& operator_def, Workspace* ws)
29  : Operator<Context>(operator_def, ws),
30  use_max_norm_(
31  OperatorBase::GetSingleArgument<bool>("use_max_norm", true)),
32  norm_(OperatorBase::GetSingleArgument<float>("norm", 1.0)) {
33  CAFFE_ENFORCE_GE(norm_, 0, "norm should be bigger than 0");
34  }
35 
36  bool RunOnDevice() override {
37  CAFFE_ENFORCE_EQ(
38  Input(PARAM).size_from_dim(1),
39  Input(GRAD).size_from_dim(Input(INDICES).ndim()));
40 
42  this, Input(INDICES));
43  }
44 
45  template <typename SIndex>
46  bool DoRunWithType();
47 
48  protected:
49  bool use_max_norm_;
50  float norm_;
51  INPUT_TAGS(PARAM, INDICES, GRAD);
52  OUTPUT_TAGS(OUTPUT_PARAM);
53 };
54 
55 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.