Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor.cc
1 
17 #include "caffe2/core/tensor.h"
18 
19 #include "caffe2/core/blob_stats.h"
20 #include "caffe2/core/flags.h"
21 
22 CAFFE2_DEFINE_bool(
23  caffe2_keep_on_shrink,
24  true,
25  "If set, keeps memory when a tensor is shrinking its size.");
26 
27 CAFFE2_DEFINE_int64(
28  caffe2_max_keep_on_shrink_memory,
29  LLONG_MAX,
30  "The maximum memory in bytes to keep on shrink, if the difference between "
31  "tensor sizes is bigger than this then tensor will be reset.");
32 
33 namespace caffe2 {
34 // declaring it here instead of context.cc because tensor.h includes context.h
35 CAFFE_KNOWN_TYPE(Tensor<CPUContext>);
36 
37 TensorPrinter::TensorPrinter(
38  const std::string& tensor_name,
39  const std::string& file_name,
40  int limit)
41  : to_file_(!file_name.empty()),
42  limit_(limit ? limit : k_limit_default_),
43  tensor_name_(tensor_name) {
44  if (to_file_) {
45  // We will output to file instead of printing on screen.
46  // We will write each individual tensor to its individual file.
47  log_file_.reset(new std::ofstream(
48  file_name, std::ofstream::out | std::ofstream::trunc));
49  CAFFE_ENFORCE(
50  log_file_->good(),
51  "Failed to open TensorPrinter file ",
52  file_name,
53  ". rdstate() = ",
54  log_file_->rdstate());
55  }
56 }
57 
58 TensorPrinter::~TensorPrinter() {
59  if (log_file_.get()) {
60  log_file_->close();
61  }
62 }
63 
64 static CaffeMap<CaffeTypeId, TypeCall> type_call_registry_ {
65  {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorType<CPUContext>}
66 };
67 
68 TypeCall GetTypeCallFunction(CaffeTypeId id) {
69  auto f = type_call_registry_.find(id);
70  if (f == type_call_registry_.end()) {
71  return nullptr;
72  }
73  return f->second;
74 }
75 
76 void RegisterTypeCallFunction(CaffeTypeId id, TypeCall c) {
77  type_call_registry_[id] = c;
78 }
79 
80 static CaffeMap<CaffeTypeId, TensorInfoCall> tensor_info_call_registry_{
81  {TypeMeta::Id<Tensor<CPUContext>>(), GetTensorInfo<CPUContext>}};
82 
83 TensorInfoCall GetTensorInfoFunction(CaffeTypeId id) {
84  auto f = tensor_info_call_registry_.find(id);
85  if (f == tensor_info_call_registry_.end()) {
86  return nullptr;
87  }
88  return f->second;
89 }
90 
91 void RegisterTensorInfoFunction(CaffeTypeId id, TensorInfoCall c) {
92  tensor_info_call_registry_[id] = c;
93 }
94 
95 namespace {
96 
97 struct TensorCPUStatGetter : BlobStatGetter {
98  size_t sizeBytes(const Blob& blob) const override {
99  const auto& tensor = blob.Get<TensorCPU>();
100  auto nbytes = tensor.nbytes();
101  if (nbytes > 0 && tensor.IsType<std::string>()) {
102  const auto* data = tensor.data<std::string>();
103  for (size_t i = 0; i < tensor.size(); ++i) {
104  nbytes += data[i].size();
105  }
106  }
107  return nbytes;
108  }
109 };
110 REGISTER_BLOB_STAT_GETTER(TensorCPU, TensorCPUStatGetter);
111 }
112 
113 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.
Copyright (c) 2016-present, Facebook, Inc.