Caffe2 - C++ API
A deep learning, cross platform ML framework
stop_gradient.h
1 
17 #ifndef CAFFE2_OPERATORS_STOP_GRADIENT_H_
18 #define CAFFE2_OPERATORS_STOP_GRADIENT_H_
19 
20 #include "caffe2/core/operator.h"
21 
22 namespace caffe2 {
23 
24 template <class Context>
25 class StopGradientOp : public Operator<Context> {
26  public:
27  USE_SIMPLE_CTOR_DTOR(StopGradientOp)
28  USE_OPERATOR_CONTEXT_FUNCTIONS;
29  bool RunOnDevice() override {
30  const auto& in = Input(0);
31  auto* out = Output(0);
32  if (out != &in) {
33  out->CopyFrom(in, &context_);
34  }
35  return true;
36  }
37 };
38 
39 } // namespace caffe2
40 
41 #endif // CAFFE2_OPERATORS_STOP_GRADIENT_H_
Copyright (c) 2016-present, Facebook, Inc.