Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_box_cox_op.h
1 
17 #ifndef CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
18 #define CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class BatchBoxCoxOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  BatchBoxCoxOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws),
33  min_block_size_(
34  OperatorBase::GetSingleArgument<int>("min_block_size", 256)) {}
35 
36  bool RunOnDevice() override {
37  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(DATA));
38  }
39 
40  template <typename T>
41  bool DoRunWithType();
42 
43  protected:
44  template <typename T>
45  void BoxCoxNaive(
46  TIndex N,
47  TIndex D,
48  const T* data_ptr,
49  const T* lambda1_ptr,
50  const T* lambda2_ptr,
51  T k_eps,
52  T* output_ptr);
53 
54 #ifdef CAFFE2_USE_MKL
55  template <typename T>
56  void BoxCoxNonzeroLambda(
57  TIndex D,
58  const T* data_ptr,
59  const T* lambda1,
60  const T* lambda2,
61  T k_eps,
62  T* output_ptr);
63 
64  template <typename T>
65  void BoxCoxZeroLambda(
66  TIndex D,
67  const T* data_ptr,
68  const T* lambda2,
69  T k_eps,
70  T* output_ptr);
71 
72  template <typename T>
73  void BoxCoxMixedLambda(
74  const T* data_ptr,
75  const vector<int>& nonzeros,
76  const vector<int>& zeros,
77  const T* lambda1,
78  const T* lambda2,
79  const T* lambda2_z,
80  T k_eps,
81  T* buffer,
82  T* output_ptr);
83 
84  vector<int> nonzeros_, zeros_;
85 
86  // Buffers used by the MKL version are cached across calls.
87  struct CachedBuffers {
88  virtual ~CachedBuffers() {}
89  int type_;
90  };
91  template <typename T>
92  struct TypedCachedBuffers : public CachedBuffers {
93  vector<T> lambda1_, lambda2_, lambda2_z_;
94  vector<T> accumulator_;
95  };
96  template <typename T>
97  TypedCachedBuffers<T>& GetBuffers();
98  unique_ptr<CachedBuffers> buffers_;
99 
100 #endif // CAFFE2_USE_MKL
101 
102  int min_block_size_;
103 
104  INPUT_TAGS(DATA, LAMBDA1, LAMBDA2);
105 };
106 
107 } // namespace caffe2
108 
109 #endif // CAFFE_OPERATORS_BATCH_BOX_COX_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.