Caffe2 - C++ API
A deep learning, cross platform ML framework
pickler.h
1 #include <string>
2 #include <vector>
3 
4 #include <ATen/core/ivalue.h>
5 #include <c10/util/ArrayRef.h>
6 #include <torch/csrc/utils/disallow_copy.h>
7 
8 namespace torch {
9 namespace jit {
10 
11 // See Python's pickletools.py for a detailed description of each of these codes
12 enum class OpCode : char {
13  MARK = '(',
14  STOP = '.',
15  POP = '0',
16  POP_MARK = '1',
17  DUP = '2',
18  FLOAT = 'F',
19  INT = 'I',
20  BININT = 'J',
21  BININT1 = 'K',
22  LONG = 'L',
23  BININT2 = 'M',
24  NONE = 'N',
25  PERSID = 'P',
26  BINPERSID = 'Q',
27  REDUCE = 'R',
28  STRING = 'S',
29  BINSTRING = 'T',
30  SHORT_BINSTRING = 'U',
31  UNICODE = 'V',
32  BINUNICODE = 'X',
33  APPEND = 'a',
34  BUILD = 'b',
35  GLOBAL = 'c',
36  DICT = 'd',
37  EMPTY_DICT = '}',
38  APPENDS = 'e',
39  GET = 'g',
40  BINGET = 'h',
41  INST = 'i',
42  LONG_BINGET = 'j',
43  LIST = 'l',
44  EMPTY_LIST = ']',
45  OBJ = 'o',
46  PUT = 'p',
47  BINPUT = 'q',
48  LONG_BINPUT = 'r',
49  SETITEM = 's',
50  TUPLE = 't',
51  EMPTY_TUPLE = ')',
52  SETITEMS = 'u',
53  BINFLOAT = 'G',
54 
55  // Protocol 2
56  PROTO = '\x80',
57  NEWOBJ = '\x81',
58  EXT1 = '\x82',
59  EXT2 = '\x83',
60  EXT4 = '\x84',
61  TUPLE1 = '\x85',
62  TUPLE2 = '\x86',
63  TUPLE3 = '\x87',
64  NEWTRUE = '\x88',
65  NEWFALSE = '\x89',
66  LONG1 = '\x8a',
67  LONG4 = '\x8b',
68 
69  // Protocol 3 (Python 3.x)
70  BINBYTES = 'B',
71  SHORT_BINBYTES = 'C',
72 
73  // Protocol 4
74  SHORT_BINUNICODE = '\x8c',
75  BINUNICODE8 = '\x8d',
76  BINBYTES8 = '\x8e',
77  EMPTY_SET = '\x8f',
78  ADDITEMS = '\x90',
79  FROZENSET = '\x91',
80  NEWOBJ_EX = '\x92',
81  STACK_GLOBAL = '\x93',
82  MEMOIZE = '\x94',
83  FRAME = '\x95'
84 };
85 
86 enum PicklerClass : uint8_t { TENSOR = 0, INTLIST = 1 };
87 
88 using ::c10::IValue;
89 
90 class Pickler {
91  TH_DISALLOW_COPY_AND_ASSIGN(Pickler);
92 
93  public:
94  Pickler(std::vector<at::Tensor>* tensor_table)
95  : tensor_table_(tensor_table) {}
96 
97  const std::vector<char>& stack();
98  void start();
99  void finish();
100  void addIValue(const IValue& ivalue);
101 
102  private:
103  void pushBinGet(uint32_t memo_id);
104  void pushMemoizedString(const IValue& ivalue);
105  void pushString(const std::string& string);
106  void pushTensor(const IValue& ivalue);
107  void pushDouble(const IValue& ivalue);
108  void pushMemoization(const void* item);
109  void pushMemoization(const IValue& ivalue);
110  void pushList(const IValue& ivalue);
111  void pushIntList(const IValue& ivalue);
112  void pushTuple(const IValue& ivalue);
113  void pushDict(const IValue& ivalue);
114  void pushClass(PicklerClass cls);
115  const void* getPointer(const IValue& ivalue);
116 
117  void pushUint8(uint8_t value);
118  void pushOpCode(OpCode value);
119  void pushUint32(uint32_t value);
120  void pushInt32(int32_t value);
121 
122  // Stack of opcodes/data
123  std::vector<char> stack_;
124 
125  // Memoization of IValues that have been written (index in table is used for
126  // BINPUT opcodes) to enable shared references
127  std::unordered_map<const void*, uint32_t> memo_;
128 
129  // External table of tensors to serialize
130  std::vector<at::Tensor>* tensor_table_;
131 
132  // TODO: only use this if necessary (add a pass to find all shared ivalues,
133  // and only memoize those)
134  uint32_t memo_id = 0;
135 };
136 
137 class Unpickler {
138  TH_DISALLOW_COPY_AND_ASSIGN(Unpickler);
139 
140  public:
141  Unpickler(
142  void* data,
143  size_t size,
144  const std::vector<at::Tensor>* tensor_table)
145  : bytes_(static_cast<const uint8_t*>(data)),
146  end_ptr_(bytes_ + size),
147  tensor_table_(tensor_table) {}
148 
149  std::vector<IValue> parse_ivalue_list();
150 
151  private:
152  // No arguments ensures that a template arugment must be specified
153  // so that the number of bytes read / type read is explicit
154  template <typename T>
155  T read() {
156  AT_CHECK(
157  bytes_ + sizeof(T) <= end_ptr_,
158  "Unpickler overran buffer while reading a value");
159  T item;
160  std::memcpy(&item, bytes_, sizeof(T));
161  bytes_ += sizeof(T);
162  return item;
163  }
164 
165  double readFloat();
166  void run();
167  OpCode readInstruction();
168  std::string readString();
169  OpCode readOpCode();
170  void readList();
171 
172  std::vector<IValue> stack_;
173  std::vector<IValue> memo_;
174  std::vector<size_t> marks_;
175  const uint8_t* bytes_;
176  const uint8_t* end_ptr_;
177  const std::vector<at::Tensor>* tensor_table_;
178  OpCode last_opcode_;
179 };
180 
181 } // namespace jit
182 } // namespace torch
Definition: jit_type.h:17