Caffe2 - C++ API
A deep learning, cross platform ML framework
stop_gradient_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/stop_gradient.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 
9 namespace caffe2 {
10 namespace {
11 template <class DataType>
12 void stop_gradient_op_cpu_impl(
13  const at::Tensor& input_,
14  const at::Tensor& output_) {
15  Tensor input{C10Tensor(input_)};
16  Tensor output{C10Tensor(output_)};
17  if (!output.is_same(input)) {
18  output.CopyFrom(input);
19  }
20 }
21 } // namespace
22 } // namespace caffe2
23 
24 namespace c10 {
25 C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
26  .kernel<decltype(caffe2::stop_gradient_op_cpu_impl<float>), &caffe2::stop_gradient_op_cpu_impl<float>>()
27  .dispatchKey(CPUTensorId());
28 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7