Caffe2 - C++ API
A deep learning, cross platform ML framework
filler.h
1 #ifndef CAFFE2_FILLER_H_
2 #define CAFFE2_FILLER_H_
3 
4 #include <sstream>
5 
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/tensor.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 // TODO: replace filler distribution enum with a better abstraction
13 enum FillerDistribution { FD_UNIFORM, FD_FIXEDSUM, FD_SYNTHETIC };
14 
15 class TensorFiller {
16  public:
17  template <class Type, class Context>
18  void Fill(Tensor* tensor, Context* context) const {
19  CAFFE_ENFORCE(context, "context is null");
20  CAFFE_ENFORCE(tensor, "tensor is null");
21  auto min = (min_ < std::numeric_limits<Type>::min())
22  ? std::numeric_limits<Type>::min()
23  : static_cast<Type>(min_);
24  auto max = (max_ > std::numeric_limits<Type>::max())
25  ? std::numeric_limits<Type>::max()
26  : static_cast<Type>(max_);
27  CAFFE_ENFORCE_LE(min, max);
28 
29  Tensor temp_tensor(shape_, Context::GetDeviceType());
30  std::swap(*tensor, temp_tensor);
31  Type* data = tensor->template mutable_data<Type>();
32 
33  // select distribution
34  switch (dist_) {
35  case FD_UNIFORM: {
36  math::RandUniform<Type, Context>(
37  tensor->numel(), min, max, data, context);
38  break;
39  }
40  case FD_FIXEDSUM: {
41  auto fixed_sum = static_cast<Type>(fixed_sum_);
42  CAFFE_ENFORCE_LE(min * tensor->numel(), fixed_sum);
43  CAFFE_ENFORCE_GE(max * tensor->numel(), fixed_sum);
44  math::RandFixedSum<Type, Context>(
45  tensor->numel(), min, max, fixed_sum_, data, context);
46  break;
47  }
48  case FD_SYNTHETIC: {
49  math::RandSyntheticData<Type, Context>(
50  tensor->numel(), min, max, data, context);
51  break;
52  }
53  }
54  }
55 
56  TensorFiller& Dist(FillerDistribution dist) {
57  dist_ = dist;
58  return *this;
59  }
60 
61  template <class Type>
62  TensorFiller& Min(Type min) {
63  min_ = (double)min;
64  return *this;
65  }
66 
67  template <class Type>
68  TensorFiller& Max(Type max) {
69  max_ = (double)max;
70  return *this;
71  }
72 
73  template <class Type>
74  TensorFiller& FixedSum(Type fixed_sum) {
75  dist_ = FD_FIXEDSUM;
76  fixed_sum_ = (double)fixed_sum;
77  return *this;
78  }
79 
80  // A helper function to construct the lengths vector for sparse features
81  // We try to pad least one index per batch unless the total_length is 0
82  template <class Type>
83  TensorFiller& SparseLengths(Type total_length) {
84  return FixedSum(total_length)
85  .Min(std::min(static_cast<Type>(1), total_length))
86  .Max(total_length);
87  }
88 
89  // a helper function to construct the segments vector for sparse features
90  template <class Type>
91  TensorFiller& SparseSegments(Type max_segment) {
92  CAFFE_ENFORCE(dist_ != FD_FIXEDSUM);
93  return Min(0).Max(max_segment).Dist(FD_SYNTHETIC);
94  }
95 
96  TensorFiller& Shape(const std::vector<int64_t>& shape) {
97  shape_ = shape;
98  return *this;
99  }
100 
101  template <class Type>
102  TensorFiller(const std::vector<int64_t>& shape, Type fixed_sum)
103  : shape_(shape), dist_(FD_FIXEDSUM), fixed_sum_((double)fixed_sum) {}
104 
105  TensorFiller(const std::vector<int64_t>& shape)
106  : shape_(shape), dist_(FD_UNIFORM), fixed_sum_(0) {}
107 
108  TensorFiller() : TensorFiller(std::vector<int64_t>()) {}
109 
110  std::string DebugString() const {
111  std::stringstream stream;
112  stream << "shape = [" << shape_ << "]; min = " << min_
113  << "; max = " << max_;
114  switch (dist_) {
115  case FD_FIXEDSUM:
116  stream << "; dist = FD_FIXEDSUM";
117  break;
118  case FD_SYNTHETIC:
119  stream << "; dist = FD_SYNTHETIC";
120  break;
121  default:
122  stream << "; dist = FD_UNIFORM";
123  break;
124  }
125  return stream.str();
126  }
127 
128  private:
129  std::vector<int64_t> shape_;
130  // TODO: type is unknown until a user starts to fill data;
131  // cast everything to double for now.
132  double min_ = 0.0;
133  double max_ = 1.0;
134  FillerDistribution dist_;
135  double fixed_sum_;
136 };
137 
138 } // namespace caffe2
139 
140 #endif // CAFFE2_FILLER_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13