Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_segments.h
1 #ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_
2 #define CAFFE2_OPERATORS_PACK_SEGMENTS_H_
3 
4 #include <atomic>
5 #include <limits>
6 #include <mutex>
7 #include <unordered_map>
8 #include <vector>
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/tensor.h"
11 #include "caffe2/utils/math.h"
12 
13 namespace caffe2 {
14 
15 template <class Context>
16 class PackSegmentsOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19  USE_DISPATCH_HELPER;
20 
21  template <class... Args>
22  explicit PackSegmentsOp(Args&&... args)
23  : Operator<Context>(std::forward<Args>(args)...),
24  max_length_(this->template GetSingleArgument<int>("max_length", -1)),
25  pad_minf_(this->template GetSingleArgument<bool>("pad_minf", false)),
26  return_presence_mask_(this->template GetSingleArgument<bool>(
27  "return_presence_mask",
28  false)) {
29  if (pad_minf_) {
30  padding_ = -1.0 * std::numeric_limits<float>::infinity();
31  } else {
32  padding_ = 0;
33  }
34  }
35 
36  bool RunOnDevice() {
37  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
38  }
39 
40  template <typename T>
41  bool DoRunWithType();
42 
43  template <typename T, typename Data_T>
44  bool DoRunWithType2();
45 
46  INPUT_TAGS(LENGTHS, DATA);
47 
48  private:
49  int64_t max_length_;
50  bool pad_minf_;
51  float padding_;
52  bool return_presence_mask_;
53 
54  // Scratch space required by the CUDA version
55  Tensor dev_buffer_{Context::GetDeviceType()};
56  Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()};
57  Tensor dev_max_length_{Context::GetDeviceType()};
58  Tensor host_max_length_{CPU};
59 };
60 
61 template <class Context>
62 class UnpackSegmentsOp final : public Operator<Context> {
63  public:
64  USE_OPERATOR_CONTEXT_FUNCTIONS;
65  USE_DISPATCH_HELPER;
66 
67  template <class... Args>
68  explicit UnpackSegmentsOp(Args&&... args)
69  : Operator<Context>(std::forward<Args>(args)...),
70  max_length_(this->template GetSingleArgument<int>("max_length", -1)) {}
71 
72  bool RunOnDevice() override {
73  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
74  }
75 
76  template <typename T>
77  bool DoRunWithType();
78 
79  template <typename T, typename Data_T>
80  bool DoRunWithType2();
81 
82  INPUT_TAGS(LENGTHS, DATA);
83 
84  private:
85  int64_t max_length_;
86  Tensor dev_buffer_{Context::GetDeviceType()};
87  Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()};
88  Tensor dev_max_length_{Context::GetDeviceType()};
89  Tensor dev_num_cell_{Context::GetDeviceType()};
90  Tensor host_max_length_{CPU};
91  Tensor host_num_cell_{CPU};
92 };
93 
94 } // namespace caffe2
95 #endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13