Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_box_cox_op.h
1 #ifndef CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
2 #define CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class BatchBoxCoxOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  template <class... Args>
16  explicit BatchBoxCoxOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...),
18  min_block_size_(
19  this->template GetSingleArgument<int>("min_block_size", 256)) {}
20 
21  bool RunOnDevice() override {
22  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
23  }
24 
25  template <typename T>
26  bool DoRunWithType();
27 
28  protected:
29  template <typename T>
30  void BoxCoxNaive(
31  int64_t N,
32  int64_t D,
33  const T* data_ptr,
34  const T* lambda1_ptr,
35  const T* lambda2_ptr,
36  T k_eps,
37  T* output_ptr);
38 
39 #ifdef CAFFE2_USE_MKL
40  template <typename T>
41  void BoxCoxNonzeroLambda(
42  int64_t D,
43  const T* data_ptr,
44  const T* lambda1,
45  const T* lambda2,
46  T k_eps,
47  T* output_ptr);
48 
49  template <typename T>
50  void BoxCoxZeroLambda(
51  int64_t D,
52  const T* data_ptr,
53  const T* lambda2,
54  T k_eps,
55  T* output_ptr);
56 
57  template <typename T>
58  void BoxCoxMixedLambda(
59  const T* data_ptr,
60  const vector<int>& nonzeros,
61  const vector<int>& zeros,
62  const T* lambda1,
63  const T* lambda2,
64  const T* lambda2_z,
65  T k_eps,
66  T* buffer,
67  T* output_ptr);
68 
69  vector<int> nonzeros_, zeros_;
70 
71  // Buffers used by the MKL version are cached across calls.
72  struct CachedBuffers {
73  virtual ~CachedBuffers() {}
74  int type_;
75  };
76  template <typename T>
77  struct TypedCachedBuffers : public CachedBuffers {
78  vector<T> lambda1_, lambda2_, lambda2_z_;
79  vector<T> accumulator_;
80  };
81  template <typename T>
82  TypedCachedBuffers<T>& GetBuffers();
83  unique_ptr<CachedBuffers> buffers_;
84 
85 #endif // CAFFE2_USE_MKL
86 
87  int min_block_size_;
88 
89  INPUT_TAGS(DATA, LAMBDA1, LAMBDA2);
90 };
91 
92 } // namespace caffe2
93 
94 #endif // CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70