3 #include "caffe2/core/context.h" 4 #include "caffe2/core/operator.h" 8 template <
class Context>
11 USE_OPERATOR_CONTEXT_FUNCTIONS;
12 template <
class... Args>
16 this->
template GetSingleArgument<int64_t>(
"max_batch_size", -1)) {}
18 bool RunOnDevice()
override {
19 auto& input =
Input(0);
20 vector<int64_t> output_dims(input.sizes().vec());
21 CAFFE_ENFORCE(!output_dims.empty());
22 if (InputSize() > 1) {
25 auto& batch_size =
Input(1);
26 int64_t real_batch_size = *batch_size.template data<int64_t>();
27 int64_t max_batch_size = output_dims[0];
28 CAFFE_ENFORCE_GE(max_batch_size, real_batch_size);
29 output_dims[0] = real_batch_size;
30 auto* output = Output(0, output_dims, input.dtype());
31 this->context_.template CopyItems<Context, Context>(
33 input.numel() * real_batch_size / max_batch_size,
35 output->raw_mutable_data(input.dtype()));
41 "max_batch_size should be larger than 0. Got ",
47 CAFFE_ENFORCE_GE(max_batch_size_, output_dims.front());
49 int64_t real_batch_size = output_dims[0];
50 output_dims[0] = max_batch_size_;
51 auto* output = Output(0, output_dims, input.dtype());
55 static_cast<char*>(output->raw_data()),
57 this->context_.template CopyItems<Context, Context>(
61 output->raw_mutable_data(input.dtype()));
63 if (OutputSize() > 1) {
64 auto* real_batch_tensor = Output(1, {1}, at::dtype<int64_t>());
65 real_batch_tensor->template mutable_data<int64_t>()[0] =
74 int64_t max_batch_size_;
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...