Caffe2 - C++ API
A deep learning, cross platform ML framework
pack_segments.h
1 
17 #ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_
18 #define CAFFE2_OPERATORS_PACK_SEGMENTS_H_
19 
20 #include <atomic>
21 #include <limits>
22 #include <mutex>
23 #include <unordered_map>
24 #include <vector>
25 #include "caffe2/core/operator.h"
26 #include "caffe2/core/tensor.h"
27 #include "caffe2/utils/math.h"
28 
29 namespace caffe2 {
30 
31 template <class Context>
32 class PackSegmentsOp final : public Operator<Context> {
33  public:
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35  // USE_SIMPLE_CTOR_DTOR(PackSegmentsOp)
36  USE_DISPATCH_HELPER;
37 
38  PackSegmentsOp(const OperatorDef& operator_def, Workspace* ws)
39  : Operator<Context>(operator_def, ws),
40  pad_minf_(OperatorBase::GetSingleArgument<bool>("pad_minf", false)),
41  return_presence_mask_(OperatorBase::GetSingleArgument<bool>(
42  "return_presence_mask",
43  false)) {
44  if (pad_minf_) {
45  padding_ = -1.0 * std::numeric_limits<float>::infinity();
46  } else {
47  padding_ = 0;
48  }
49  }
50 
51  bool RunOnDevice() {
52  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
53  }
54 
55  template <typename T>
56  bool DoRunWithType();
57 
58  template <typename T, typename Data_T>
59  bool DoRunWithType2();
60 
61  INPUT_TAGS(LENGTHS, DATA);
62 
63  private:
64  bool pad_minf_;
65  float padding_;
66  bool return_presence_mask_;
67 
68  // Scratch space required by the CUDA version
69  Tensor<Context> dev_buffer_;
70  Tensor<Context> dev_lengths_prefix_sum_;
71  Tensor<Context> dev_max_length_;
72  Tensor<CPUContext> host_max_length_;
73 };
74 
75 template <class Context>
76 class UnpackSegmentsOp final : public Operator<Context> {
77  public:
78  USE_OPERATOR_CONTEXT_FUNCTIONS;
79  USE_SIMPLE_CTOR_DTOR(UnpackSegmentsOp)
80  USE_DISPATCH_HELPER;
81 
82  bool RunOnDevice() override {
83  return DispatchHelper<TensorTypes<int, long>>::call(this, Input(LENGTHS));
84  }
85 
86  template <typename T>
87  bool DoRunWithType();
88 
89  template <typename T, typename Data_T>
90  bool DoRunWithType2();
91 
92  INPUT_TAGS(LENGTHS, DATA);
93 
94  private:
95  Tensor<Context> dev_buffer_;
96  Tensor<Context> dev_lengths_prefix_sum_;
97  Tensor<Context> dev_max_length_;
98  Tensor<Context> dev_num_cell_;
99  Tensor<CPUContext> host_max_length_;
100  Tensor<CPUContext> host_num_cell_;
101 };
102 
103 } // namespace caffe2
104 #endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.