1 #ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_ 2 #define CAFFE2_CORE_OPERATOR_GRADIENT_H_ 4 #include "c10/util/Registry.h" 5 #include "caffe2/core/operator_schema.h" 6 #include "caffe2/proto/caffe2_pb.h" 7 #include "caffe2/utils/proto_utils.h" 22 inline bool IsDense()
const {
23 return (dense_.size() != 0);
25 inline bool IsSparse()
const {
26 return (indices_.size() != 0 || values_.size() != 0);
28 inline bool IsEmpty()
const {
29 return (!IsDense() && !IsSparse());
37 vector<OperatorDef> ops_;
38 vector<GradientWrapper> g_input_;
42 const vector<OperatorDef>& ops,
43 const vector<GradientWrapper>& v)
44 : ops_(ops), g_input_(v) {}
50 const OperatorDef& def,
51 const vector<GradientWrapper>& g_output)
52 : def_(def), g_output_(g_output), g_input_(def.input_size()){};
54 virtual bool CopyDeviceOption()
const {
57 virtual bool CopyEngine()
const {
60 virtual bool CopyArguments()
const {
64 virtual void VerifyOp()
const {
65 auto* schema = OpSchemaRegistry::Schema(def_.type());
69 "(GradientMaker) Operator def did not pass schema checking: ",
70 ProtoDebugString(def_));
87 vector<OperatorDef> new_defs = GetGradientDefs();
88 for (
auto& opdef : new_defs) {
89 opdef.set_is_gradient_op(
true);
94 const OperatorDef& Def()
const {
99 virtual vector<OperatorDef> GetGradientDefs() {
100 CAFFE_NOT_IMPLEMENTED;
109 string I(
const int i) {
110 CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
111 return def_.input(i);
113 string O(
const int i) {
114 CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
115 return def_.output(i);
117 string GI(
const int i) {
119 !g_input_.at(i).IsSparse(),
122 " already set to sparse.");
123 g_input_.at(i).dense_ = GradientName(def_.input(i));
124 return GradientName(def_.input(i));
126 string GI_I(
const int i) {
128 !g_input_.at(i).IsDense(),
131 " already set to dense.");
132 g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
133 return GradientSliceIndices(def_.input(i));
135 string GI_V(
const int i) {
137 !g_input_.at(i).IsDense(),
140 " already set to dense.");
141 g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
142 return GradientSliceValues(def_.input(i));
144 string GO(
const int i) {
146 g_output_.at(i).IsDense(),
147 "Gradient of output ",
149 (g_output_.at(i).IsSparse() ?
" is sparse (expected dense)." 150 :
" is not provided!"));
151 return g_output_.at(i).dense_;
153 string GO_I(
const int i) {
155 g_output_.at(i).IsSparse(),
156 "Gradient of output ",
158 (g_output_.at(i).IsDense() ?
" is dense (expected sparse)." 159 :
" is not provided!"));
160 return g_output_.at(i).indices_;
162 string GO_V(
const int i) {
164 g_output_.at(i).IsSparse(),
165 "Gradient of output ",
167 (g_output_.at(i).IsDense() ?
" is dense (expected sparse)." 168 :
" is not provided!"));
169 return g_output_.at(i).values_;
172 return g_output_.at(i);
176 void SetDense(
const int i,
const string& name) {
178 !g_input_.at(i).IsSparse(),
181 " already set to sparse.");
182 g_input_.at(i).dense_ = name;
184 void SetSparse(
const int i,
const string& indices,
const string& values) {
186 !g_input_.at(i).IsDense(),
189 " already set to dense.");
190 g_input_.at(i).indices_ = indices;
191 g_input_.at(i).values_ = values;
198 template <
class... Args>
200 return vector<OperatorDef>{CreateOperatorDef(args...)};
209 CaffeMap<string, string> m;
210 for (
auto& out : op.output()) {
211 if (IsGradientBlob(out)) {
212 m[out] = out.substr(0, out.length() - 5);
221 static string GradientName(
const string& name) {
222 return name +
"_grad";
225 static bool IsGradientBlob(
const string& name) {
226 return name.length() > 5 && name.find(
"_grad") == name.length() - 5;
229 static string GradientNameToParam(
const string& name) {
230 CHECK(IsGradientBlob(name));
231 return name.substr(0, name.length() - 5);
234 static string GradientSliceIndices(
const string& name) {
235 return name +
"_grad_indices";
238 static string GradientSliceValues(
const string& name) {
239 return name +
"_grad_values";
245 const OperatorDef& def_;
246 const vector<GradientWrapper>& g_output_;
247 vector<GradientWrapper> g_input_;
260 using GradientMakerBase::GradientMakerBase;
261 vector<OperatorDef> GetGradientDefs()
override {
262 return vector<OperatorDef>();
273 using GradientMakerBase::GradientMakerBase;
276 false,
"One should not call gradient for operator ", def_.type(),
".");
288 using GradientMakerBase::GradientMakerBase;
294 " should have a gradient but is not implemented yet.");
298 C10_DECLARE_REGISTRY(
302 const vector<GradientWrapper>&);
304 #ifdef CAFFE2_NO_GRADIENT_OPS 306 #define REGISTER_GRADIENT(name, ...) 307 #define REGISTER_GRADIENT_STR(str_name, ...) 311 #define REGISTER_GRADIENT(name, ...) \ 312 C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__) 313 #define REGISTER_GRADIENT_STR(str_name, ...) \ 314 C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__) 319 #define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient) 324 #define SHOULD_NOT_DO_GRADIENT(name) \ 325 REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled) 327 #define GRADIENT_NOT_IMPLEMENTED_YET(name) \ 328 REGISTER_GRADIENT(name, GradientNotImplementedYet) 334 const OperatorDef& def,
335 const vector<GradientWrapper>& g_output);
339 #endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_ A helper class to indicate that the gradient mechanism is not ready.
static CaffeMap< string, string > MatchGradsToParams(const OperatorDef &op)
Returns map that returns the parameters that the gradients are for.
GradientOpsMeta GetGradientForOp(const OperatorDef &def, const vector< GradientWrapper > &g_output)
Gets the GradientOpsMeta for the given operator def.
A helper class to indicate that the operator should have no gradient.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
virtual GradientOpsMeta Get()
Returns the gradient ops meta.
GradientOpsMeta Get() override
Returns the gradient ops meta.
A helper class to indicate that the operator does not need gradient computation.
GradientOpsMeta Get() override
Returns the gradient ops meta.