Caffe2 - C++ API
A deep learning, cross platform ML framework
inline_container.cc
1 #include <cstdio>
2 #include <cstring>
3 #include <cerrno>
4 #include <istream>
5 #include <ostream>
6 #include <fstream>
7 
8 #include <c10/core/Allocator.h>
9 #include <c10/core/Backend.h>
10 
11 #include "caffe2/core/common.h"
12 #include "caffe2/core/logging.h"
13 #include "caffe2/serialize/file_adapter.h"
14 #include "caffe2/serialize/inline_container.h"
15 #include "caffe2/serialize/istream_adapter.h"
16 #include "caffe2/serialize/read_adapter_interface.h"
17 
18 #include "miniz.h"
19 
20 namespace caffe2 {
21 namespace serialize {
22 
23 size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
24  auto self = static_cast<PyTorchStreamReader*>(pOpaque);
25  return self->read(file_ofs, static_cast<char*>(pBuf), n);
26 }
27 
28 static std::string basename(const std::string& name) {
29  size_t start = 0;
30  for(size_t i = 0; i < name.size(); ++i) {
31  if (name[i] == '\\' || name[i] == '/') {
32  start = i + 1;
33  }
34  }
35 
36  if (start >= name.size())
37  return "";
38 
39  size_t end = name.size();
40  for(size_t i = end; i > start; --i) {
41  if (name[i - 1] == '.') {
42  end = i - 1;
43  break;
44  }
45  }
46  return name.substr(start, end - start);
47 }
48 
49 size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
50  return in_->read(pos, buf, n, "reading file");
51 }
52 
53 PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
54  : ar_(caffe2::make_unique<mz_zip_archive>()),
55  in_(caffe2::make_unique<FileAdapter>(file_name)) {
56  init();
57 }
58 
59 PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
60  : ar_(caffe2::make_unique<mz_zip_archive>()),
61  in_(caffe2::make_unique<IStreamAdapter>(in)) {
62  init();
63 }
64 
65 PyTorchStreamReader::PyTorchStreamReader(
66  std::unique_ptr<ReadAdapterInterface> in)
67  : ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
68  init();
69 }
70 
71 void PyTorchStreamReader::init() {
72  AT_ASSERT(in_ != nullptr);
73  AT_ASSERT(ar_ != nullptr);
74  memset(ar_.get(), 0, sizeof(mz_zip_archive));
75 
76  size_t size = in_->size();
77 
78  // check for the old magic number,
79  constexpr size_t kMagicValueLength = 8;
80  if (size > kMagicValueLength) {
81  char buf[kMagicValueLength];
82  read(0, buf, kMagicValueLength);
83  valid("checking magic number");
84  AT_ASSERTM(
85  memcmp("PYTORCH1", buf, kMagicValueLength) != 0,
86  "File is an unsupported archive format from the preview release.");
87  }
88 
89  ar_->m_pIO_opaque = this;
90  ar_->m_pRead = istream_read_func;
91 
92  mz_zip_reader_init(ar_.get(), size, 0);
93  valid("reading zip archive");
94 
95  // figure out the archive_name (i.e. the zip folder all the other files are in)
96  // all lookups to getRecord will be prefixed by this folder
97  int n = mz_zip_reader_get_num_files(ar_.get());
98  if (n == 0) {
99  CAFFE_THROW("archive does not contain any files");
100  }
101  size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0, nullptr, 0);
102  valid("getting filename");
103  std::string buf(name_size, '\0');
104  mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
105  valid("getting filename");
106  auto pos = buf.find_first_of('/');
107  if (pos == std::string::npos) {
108  CAFFE_THROW("file in archive is not in a subdirectory: ", buf);
109  }
110  archive_name_ = buf.substr(0, pos);
111 
112  // version check
113  at::DataPtr version_ptr;
114  size_t version_size;
115  std::tie(version_ptr, version_size) = getRecord("version");
116  std::string version(static_cast<const char*>(version_ptr.get()), version_size);
117  size_t version_number = caffe2::stoull(version);
118  AT_ASSERTM(
119  version_number >= kMinSupportedFileFormatVersion,
120  "Attempted to read a PyTorch file with version ",
121  c10::to_string(version_number),
122  ", but the minimum supported version for reading is ",
123  c10::to_string(kMinSupportedFileFormatVersion),
124  ". Your PyTorch script module file is too old. Please re-export it again.");
125  AT_ASSERTM(
126  version_number <= kMaxSupportedFileFormatVersion,
127  "Attempted to read a PyTorch file with version ",
128  version_number,
129  ", but the maximum supported version for reading is ",
130  kMaxSupportedFileFormatVersion,
131  ". Your PyTorch installation may be too old.");
132 }
133 
134 void PyTorchStreamReader::valid(const char* what) {
135  auto err = mz_zip_get_last_error(ar_.get());
136  if (err != MZ_ZIP_NO_ERROR) {
137  CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
138  }
139 }
140 
141 constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
142 constexpr int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
143 constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
144 
145 static std::string getPadding(size_t cursor, const std::string& filename, size_t size) {
146  size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() + sizeof(mz_uint16) * 2;
147  if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
148  start += sizeof(mz_uint16) * 2;
149  if (size >= MZ_UINT32_MAX) {
150  start += 2*sizeof(mz_uint64);
151  }
152  if (cursor >= MZ_UINT32_MAX) {
153  start += sizeof(mz_uint64);
154  }
155  }
156  size_t mod = start % kFieldAlignment;
157  size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
158  size_t padding_size = next_offset - start;
159  std::string buf(padding_size + 4, 'Z');
160  // zip extra encoding (key, size_of_extra_bytes)
161  buf[0] = 'F';
162  buf[1] = 'B';
163  buf[2] = (uint8_t) padding_size;
164  buf[3] = (uint8_t) (padding_size >> 8);
165  return buf;
166 }
167 
168 size_t PyTorchStreamReader::getFileID(const std::string& name) {
169  std::stringstream ss;
170  ss << archive_name_ << "/" << name;
171  size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(), nullptr, 0);
172  if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
173  CAFFE_THROW("file not found: ", ss.str());
174  }
175  valid("locating file");
176  return result;
177 }
178 
179 // return dataptr, size
180 std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(const std::string& name) {
181  size_t key = getFileID(name);
182  mz_zip_archive_file_stat stat;
183  mz_zip_reader_file_stat(ar_.get(), key, &stat);
184  valid("retrieving file meta-data");
185  void * ptr = malloc(stat.m_uncomp_size);
186  mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
187  valid("reading file");
188 
189  at::DataPtr retval(ptr, ptr, free, at::kCPU);
190  return std::make_tuple(std::move(retval), stat.m_uncomp_size);
191 }
192 
193 static int64_t read_le_16(uint8_t* buf) {
194  return buf[0] + (buf[1] << 8);
195 }
196 
197 size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
198  mz_zip_archive_file_stat stat;
199  mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
200  valid("retriving file meta-data");
201  uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
202  in_->read(
203  stat.m_local_header_ofs,
204  local_header,
205  MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
206  "reading file header");
207  size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
208  size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
209  return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
210 }
211 
212 
213 PyTorchStreamReader::~PyTorchStreamReader() {
214  mz_zip_reader_end(ar_.get());
215  valid("closing reader");
216 }
217 
218 size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) {
219  auto self = static_cast<PyTorchStreamWriter*>(pOpaque);
220  if (self->current_pos_ != file_ofs) {
221  // xxx - windows ostringstream refuses to seek to the end of an empty string
222  // so we workaround this by not calling seek unless necessary
223  // in the case of the first write (to the empty string) file_ofs and
224  // current_pos_ will be 0 and the seek won't occur.
225  self->out_->seekp(file_ofs);
226  if(!*self->out_)
227  return 0;
228  }
229 
230  self->out_->write(static_cast<const char*>(pBuf), n);
231  if(!*self->out_)
232  return 0;
233  self->current_pos_ = file_ofs + n;
234  return n;
235 }
236 
237 PyTorchStreamWriter::PyTorchStreamWriter(
238  std::string file_name,
239  std::ostream* out)
240  : ar_(caffe2::make_unique<mz_zip_archive>()),
241  archive_name_(basename(file_name)),
242  out_(out) {
243  memset(ar_.get(), 0, sizeof(mz_zip_archive));
244 
245  if (archive_name_.size() == 0) {
246  CAFFE_THROW("invalid file name: ", file_name);
247  }
248  if (!out_) {
249  file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
250  out_ = &file_stream_;
251  valid("opening archive");
252  }
253 
254  ar_->m_pIO_opaque = this;
255  ar_->m_pWrite = ostream_write_func;
256 
257  mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
258  valid("initializing archive");
259 
260  std::stringstream version;
261  version << kMaxSupportedFileFormatVersion << "\n";
262  writeRecord("version", version.str().c_str(), version.str().size());
263 }
264 
265 void PyTorchStreamWriter::writeRecord(const std::string& name, const void* data, size_t size) {
266  AT_ASSERT(!finalized_);
267  std::stringstream ss;
268  ss << archive_name_ << "/" << name;
269  const std::string& full_name = ss.str();
270  std::string padding = getPadding(ar_->m_archive_size, full_name, size);
271  uint32_t flags = 0;
272  mz_zip_writer_add_mem_ex_v2(
273  ar_.get(),
274  full_name.c_str(),
275  data,
276  size,
277  nullptr,
278  0,
279  flags,
280  0,
281  0,
282  nullptr,
283  padding.c_str(),
284  padding.size(),
285  nullptr,
286  0);
287  valid("writing file");
288 }
289 
290 void PyTorchStreamWriter::writeEndOfFile() {
291  AT_ASSERT(!finalized_);
292  finalized_ = true;
293  mz_zip_writer_finalize_archive(ar_.get());
294  mz_zip_writer_end(ar_.get());
295  valid("writing central directory");
296  if (file_stream_.is_open())
297  file_stream_.close();
298 }
299 
300 
301 void PyTorchStreamWriter::valid(const char* what) {
302  auto err = mz_zip_get_last_error(ar_.get());
303  if (err != MZ_ZIP_NO_ERROR) {
304  CAFFE_THROW("PytorchStreamWriter failed ", what, ": ", mz_zip_get_error_string(err));
305  }
306  if (!*out_) {
307  CAFFE_THROW("PytorchStreamWriter failed ", what, ".");
308  }
309 }
310 
311 PyTorchStreamWriter::~PyTorchStreamWriter() {
312  if (!finalized_) {
313  writeEndOfFile();
314  }
315 }
316 
317 } // namespace serialize
318 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13