Caffe2 - C++ API
A deep learning, cross platform ML framework
expand_dims_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/expand_dims.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
6 using caffe2::Tensor;
7 
8 namespace caffe2 {
9 namespace {
10 
11 struct Cache final : public c10::KernelCache {
12  std::vector<int64_t> dims;
13  bool initialized = false;
14 };
15 
16 template <class DataType>
17 void expand_dims_op_cpu_impl(
18  const at::Tensor& input_,
19  const at::Tensor& output_,
20  ArrayRef<int64_t> dims,
21  Cache* cache) {
22  Tensor input{C10Tensor(input_)};
23  Tensor output{C10Tensor(output_)};
24 
25  if (!cache->initialized) {
26  cache->dims = dims.vec();
27  auto originalSize = cache->dims.size();
28  CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
29  std::sort(cache->dims.begin(), cache->dims.end());
30  cache->dims.erase(
31  std::unique(cache->dims.begin(), cache->dims.end()), cache->dims.end());
32  if (cache->dims.size() < originalSize) {
33  LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
34  }
35  CAFFE_ENFORCE(
36  cache->dims.front() >= 0, "Dimension ids must be non-negative.");
37  cache->initialized = true;
38  }
39 
40  output.CopyFrom(input);
41  if (cache->dims.empty()) {
42  return;
43  }
44 
45  auto newDims = input.sizes().vec();
46  CAFFE_ENFORCE_GE(
47  input.sizes().size() + cache->dims.size(),
48  cache->dims.back() + 1,
49  "Input needs at least ",
50  (1 + cache->dims.back() - cache->dims.size()),
51  " dimensions given `dims`.");
52  for (const auto dim : cache->dims) {
53  newDims.insert(newDims.begin() + dim, 1);
54  }
55  output.Reshape(newDims);
56 }
57 } // namespace
58 } // namespace caffe2
59 
60 namespace c10 {
61 C10_REGISTER_KERNEL(caffe2::ops::ExpandDims)
62  .withCache<caffe2::Cache>()
63  .kernel<decltype(caffe2::expand_dims_op_cpu_impl<float>), &caffe2::expand_dims_op_cpu_impl<float>>()
64  .dispatchKey(CPUTensorId());
65 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
A kernel can keep around a cache to have better performance when it&#39;s called multiple times...
Definition: KernelCache.h:15