1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/averaged_loss.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 17 template <
class T,
class Context>
18 void averaged_loss_op_cpu_impl(
23 Tensor sum{C10Tensor(sum_)};
26 sum.Resize(vector<int64_t>());
28 T* data = sum.template mutable_data<T>();
30 Tensor scratch(state->scratch);
31 caffe2::math::Sum<T, Context>(
35 static_cast<Context*>(&context),
38 caffe2::math::Scale<T, T, Context>(
40 static_cast<T>(1.) / X.numel(),
41 sum.template data<T>(),
43 static_cast<Context*>(&context));
50 C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
51 .withCache<caffe2::Cache>()
52 .kernel<decltype(caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
53 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
A kernel can keep around a cache to have better performance when it's called multiple times...