Caffe2 - Python API
A deep learning, cross platform ML framework
quantized.py
1 import torch
2 import copy
3 import numbers
4 from typing import Tuple, Optional
5 from torch import Tensor
6 from torch.jit import ScriptModule
7 
8 from torch.nn.utils.rnn import PackedSequence
9 from torch.nn import _VF
10 
11 
13  __constants__ = ['scale', 'zero_point']
14 
15  def __init__(self, other):
16  super(QuantizedLinear, self).__init__()
17  self.in_features = other.in_features
18  self.out_features = other.out_features
19  # Quantize weight and discard the original
20  self.weight, self.col_offsets, self.scale, self.zero_point = torch.fbgemm_linear_quantize_weight(
21  other.weight.clone().float())
22  self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
23  self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
24  assert other.bias is not None, 'QuantizedLinear requires a bias'
25  self.bias = torch.nn.Parameter(other.bias.clone().float())
26 
27  self.register_buffer(
28  'packed_tensor_ptr',
29  torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
30 
31  @torch.jit.script_method
32  def _unpack(self):
33  self.packed_tensor_ptr.set_(
34  torch.fbgemm_pack_quantized_matrix(
35  self.weight, self.weight.size(1), self.weight.size(0)))
36 
37  @torch.jit.script_method
38  def _pack(self):
39  self.packed_tensor_ptr.set_(
40  torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
41 
42  @torch.jit.script_method
43  def forward(self, input):
44  out = torch.fbgemm_linear_int8_weight(
45  input.float(), self.weight, self.packed_tensor_ptr, self.col_offsets,
46  self.scale, self.zero_point, self.bias)
47  return out.type_as(input)
48 
49  def extra_repr(self):
50  repr = 'in_features={in_features}, out_features={out_features}, ' \
51  'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
52  return repr
53 
54 
55 # Quantized RNN cell implementations
57  __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
58  'zero_point_ih', 'zero_point_hh']
59 
60  def __init__(self, other):
61  super(QuantizedRNNCellBase, self).__init__()
62  self.input_size = other.input_size
63  self.hidden_size = other.hidden_size
64  self.bias = other.bias
65  if not self.bias:
66  raise ValueError("Quantized RNN cells require bias terms")
67 
68  weight_ih, col_offsets_ih, self.scale_ih, self.zero_point_ih = \
69  torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
70  self.register_buffer('weight_ih', weight_ih)
71  self.register_buffer('col_offsets_ih', col_offsets_ih)
72  weight_hh, col_offsets_hh, self.scale_hh, self.zero_point_hh = \
73  torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
74  self.register_buffer('weight_hh', weight_hh)
75  self.register_buffer('col_offsets_hh', col_offsets_hh)
76 
77  packed_ih = torch.fbgemm_pack_quantized_matrix(
78  self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
79  self.register_buffer('packed_ih', packed_ih)
80  packed_hh = torch.fbgemm_pack_quantized_matrix(
81  self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
82  self.register_buffer('packed_hh', packed_hh)
83 
84  self.bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=False)
85  self.bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=False)
86 
87  def extra_repr(self):
88  s = '{input_size}, {hidden_size}'
89  if 'bias' in self.__dict__ and self.bias is not True:
90  s += ', bias={bias}'
91  if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
92  s += ', nonlinearity={nonlinearity}'
93  return s.format(**self.__dict__)
94 
95  @torch.jit.script_method
96  def check_forward_input(self, input):
97  if input.size(1) != self.input_size:
98  raise RuntimeError(
99  "input has inconsistent input_size: got {}, expected {}".format(
100  input.size(1), self.input_size))
101 
102  @torch.jit.script_method
103  def check_forward_hidden(self, input, hx, hidden_label=''):
104  # type: (Tensor, Tensor, str) -> None
105  if input.size(0) != hx.size(0):
106  raise RuntimeError(
107  "Input batch size {} doesn't match hidden{} batch size {}".format(
108  input.size(0), hidden_label, hx.size(0)))
109 
110  if hx.size(1) != self.hidden_size:
111  raise RuntimeError(
112  "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
113  hidden_label, hx.size(1), self.hidden_size))
114 
115  # TODO: for some reason weak_script_method causes a destruction of the
116  # module to occur, which in turn frees the packed_ih object via its DataPtr
117  # deleter. This is bizarre and should probably get fixed.
118  # @torch._jit_internal.weak_script_method
119  @torch.jit.script_method
120  def _unpack(self):
121  self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
122  self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
123  self.packed_hh.set_(
124  torch.fbgemm_pack_quantized_matrix(
125  self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
126 
127  # @torch._jit_internal.weak_script_method
128  @torch.jit.script_method
129  def _pack(self):
130  self.packed_ih.set_(
131  torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
132  self.packed_hh.set_(
133  torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach())
134 
135 
137  __constants__ = ['input_size', 'hidden_size', 'bias', 'scale_hh', 'scale_ih',
138  'zero_point_ih', 'zero_point_hh', 'nonlinearity']
139 
140  def __init__(self, other):
141  super(QuantizedRNNCell, self).__init__(other)
142  self.nonlinearity = other.nonlinearity
143 
144  @torch.jit.script_method
145  def forward(self, input, hx=None):
146  # type: (Tensor, Optional[Tensor]) -> Tensor
147  self.check_forward_input(input)
148  if hx is None:
149  hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
150  self.check_forward_hidden(input, hx, '')
151  if self.nonlinearity == "tanh":
152  ret = _VF.quantized_rnn_tanh_cell(
153  input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
154  self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
155  self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
156  self.zero_point_hh
157  )
158  elif self.nonlinearity == "relu":
159  ret = _VF.quantized_rnn_relu_cell(
160  input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
161  self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
162  self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
163  self.zero_point_hh
164  )
165  else:
166  ret = input # TODO: remove when jit supports exception flow
167  raise RuntimeError(
168  "Unknown nonlinearity: {}".format(self.nonlinearity))
169  return ret
170 
171 
173  def __init__(self, other):
174  super(QuantizedLSTMCell, self).__init__(other)
175 
176  @torch.jit.script_method
177  def forward(self, input, hx=None):
178  # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
179  self.check_forward_input(input)
180  if hx is None:
181  zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
182  hx = (zeros, zeros)
183  self.check_forward_hidden(input, hx[0], '[0]')
184  self.check_forward_hidden(input, hx[1], '[1]')
185  return _VF.quantized_lstm_cell(
186  input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
187  self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
188  self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
189  self.zero_point_hh
190  )
191 
192 
194  def __init__(self, other):
195  super(QuantizedGRUCell, self).__init__(other)
196 
197  @torch.jit.script_method
198  def forward(self, input, hx=None):
199  # type: (Tensor, Optional[Tensor]) -> Tensor
200  self.check_forward_input(input)
201  if hx is None:
202  hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
203  self.check_forward_hidden(input, hx, '')
204  return _VF.quantized_gru_cell(
205  input, hx, self.weight_ih, self.weight_hh, self.bias_ih,
206  self.bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
207  self.col_offsets_hh, self.scale_ih, self.scale_hh, self.zero_point_ih,
208  self.zero_point_hh
209  )
210 
211 
212 def quantize_rnn_cell_modules(module):
213  reassign = {}
214  for name, mod in module.named_modules():
215  if mod is module:
216  continue
217  new_mod = quantize_rnn_cell_modules(mod)
218  if new_mod is not mod:
219  reassign[name] = new_mod
220  for name, mod in reassign.items():
221  setattr(module, name, mod)
222  if isinstance(module, torch.nn.LSTMCell):
223  return QuantizedLSTMCell(mod)
224  if isinstance(module, torch.nn.GRUCell):
225  return QuantizedGRUCell(mod)
226  if isinstance(module, torch.nn.RNNCell):
227  return QuantizedRNNCell(mod)
228 
229  return module
230 
231 
232 def quantize_linear_modules(module):
233  reassign = {}
234  for name, mod in module.named_modules():
235  if mod is module:
236  continue
237  new_mod = quantize_linear_modules(mod)
238  if new_mod is not mod:
239  reassign[name] = new_mod
240 
241  for name, mod in reassign.items():
242  setattr(module, name, mod)
243  if isinstance(mod, torch.nn.Linear):
244  return QuantizedLinear(mod)
245  return module
def check_forward_hidden(self, input, hx, hidden_label='')
Definition: quantized.py:103
def annotate(the_type, the_value)
Definition: __init__.py:1560