Caffe2 - C++ API
A deep learning, cross platform ML framework
adjust_batch_op.h
1 #pragma once
2 
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/operator.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 class AdjustBatchOp final : public Operator<Context> {
10  public:
11  USE_OPERATOR_CONTEXT_FUNCTIONS;
12  template <class... Args>
13  explicit AdjustBatchOp(Args&&... args)
14  : Operator<Context>(std::forward<Args>(args)...),
15  max_batch_size_(
16  this->template GetSingleArgument<int64_t>("max_batch_size", -1)) {}
17 
18  bool RunOnDevice() override {
19  auto& input = Input(0);
20  vector<int64_t> output_dims(input.sizes().vec());
21  CAFFE_ENFORCE(!output_dims.empty());
22  if (InputSize() > 1) {
23  // TODO: if we have a second input and we have max_batch_size set, check
24  // the batch size of the two inputs for consistency
25  auto& batch_size = Input(1);
26  int64_t real_batch_size = *batch_size.template data<int64_t>();
27  int64_t max_batch_size = output_dims[0];
28  CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
29  output_dims[0] = real_batch_size;
30  auto* output = Output(0, output_dims, input.dtype());
31  this->context_.template CopyItems<Context, Context>(
32  input.dtype(),
33  input.numel() * real_batch_size / max_batch_size,
34  input.raw_data(),
35  output->raw_mutable_data(input.dtype()));
36  } else {
37  // Pad to max batch size
38  CAFFE_ENFORCE_GT(
39  max_batch_size_,
40  0,
41  "max_batch_size should be larger than 0. Got ",
42  max_batch_size_);
43 
44  // TODO: ideally we can support the case when input batch is larger than
45  // the max_batch_size, as we can just pad to the multiple of
46  // max_batch_size.
47  CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
48 
49  int64_t real_batch_size = output_dims[0];
50  output_dims[0] = max_batch_size_;
51  auto* output = Output(0, output_dims, input.dtype());
52  math::Set(
53  output->nbytes(),
54  static_cast<char>(0),
55  static_cast<char*>(output->raw_data()),
56  &context_);
57  this->context_.template CopyItems<Context, Context>(
58  input.dtype(),
59  input.numel(),
60  input.raw_data(),
61  output->raw_mutable_data(input.dtype()));
62 
63  if (OutputSize() > 1) {
64  auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
65  real_batch_tensor->template mutable_data<int64_t>()[0] =
66  real_batch_size;
67  }
68  }
69 
70  return true;
71  }
72 
73  private:
74  int64_t max_batch_size_;
75 };
76 } // namespace caffe2
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13