Caffe2 - C++ API
A deep learning, cross platform ML framework
pickler.cpp
1 #include <torch/csrc/jit/pickler.h>
2 
3 namespace torch {
4 namespace jit {
5 
6 using ::c10::IValue;
7 
8 PicklerClass getClass(const std::string& str) {
9  if (str == "TensorID") {
10  return PicklerClass::TENSOR;
11  } else if (str == "IntList") {
12  return PicklerClass::INTLIST;
13  }
14  AT_ERROR("Unknown class name for unpickler: ", str);
15 }
16 
17 const std::string& getClassName(PicklerClass cls) {
18  static const std::string tensor_class("TensorID\n");
19  static const std::string intlist_class("IntList\n");
20  switch (cls) {
21  case PicklerClass::TENSOR:
22  return tensor_class;
23  case PicklerClass::INTLIST:
24  return intlist_class;
25  default:
26  AT_ERROR("Unknown class for pickler");
27  }
28 }
29 
30 const std::string& getModuleName() {
31  static const std::string module_name("__main__\n");
32  return module_name;
33 }
34 
35 const std::vector<char>& Pickler::stack() {
36  return stack_;
37 }
38 
39 void Pickler::start() {
40  pushOpCode(OpCode::PROTO);
41  pushUint8(2);
42 
43  // All attributes get pushed into a list and their indices saved in the
44  // module def
45  pushOpCode(OpCode::EMPTY_LIST);
46  pushOpCode(OpCode::MARK);
47 }
48 
49 void Pickler::finish() {
50  pushOpCode(OpCode::APPENDS);
51  pushOpCode(OpCode::STOP);
52 }
53 
54 void Pickler::addIValue(const IValue& ivalue) {
55  // Check if reference ivalue has been saved before
56  const void* ivalue_ptr = getPointer(ivalue);
57  if (ivalue_ptr) {
58  auto memo_entry = memo_.find(ivalue_ptr);
59  if (memo_entry != memo_.end()) {
60  // This value has already been pushed, just do a BINGET
61  pushBinGet(memo_entry->second);
62  return;
63  }
64  }
65 
66  if (ivalue.isTensor()) {
67  pushTensor(ivalue);
68  } else if (ivalue.isTuple()) {
69  pushTuple(ivalue);
70  } else if (ivalue.isDouble()) {
71  pushDouble(ivalue);
72  } else if (ivalue.isInt()) {
73  // TODO: use BININT1/BININT2/LONG if possible/necessary
74  AT_ASSERT(
75  ivalue.toInt() <= std::numeric_limits<int32_t>::max() &&
76  ivalue.toInt() >= std::numeric_limits<int32_t>::min());
77  pushOpCode(OpCode::BININT);
78  pushInt32(ivalue.toInt());
79  } else if (ivalue.isBool()) {
80  if (ivalue.toBool()) {
81  pushOpCode(OpCode::NEWTRUE);
82  } else {
83  pushOpCode(OpCode::NEWFALSE);
84  }
85  } else if (ivalue.isString()) {
86  pushMemoizedString(ivalue);
87  } else if (ivalue.isGenericList()) {
88  pushList(ivalue);
89  } else if (ivalue.isGenericDict()) {
90  pushDict(ivalue);
91  } else if (ivalue.isNone()) {
92  pushOpCode(OpCode::NONE);
93  } else if (ivalue.isIntList()) {
94  pushIntList(ivalue);
95  } else {
96  AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind());
97  }
98 }
99 
100 const void* Pickler::getPointer(const IValue& ivalue) {
101  if (ivalue.isGenericDict()) {
102  return &(ivalue.toGenericDictRef());
103  } else if (ivalue.isGenericList()) {
104  return &(ivalue.toGenericListRef());
105  } else if (ivalue.isTuple()) {
106  return &(ivalue.toTuple()->elements());
107  } else if (ivalue.isString()) {
108  return &(ivalue.toStringRef());
109  } else if (ivalue.isIntList()) {
110  return &(ivalue.toIntListRef());
111  }
112 
113  return nullptr;
114 }
115 
116 void Pickler::pushBinGet(uint32_t memo_id) {
117  if (memo_id <= std::numeric_limits<uint8_t>::max()) {
118  pushOpCode(OpCode::BINGET);
119  pushUint8(memo_id);
120  } else {
121  // Memoized too many items, issue a LONG_BINGET instead
122  pushOpCode(OpCode::LONG_BINGET);
123  pushUint32(memo_id);
124  }
125 }
126 
127 void Pickler::pushMemoizedString(const IValue& ivalue) {
128  const auto& string = ivalue.toStringRef();
129 
130  pushOpCode(OpCode::BINUNICODE);
131  pushUint32(string.size());
132  pushString(string);
133  pushMemoization(ivalue);
134 }
135 
136 void Pickler::pushString(const std::string& string) {
137  stack_.insert(stack_.end(), string.begin(), string.end());
138 }
139 
140 void Pickler::pushClass(PicklerClass cls) {
141  const auto& name = getClassName(cls);
142  // Write it to the tensor table
143  auto memo_entry = memo_.find(&name);
144  if (memo_entry == memo_.end()) {
145  pushOpCode(OpCode::GLOBAL);
146  // Module name + "\n"
147  pushString(getModuleName());
148  // Class name + "\n"
149  pushString(name);
150  pushMemoization((void*)&name);
151  } else {
152  pushBinGet(memo_entry->second);
153  }
154 
155  pushOpCode(OpCode::EMPTY_TUPLE);
156  pushOpCode(OpCode::NEWOBJ);
157 }
158 
159 void Pickler::pushTensor(const IValue& ivalue) {
160  pushClass(PicklerClass::TENSOR);
161 
162  tensor_table_->push_back(ivalue.toTensor());
163  auto tensor_id = tensor_table_->size() - 1;
164  pushOpCode(OpCode::BININT);
165  pushUint32(tensor_id);
166 
167  pushOpCode(OpCode::BUILD);
168 }
169 
170 void Pickler::pushIntList(const IValue& ivalue) {
171  pushClass(PicklerClass::INTLIST);
172 
173  pushOpCode(OpCode::EMPTY_LIST);
174  pushMemoization(ivalue);
175  pushOpCode(OpCode::MARK);
176 
177  for (const auto& item : ivalue.toIntListRef()) {
178  addIValue(item);
179  }
180 
181  pushOpCode(OpCode::APPENDS);
182  pushOpCode(OpCode::BUILD);
183 }
184 
185 void Pickler::pushDouble(const IValue& ivalue) {
186  double value = ivalue.toDouble();
187  AT_ASSERT(sizeof(double) == 8);
188  char* bytes = reinterpret_cast<char*>(&value);
189 
190  pushOpCode(OpCode::BINFLOAT);
191  for (size_t i = 0; i < 8; ++i) {
192  pushUint8(bytes[8 - i - 1]);
193  }
194 }
195 
196 using ivalue_pair = std::pair<IValue, IValue>;
197 
199  bool operator()(const ivalue_pair& lhs, const ivalue_pair& rhs) const {
200  if (lhs.first.isString()) {
201  return lhs.first.toStringRef() < rhs.first.toStringRef();
202  }
203  if (lhs.first.isInt()) {
204  return lhs.first.toInt() < rhs.first.toInt();
205  }
206  if (lhs.first.isDouble()) {
207  return lhs.first.toDouble() < rhs.first.toDouble();
208  }
209  AT_ERROR("Uncomparable IValue types");
210  }
211 };
212 
213 void Pickler::pushDict(const IValue& ivalue) {
214  auto dict = ivalue.toGenericDictRef();
215 
216  pushOpCode(OpCode::EMPTY_DICT);
217  pushMemoization(ivalue);
218 
219  pushOpCode(OpCode::MARK);
220 
221  // Sort the dict for deterministic keys
222  std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end());
223  std::sort(dict_items.begin(), dict_items.end(), IValuePairComparator());
224 
225  for (const auto& pair : dict_items) {
226  addIValue(pair.first);
227  addIValue(pair.second);
228  }
229 
230  pushOpCode(OpCode::SETITEMS);
231 }
232 
233 void Pickler::pushMemoization(const void* item) {
234  AT_ASSERT(item != nullptr);
235  if (memo_id <= std::numeric_limits<uint8_t>::max()) {
236  pushOpCode(OpCode::BINPUT);
237  pushUint8(memo_id);
238  } else {
239  // Memoized too many items, issue a LONG_BINPUT instead
240  pushOpCode(OpCode::LONG_BINPUT);
241  pushUint32(memo_id);
242  }
243  memo_[item] = memo_id;
244  AT_ASSERT(memo_id <= std::numeric_limits<uint32_t>::max());
245  ++memo_id;
246 }
247 
248 void Pickler::pushMemoization(const IValue& ivalue) {
249  pushMemoization(getPointer(ivalue));
250 }
251 
252 void Pickler::pushList(const IValue& ivalue) {
253  auto list = ivalue.toGenericListRef();
254  pushOpCode(OpCode::EMPTY_LIST);
255  pushMemoization(ivalue);
256 
257  pushOpCode(OpCode::MARK);
258 
259  for (const auto& item : list) {
260  addIValue(item);
261  }
262 
263  pushOpCode(OpCode::APPENDS);
264 }
265 
266 void Pickler::pushTuple(const IValue& ivalue) {
267  // TODO: Small tuple unrolling (e.g. TUPLE3)
268  pushOpCode(OpCode::MARK);
269  auto tuple = ivalue.toTuple()->elements();
270 
271  for (const auto& item : tuple) {
272  addIValue(item);
273  }
274 
275  pushOpCode(OpCode::TUPLE);
276  pushMemoization(ivalue);
277 }
278 
279 void Pickler::pushUint8(uint8_t value) {
280  const char* begin = reinterpret_cast<const char*>(&value);
281  stack_.insert(stack_.end(), begin, begin + sizeof(uint8_t));
282 }
283 
284 void Pickler::pushOpCode(OpCode value) {
285  const char* begin = reinterpret_cast<const char*>(&value);
286  stack_.insert(stack_.end(), begin, begin + sizeof(OpCode));
287 }
288 
289 void Pickler::pushUint32(uint32_t value) {
290  const char* begin = reinterpret_cast<const char*>(&value);
291  stack_.insert(stack_.end(), begin, begin + sizeof(uint32_t));
292 }
293 
294 void Pickler::pushInt32(int32_t value) {
295  const char* begin = reinterpret_cast<const char*>(&value);
296  stack_.insert(stack_.end(), begin, begin + sizeof(int32_t));
297 }
298 
299 std::vector<IValue> Unpickler::parse_ivalue_list() {
300  run();
301  AT_ASSERT(stack_.size() == 1);
302  return stack_[0].toGenericListRef();
303 }
304 
305 double Unpickler::readFloat() {
306  AT_ASSERT(sizeof(double) == 8);
307  AT_ASSERT(bytes_ + 8 < end_ptr_);
308  double result;
309 
310  // Pickle floats are big endian, so reverse the bytes
311  std::reverse_copy(
312  reinterpret_cast<const char*>(bytes_),
313  reinterpret_cast<const char*>(bytes_ + 8),
314  reinterpret_cast<char*>(&result));
315 
316  bytes_ += 8;
317  return result;
318 }
319 
320 void Unpickler::run() {
321  // Expect a PROTO opcode and protocol number at the start of blob
322  AT_ASSERT(readOpCode() == OpCode::PROTO);
323  uint8_t protocol = read<uint8_t>();
324  AT_CHECK(
325  protocol == 2,
326  "Only Pickle protocol 2 is supported, found protocol = ",
327  protocol);
328 
329  while (bytes_ < end_ptr_) {
330  OpCode opcode = readInstruction();
331  if (opcode == OpCode::STOP) {
332  return;
333  }
334  last_opcode_ = opcode;
335  }
336 
337  AT_ERROR("Overran buffer while unpickling data, didn't find STOP opcode");
338 }
339 
340 OpCode Unpickler::readInstruction() {
341  auto opcode = readOpCode();
342  switch (opcode) {
343  case OpCode::EMPTY_LIST: {
344  // Look back to see if the last opcode was an IntList class
345  if (last_opcode_ == OpCode::NEWOBJ) {
346  // It's a list specialization, the enum ID of which is on the stack
347  PicklerClass cls =
348  static_cast<PicklerClass>(uint8_t(stack_.back().toInt()));
349  if (cls == PicklerClass::INTLIST) {
350  stack_.emplace_back(std::vector<int64_t>());
351  }
352  } else {
353  stack_.emplace_back(std::vector<IValue>());
354  }
355  } break;
356  case OpCode::EMPTY_TUPLE: {
357  stack_.emplace_back(c10::ivalue::Tuple::create({}));
358  } break;
359  case OpCode::BINPUT: {
360  size_t memo_id = read<uint8_t>();
361  if (memo_.size() <= memo_id) {
362  memo_.reserve(1 + 2 * memo_id);
363  }
364  memo_.push_back(stack_.back());
365  } break;
366  case OpCode::MARK: {
367  // Mark location of the container ivalue in the stack
368  marks_.push_back(stack_.size());
369  } break;
370  case OpCode::BININT: {
371  int32_t value = read<int32_t>();
372  stack_.emplace_back(int64_t(value));
373  } break;
374  case OpCode::BINUNICODE: {
375  uint32_t length = read<uint32_t>();
376  const char* characters = reinterpret_cast<const char*>(bytes_);
377  AT_ASSERT(bytes_ + length < end_ptr_);
378  bytes_ += length;
379  stack_.emplace_back(std::string(characters, /*n=*/length));
380  } break;
381  case OpCode::BINFLOAT:
382  stack_.emplace_back(readFloat());
383  break;
384  case OpCode::TUPLE: {
385  size_t start = marks_.back();
386  marks_.pop_back();
387  IValue tup = c10::ivalue::Tuple::create(
388  std::vector<IValue>(stack_.begin() + start, stack_.end()));
389  stack_.resize(start);
390  stack_.push_back(tup);
391  } break;
392  case OpCode::EMPTY_DICT:
393  stack_.emplace_back(c10::ivalue::UnorderedMap());
394  break;
395  case OpCode::APPENDS: {
396  readList();
397  } break;
398  case OpCode::SETITEMS: {
399  size_t start = marks_.back();
400  marks_.pop_back();
401  auto dict = stack_.at(start - 1).toGenericDict();
402  for (size_t i = start; i < stack_.size(); i += 2) {
403  dict->elements()[stack_[i]] = stack_[i + 1];
404  }
405  stack_.resize(start);
406  } break;
407  case OpCode::BINGET: {
408  stack_.push_back(memo_.at(read<uint8_t>()));
409  } break;
410  case OpCode::STOP:
411  break;
412  case OpCode::GLOBAL: {
413  AT_ASSERT(readString() == "__main__");
414  // Push class name to stack
415  stack_.emplace_back(static_cast<uint8_t>(getClass(readString())));
416  } break;
417  case OpCode::NEWOBJ: {
418  // pop empty tuple
419  stack_.pop_back();
420  } break;
421  case OpCode::BUILD: {
422  auto setitem_data = stack_.back();
423  stack_.pop_back();
424 
425  auto class_name =
426  static_cast<PicklerClass>(uint8_t(stack_.back().toInt()));
427  stack_.pop_back();
428 
429  switch (class_name) {
430  case PicklerClass::TENSOR:
431  stack_.emplace_back(tensor_table_->at(setitem_data.toInt()));
432  break;
433  case PicklerClass::INTLIST:
434  stack_.push_back(setitem_data);
435  break;
436  default:
437  AT_ERROR("Unknown pickler class id");
438  }
439  } break;
440  default:
441  AT_ERROR("Unknown opcode for unpickling");
442  }
443  return opcode;
444 }
445 
446 void Unpickler::readList() {
447  size_t start = marks_.back();
448  marks_.pop_back();
449  auto list_ivalue = stack_.at(start - 1);
450  if (list_ivalue.isIntList()) {
451  auto list = stack_.at(start - 1).toIntList();
452  auto num_elements = stack_.size() - start;
453  list->elements().reserve(num_elements);
454  for (auto it = stack_.begin() + start; it != stack_.end(); ++it) {
455  list->elements().emplace_back(it->toInt());
456  }
457  } else {
458  auto list = stack_.at(start - 1).toGenericList();
459  list->elements().insert(
460  list->elements().end(), stack_.begin() + start, stack_.end());
461  }
462  stack_.resize(start);
463 }
464 
465 // Read a newline terminated string
466 std::string Unpickler::readString() {
467  const char* chars = reinterpret_cast<const char*>(bytes_);
468  size_t n = 0;
469  while (true) {
470  char c = chars[n];
471  if (c == '\n') {
472  break;
473  }
474 
475  // Simple check just in case there is no terminating '\n'
476  AT_ASSERT(c >= '0' && c <= 'z');
477 
478  // Increment after to exclude newline from string
479  ++n;
480  }
481 
482  // Increment by string length + newline char
483  bytes_ += n + 1;
484  return std::string(chars, n);
485 }
486 
487 OpCode Unpickler::readOpCode() {
488  return static_cast<OpCode>(read<uint8_t>());
489 }
490 
491 } // namespace jit
492 } // namespace torch
Definition: jit_type.h:17