Caffe2 - C++ API
A deep learning, cross platform ML framework
averaged_loss_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/averaged_loss.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 using std::vector;
9 
10 namespace caffe2 {
11 namespace {
12 
13 struct Cache final : public c10::KernelCache {
14  at::Tensor scratch = at::Tensor(C10Tensor(empty({}, CPU)));
15 };
16 
17 template <class T, class Context>
18 void averaged_loss_op_cpu_impl(
19  const at::Tensor& X_,
20  const at::Tensor& sum_,
21  Cache* state) {
22  Tensor X{C10Tensor(X_)};
23  Tensor sum{C10Tensor(sum_)};
24  CPUContext context;
25 
26  sum.Resize(vector<int64_t>());
27 
28  T* data = sum.template mutable_data<T>();
29 
30  Tensor scratch(state->scratch);
31  caffe2::math::Sum<T, Context>(
32  X.numel(),
33  X.template data<T>(),
34  data,
35  static_cast<Context*>(&context),
36  &scratch);
37  if (X.numel() > 0) {
38  caffe2::math::Scale<T, T, Context>(
39  1,
40  static_cast<T>(1.) / X.numel(),
41  sum.template data<T>(),
42  data,
43  static_cast<Context*>(&context));
44  }
45 }
46 } // namespace
47 } // namespace caffe2
48 
49 namespace c10 {
50 C10_REGISTER_KERNEL(caffe2::ops::AveragedLoss)
51  .withCache<caffe2::Cache>()
52  .kernel<decltype(caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::averaged_loss_op_cpu_impl<float, caffe2::CPUContext>>()
53  .dispatchKey(CPUTensorId());
54 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
A kernel can keep around a cache to have better performance when it&#39;s called multiple times...
Definition: KernelCache.h:15