Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
operators
experimental
c10
cpu
stop_gradient_cpu.cc
1
#include <ATen/core/dispatch/KernelRegistration.h>
2
#include "caffe2/operators/experimental/c10/schemas/stop_gradient.h"
3
#include "caffe2/utils/math.h"
4
#include "caffe2/core/tensor.h"
5
6
using
caffe2::BaseContext
;
7
using
caffe2::Tensor
;
8
9
namespace
caffe2
{
10
namespace
{
11
template
<
class
DataType>
12
void
stop_gradient_op_cpu_impl(
13
const
at::Tensor
& input_,
14
const
at::Tensor
& output_) {
15
Tensor
input{C10Tensor(input_)};
16
Tensor
output{C10Tensor(output_)};
17
if
(!output.is_same(input)) {
18
output.CopyFrom(input);
19
}
20
}
21
}
// namespace
22
}
// namespace caffe2
23
24
namespace
c10
{
25
C10_REGISTER_KERNEL(caffe2::ops::StopGradient)
26
.kernel<decltype(caffe2::stop_gradient_op_cpu_impl<float>), &caffe2::stop_gradient_op_cpu_impl<float>>()
27
.dispatchKey(CPUTensorId());
28
}
// namespace c10
at::Tensor
Definition:
Tensor.h:48
caffe2::Tensor
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition:
tensor.h:25
nom::repr::Tensor
Definition:
NeuralNet.h:158
at::BaseContext
Virtual interface for the Context class in Caffe2.
Definition:
context_base.h:32
caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition:
blob.h:13
c10
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition:
alias_info.h:7
Generated on Thu Mar 21 2019 13:06:18 for Caffe2 - C++ API by
1.8.11