1 #include <torch/csrc/jit/pickler.h> 8 PicklerClass getClass(
const std::string& str) {
9 if (str ==
"TensorID") {
10 return PicklerClass::TENSOR;
11 }
else if (str ==
"IntList") {
12 return PicklerClass::INTLIST;
14 AT_ERROR(
"Unknown class name for unpickler: ", str);
17 const std::string& getClassName(PicklerClass cls) {
18 static const std::string tensor_class(
"TensorID\n");
19 static const std::string intlist_class(
"IntList\n");
21 case PicklerClass::TENSOR:
23 case PicklerClass::INTLIST:
26 AT_ERROR(
"Unknown class for pickler");
30 const std::string& getModuleName() {
31 static const std::string module_name(
"__main__\n");
35 const std::vector<char>& Pickler::stack() {
39 void Pickler::start() {
40 pushOpCode(OpCode::PROTO);
45 pushOpCode(OpCode::EMPTY_LIST);
46 pushOpCode(OpCode::MARK);
49 void Pickler::finish() {
50 pushOpCode(OpCode::APPENDS);
51 pushOpCode(OpCode::STOP);
54 void Pickler::addIValue(
const IValue& ivalue) {
56 const void* ivalue_ptr = getPointer(ivalue);
58 auto memo_entry = memo_.find(ivalue_ptr);
59 if (memo_entry != memo_.end()) {
61 pushBinGet(memo_entry->second);
66 if (ivalue.isTensor()) {
68 }
else if (ivalue.isTuple()) {
70 }
else if (ivalue.isDouble()) {
72 }
else if (ivalue.isInt()) {
75 ivalue.toInt() <= std::numeric_limits<int32_t>::max() &&
76 ivalue.toInt() >= std::numeric_limits<int32_t>::min());
77 pushOpCode(OpCode::BININT);
78 pushInt32(ivalue.toInt());
79 }
else if (ivalue.isBool()) {
80 if (ivalue.toBool()) {
81 pushOpCode(OpCode::NEWTRUE);
83 pushOpCode(OpCode::NEWFALSE);
85 }
else if (ivalue.isString()) {
86 pushMemoizedString(ivalue);
87 }
else if (ivalue.isGenericList()) {
89 }
else if (ivalue.isGenericDict()) {
91 }
else if (ivalue.isNone()) {
92 pushOpCode(OpCode::NONE);
93 }
else if (ivalue.isIntList()) {
96 AT_ERROR(
"Unknown IValue type for pickling: ", ivalue.tagKind());
100 const void* Pickler::getPointer(
const IValue& ivalue) {
101 if (ivalue.isGenericDict()) {
102 return &(ivalue.toGenericDictRef());
103 }
else if (ivalue.isGenericList()) {
104 return &(ivalue.toGenericListRef());
105 }
else if (ivalue.isTuple()) {
106 return &(ivalue.toTuple()->elements());
107 }
else if (ivalue.isString()) {
108 return &(ivalue.toStringRef());
109 }
else if (ivalue.isIntList()) {
110 return &(ivalue.toIntListRef());
116 void Pickler::pushBinGet(uint32_t memo_id) {
117 if (memo_id <= std::numeric_limits<uint8_t>::max()) {
118 pushOpCode(OpCode::BINGET);
122 pushOpCode(OpCode::LONG_BINGET);
127 void Pickler::pushMemoizedString(
const IValue& ivalue) {
128 const auto&
string = ivalue.toStringRef();
130 pushOpCode(OpCode::BINUNICODE);
131 pushUint32(
string.size());
133 pushMemoization(ivalue);
136 void Pickler::pushString(
const std::string&
string) {
137 stack_.insert(stack_.end(),
string.begin(),
string.end());
140 void Pickler::pushClass(PicklerClass cls) {
141 const auto& name = getClassName(cls);
143 auto memo_entry = memo_.find(&name);
144 if (memo_entry == memo_.end()) {
145 pushOpCode(OpCode::GLOBAL);
147 pushString(getModuleName());
150 pushMemoization((
void*)&name);
152 pushBinGet(memo_entry->second);
155 pushOpCode(OpCode::EMPTY_TUPLE);
156 pushOpCode(OpCode::NEWOBJ);
159 void Pickler::pushTensor(
const IValue& ivalue) {
160 pushClass(PicklerClass::TENSOR);
162 tensor_table_->push_back(ivalue.toTensor());
163 auto tensor_id = tensor_table_->size() - 1;
164 pushOpCode(OpCode::BININT);
165 pushUint32(tensor_id);
167 pushOpCode(OpCode::BUILD);
170 void Pickler::pushIntList(
const IValue& ivalue) {
171 pushClass(PicklerClass::INTLIST);
173 pushOpCode(OpCode::EMPTY_LIST);
174 pushMemoization(ivalue);
175 pushOpCode(OpCode::MARK);
177 for (
const auto& item : ivalue.toIntListRef()) {
181 pushOpCode(OpCode::APPENDS);
182 pushOpCode(OpCode::BUILD);
185 void Pickler::pushDouble(
const IValue& ivalue) {
186 double value = ivalue.toDouble();
187 AT_ASSERT(
sizeof(
double) == 8);
188 char* bytes =
reinterpret_cast<char*
>(&value);
190 pushOpCode(OpCode::BINFLOAT);
191 for (
size_t i = 0; i < 8; ++i) {
192 pushUint8(bytes[8 - i - 1]);
196 using ivalue_pair = std::pair<IValue, IValue>;
199 bool operator()(
const ivalue_pair& lhs,
const ivalue_pair& rhs)
const {
200 if (lhs.first.isString()) {
201 return lhs.first.toStringRef() < rhs.first.toStringRef();
203 if (lhs.first.isInt()) {
204 return lhs.first.toInt() < rhs.first.toInt();
206 if (lhs.first.isDouble()) {
207 return lhs.first.toDouble() < rhs.first.toDouble();
209 AT_ERROR(
"Uncomparable IValue types");
213 void Pickler::pushDict(
const IValue& ivalue) {
214 auto dict = ivalue.toGenericDictRef();
216 pushOpCode(OpCode::EMPTY_DICT);
217 pushMemoization(ivalue);
219 pushOpCode(OpCode::MARK);
222 std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end());
225 for (
const auto& pair : dict_items) {
226 addIValue(pair.first);
227 addIValue(pair.second);
230 pushOpCode(OpCode::SETITEMS);
233 void Pickler::pushMemoization(
const void* item) {
234 AT_ASSERT(item !=
nullptr);
235 if (memo_id <= std::numeric_limits<uint8_t>::max()) {
236 pushOpCode(OpCode::BINPUT);
240 pushOpCode(OpCode::LONG_BINPUT);
243 memo_[item] = memo_id;
244 AT_ASSERT(memo_id <= std::numeric_limits<uint32_t>::max());
248 void Pickler::pushMemoization(
const IValue& ivalue) {
249 pushMemoization(getPointer(ivalue));
252 void Pickler::pushList(
const IValue& ivalue) {
253 auto list = ivalue.toGenericListRef();
254 pushOpCode(OpCode::EMPTY_LIST);
255 pushMemoization(ivalue);
257 pushOpCode(OpCode::MARK);
259 for (
const auto& item : list) {
263 pushOpCode(OpCode::APPENDS);
266 void Pickler::pushTuple(
const IValue& ivalue) {
268 pushOpCode(OpCode::MARK);
269 auto tuple = ivalue.toTuple()->elements();
271 for (
const auto& item : tuple) {
275 pushOpCode(OpCode::TUPLE);
276 pushMemoization(ivalue);
279 void Pickler::pushUint8(uint8_t value) {
280 const char* begin =
reinterpret_cast<const char*
>(&value);
281 stack_.insert(stack_.end(), begin, begin +
sizeof(uint8_t));
284 void Pickler::pushOpCode(OpCode value) {
285 const char* begin =
reinterpret_cast<const char*
>(&value);
286 stack_.insert(stack_.end(), begin, begin +
sizeof(OpCode));
289 void Pickler::pushUint32(uint32_t value) {
290 const char* begin =
reinterpret_cast<const char*
>(&value);
291 stack_.insert(stack_.end(), begin, begin +
sizeof(uint32_t));
294 void Pickler::pushInt32(int32_t value) {
295 const char* begin =
reinterpret_cast<const char*
>(&value);
296 stack_.insert(stack_.end(), begin, begin +
sizeof(int32_t));
299 std::vector<IValue> Unpickler::parse_ivalue_list() {
301 AT_ASSERT(stack_.size() == 1);
302 return stack_[0].toGenericListRef();
305 double Unpickler::readFloat() {
306 AT_ASSERT(
sizeof(
double) == 8);
307 AT_ASSERT(bytes_ + 8 < end_ptr_);
312 reinterpret_cast<const char*>(bytes_),
313 reinterpret_cast<const char*>(bytes_ + 8),
314 reinterpret_cast<char*>(&result));
320 void Unpickler::run() {
322 AT_ASSERT(readOpCode() == OpCode::PROTO);
323 uint8_t protocol = read<uint8_t>();
326 "Only Pickle protocol 2 is supported, found protocol = ",
329 while (bytes_ < end_ptr_) {
330 OpCode opcode = readInstruction();
331 if (opcode == OpCode::STOP) {
334 last_opcode_ = opcode;
337 AT_ERROR(
"Overran buffer while unpickling data, didn't find STOP opcode");
340 OpCode Unpickler::readInstruction() {
341 auto opcode = readOpCode();
343 case OpCode::EMPTY_LIST: {
345 if (last_opcode_ == OpCode::NEWOBJ) {
348 static_cast<PicklerClass
>(uint8_t(stack_.back().toInt()));
349 if (cls == PicklerClass::INTLIST) {
350 stack_.emplace_back(std::vector<int64_t>());
353 stack_.emplace_back(std::vector<IValue>());
356 case OpCode::EMPTY_TUPLE: {
357 stack_.emplace_back(c10::ivalue::Tuple::create({}));
359 case OpCode::BINPUT: {
360 size_t memo_id = read<uint8_t>();
361 if (memo_.size() <= memo_id) {
362 memo_.reserve(1 + 2 * memo_id);
364 memo_.push_back(stack_.back());
368 marks_.push_back(stack_.size());
370 case OpCode::BININT: {
371 int32_t value = read<int32_t>();
372 stack_.emplace_back(int64_t(value));
374 case OpCode::BINUNICODE: {
375 uint32_t length = read<uint32_t>();
376 const char* characters =
reinterpret_cast<const char*
>(bytes_);
377 AT_ASSERT(bytes_ + length < end_ptr_);
379 stack_.emplace_back(std::string(characters, length));
381 case OpCode::BINFLOAT:
382 stack_.emplace_back(readFloat());
384 case OpCode::TUPLE: {
385 size_t start = marks_.back();
387 IValue tup = c10::ivalue::Tuple::create(
388 std::vector<IValue>(stack_.begin() + start, stack_.end()));
389 stack_.resize(start);
390 stack_.push_back(tup);
392 case OpCode::EMPTY_DICT:
393 stack_.emplace_back(c10::ivalue::UnorderedMap());
395 case OpCode::APPENDS: {
398 case OpCode::SETITEMS: {
399 size_t start = marks_.back();
401 auto dict = stack_.at(start - 1).toGenericDict();
402 for (
size_t i = start; i < stack_.size(); i += 2) {
403 dict->elements()[stack_[i]] = stack_[i + 1];
405 stack_.resize(start);
407 case OpCode::BINGET: {
408 stack_.push_back(memo_.at(read<uint8_t>()));
412 case OpCode::GLOBAL: {
413 AT_ASSERT(readString() ==
"__main__");
415 stack_.emplace_back(static_cast<uint8_t>(getClass(readString())));
417 case OpCode::NEWOBJ: {
421 case OpCode::BUILD: {
422 auto setitem_data = stack_.back();
426 static_cast<PicklerClass
>(uint8_t(stack_.back().toInt()));
429 switch (class_name) {
430 case PicklerClass::TENSOR:
431 stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
433 case PicklerClass::INTLIST:
434 stack_.push_back(setitem_data);
437 AT_ERROR(
"Unknown pickler class id");
441 AT_ERROR(
"Unknown opcode for unpickling");
446 void Unpickler::readList() {
447 size_t start = marks_.back();
449 auto list_ivalue = stack_.at(start - 1);
450 if (list_ivalue.isIntList()) {
451 auto list = stack_.at(start - 1).toIntList();
452 auto num_elements = stack_.size() - start;
453 list->elements().reserve(num_elements);
454 for (
auto it = stack_.begin() + start; it != stack_.end(); ++it) {
455 list->elements().emplace_back(it->toInt());
458 auto list = stack_.at(start - 1).toGenericList();
459 list->elements().insert(
460 list->elements().end(), stack_.begin() + start, stack_.end());
462 stack_.resize(start);
466 std::string Unpickler::readString() {
467 const char* chars =
reinterpret_cast<const char*
>(bytes_);
476 AT_ASSERT(c >=
'0' && c <=
'z');
484 return std::string(chars, n);
487 OpCode Unpickler::readOpCode() {
488 return static_cast<OpCode
>(read<uint8_t>());