8 #include <c10/core/Allocator.h> 9 #include <c10/core/Backend.h> 11 #include "caffe2/core/common.h" 12 #include "caffe2/core/logging.h" 13 #include "caffe2/serialize/file_adapter.h" 14 #include "caffe2/serialize/inline_container.h" 15 #include "caffe2/serialize/istream_adapter.h" 16 #include "caffe2/serialize/read_adapter_interface.h" 23 size_t istream_read_func(
void *pOpaque, mz_uint64 file_ofs,
void *pBuf,
size_t n) {
24 auto self =
static_cast<PyTorchStreamReader*
>(pOpaque);
25 return self->read(file_ofs, static_cast<char*>(pBuf), n);
28 static std::string basename(
const std::string& name) {
30 for(
size_t i = 0; i < name.size(); ++i) {
31 if (name[i] ==
'\\' || name[i] ==
'/') {
36 if (start >= name.size())
39 size_t end = name.size();
40 for(
size_t i = end; i > start; --i) {
41 if (name[i - 1] ==
'.') {
46 return name.substr(start, end - start);
49 size_t PyTorchStreamReader::read(uint64_t pos,
char* buf,
size_t n) {
50 return in_->read(pos, buf, n,
"reading file");
53 PyTorchStreamReader::PyTorchStreamReader(
const std::string& file_name)
54 : ar_(
caffe2::make_unique<mz_zip_archive>()),
55 in_(
caffe2::make_unique<FileAdapter>(file_name)) {
59 PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
60 : ar_(
caffe2::make_unique<mz_zip_archive>()),
61 in_(
caffe2::make_unique<IStreamAdapter>(in)) {
65 PyTorchStreamReader::PyTorchStreamReader(
66 std::unique_ptr<ReadAdapterInterface> in)
67 : ar_(
caffe2::make_unique<mz_zip_archive>()), in_(
std::move(in)) {
71 void PyTorchStreamReader::init() {
72 AT_ASSERT(in_ !=
nullptr);
73 AT_ASSERT(ar_ !=
nullptr);
74 memset(ar_.get(), 0,
sizeof(mz_zip_archive));
76 size_t size = in_->size();
79 constexpr
size_t kMagicValueLength = 8;
80 if (size > kMagicValueLength) {
81 char buf[kMagicValueLength];
82 read(0, buf, kMagicValueLength);
83 valid(
"checking magic number");
85 memcmp(
"PYTORCH1", buf, kMagicValueLength) != 0,
86 "File is an unsupported archive format from the preview release.");
89 ar_->m_pIO_opaque =
this;
90 ar_->m_pRead = istream_read_func;
92 mz_zip_reader_init(ar_.get(), size, 0);
93 valid(
"reading zip archive");
97 int n = mz_zip_reader_get_num_files(ar_.get());
99 CAFFE_THROW(
"archive does not contain any files");
101 size_t name_size = mz_zip_reader_get_filename(ar_.get(), 0,
nullptr, 0);
102 valid(
"getting filename");
103 std::string buf(name_size,
'\0');
104 mz_zip_reader_get_filename(ar_.get(), 0, &buf[0], name_size);
105 valid(
"getting filename");
106 auto pos = buf.find_first_of(
'/');
107 if (pos == std::string::npos) {
108 CAFFE_THROW(
"file in archive is not in a subdirectory: ", buf);
110 archive_name_ = buf.substr(0, pos);
115 std::tie(version_ptr, version_size) = getRecord(
"version");
116 std::string version(static_cast<const char*>(version_ptr.get()), version_size);
117 size_t version_number = caffe2::stoull(version);
119 version_number >= kMinSupportedFileFormatVersion,
120 "Attempted to read a PyTorch file with version ",
121 c10::to_string(version_number),
122 ", but the minimum supported version for reading is ",
123 c10::to_string(kMinSupportedFileFormatVersion),
124 ". Your PyTorch script module file is too old. Please re-export it again.");
126 version_number <= kMaxSupportedFileFormatVersion,
127 "Attempted to read a PyTorch file with version ",
129 ", but the maximum supported version for reading is ",
130 kMaxSupportedFileFormatVersion,
131 ". Your PyTorch installation may be too old.");
134 void PyTorchStreamReader::valid(
const char* what) {
135 auto err = mz_zip_get_last_error(ar_.get());
136 if (err != MZ_ZIP_NO_ERROR) {
137 CAFFE_THROW(
"PytorchStreamReader failed ", what,
": ", mz_zip_get_error_string(err));
141 constexpr
int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
142 constexpr
int MZ_ZIP_LDH_FILENAME_LEN_OFS = 26;
143 constexpr
int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
145 static std::string getPadding(
size_t cursor,
const std::string& filename,
size_t size) {
146 size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename.size() +
sizeof(mz_uint16) * 2;
147 if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
148 start +=
sizeof(mz_uint16) * 2;
149 if (size >= MZ_UINT32_MAX) {
150 start += 2*
sizeof(mz_uint64);
152 if (cursor >= MZ_UINT32_MAX) {
153 start +=
sizeof(mz_uint64);
156 size_t mod = start % kFieldAlignment;
157 size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
158 size_t padding_size = next_offset - start;
159 std::string buf(padding_size + 4,
'Z');
163 buf[2] = (uint8_t) padding_size;
164 buf[3] = (uint8_t) (padding_size >> 8);
168 size_t PyTorchStreamReader::getFileID(
const std::string& name) {
169 std::stringstream ss;
170 ss << archive_name_ <<
"/" << name;
171 size_t result = mz_zip_reader_locate_file(ar_.get(), ss.str().c_str(),
nullptr, 0);
172 if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
173 CAFFE_THROW(
"file not found: ", ss.str());
175 valid(
"locating file");
180 std::tuple<at::DataPtr, size_t> PyTorchStreamReader::getRecord(
const std::string& name) {
181 size_t key = getFileID(name);
182 mz_zip_archive_file_stat stat;
183 mz_zip_reader_file_stat(ar_.get(), key, &stat);
184 valid(
"retrieving file meta-data");
185 void * ptr = malloc(stat.m_uncomp_size);
186 mz_zip_reader_extract_to_mem(ar_.get(), key, ptr, stat.m_uncomp_size, 0);
187 valid(
"reading file");
190 return std::make_tuple(std::move(retval), stat.m_uncomp_size);
193 static int64_t read_le_16(uint8_t* buf) {
194 return buf[0] + (buf[1] << 8);
197 size_t PyTorchStreamReader::getRecordOffset(
const std::string& name) {
198 mz_zip_archive_file_stat stat;
199 mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
200 valid(
"retriving file meta-data");
201 uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
203 stat.m_local_header_ofs,
205 MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
206 "reading file header");
207 size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
208 size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
209 return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
213 PyTorchStreamReader::~PyTorchStreamReader() {
214 mz_zip_reader_end(ar_.get());
215 valid(
"closing reader");
218 size_t ostream_write_func(
void *pOpaque, mz_uint64 file_ofs,
const void *pBuf,
size_t n) {
219 auto self =
static_cast<PyTorchStreamWriter*
>(pOpaque);
220 if (self->current_pos_ != file_ofs) {
225 self->out_->seekp(file_ofs);
230 self->out_->write(static_cast<const char*>(pBuf), n);
233 self->current_pos_ = file_ofs + n;
237 PyTorchStreamWriter::PyTorchStreamWriter(
238 std::string file_name,
240 : ar_(
caffe2::make_unique<mz_zip_archive>()),
241 archive_name_(basename(file_name)),
243 memset(ar_.get(), 0,
sizeof(mz_zip_archive));
245 if (archive_name_.size() == 0) {
246 CAFFE_THROW(
"invalid file name: ", file_name);
249 file_stream_.open(file_name, std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
250 out_ = &file_stream_;
251 valid(
"opening archive");
254 ar_->m_pIO_opaque =
this;
255 ar_->m_pWrite = ostream_write_func;
257 mz_zip_writer_init_v2(ar_.get(), 0, MZ_ZIP_FLAG_WRITE_ZIP64);
258 valid(
"initializing archive");
260 std::stringstream version;
261 version << kMaxSupportedFileFormatVersion <<
"\n";
262 writeRecord(
"version", version.str().c_str(), version.str().size());
265 void PyTorchStreamWriter::writeRecord(
const std::string& name,
const void* data,
size_t size) {
266 AT_ASSERT(!finalized_);
267 std::stringstream ss;
268 ss << archive_name_ <<
"/" << name;
269 const std::string& full_name = ss.str();
270 std::string padding = getPadding(ar_->m_archive_size, full_name, size);
272 mz_zip_writer_add_mem_ex_v2(
287 valid(
"writing file");
290 void PyTorchStreamWriter::writeEndOfFile() {
291 AT_ASSERT(!finalized_);
293 mz_zip_writer_finalize_archive(ar_.get());
294 mz_zip_writer_end(ar_.get());
295 valid(
"writing central directory");
296 if (file_stream_.is_open())
297 file_stream_.close();
301 void PyTorchStreamWriter::valid(
const char* what) {
302 auto err = mz_zip_get_last_error(ar_.get());
303 if (err != MZ_ZIP_NO_ERROR) {
304 CAFFE_THROW(
"PytorchStreamWriter failed ", what,
": ", mz_zip_get_error_string(err));
307 CAFFE_THROW(
"PytorchStreamWriter failed ", what,
".");
311 PyTorchStreamWriter::~PyTorchStreamWriter() {
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...