4 from typing
import Tuple, Optional
5 from torch
import Tensor
13 __constants__ = [
'scale',
'zero_point']
15 def __init__(self, other):
16 super(QuantizedLinear, self).__init__()
21 other.weight.clone().float())
22 self.
weight = torch.nn.Parameter(self.
weight, requires_grad=
False)
24 assert other.bias
is not None,
'QuantizedLinear requires a bias' 25 self.
bias = torch.nn.Parameter(other.bias.clone().float())
29 torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
31 @torch.jit.script_method
33 self.packed_tensor_ptr.set_(
34 torch.fbgemm_pack_quantized_matrix(
35 self.
weight, self.weight.size(1), self.weight.size(0)))
37 @torch.jit.script_method
39 self.packed_tensor_ptr.set_(
42 @torch.jit.script_method
43 def forward(self, input):
44 out = torch.fbgemm_linear_int8_weight(
47 return out.type_as(input)
50 repr =
'in_features={in_features}, out_features={out_features}, ' \
51 'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
57 __constants__ = [
'input_size',
'hidden_size',
'bias',
'scale_hh',
'scale_ih',
58 'zero_point_ih',
'zero_point_hh']
60 def __init__(self, other):
61 super(QuantizedRNNCellBase, self).__init__()
64 self.
bias = other.bias
66 raise ValueError(
"Quantized RNN cells require bias terms")
68 weight_ih, col_offsets_ih, self.scale_ih, self.
zero_point_ih = \
69 torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
70 self.register_buffer(
'weight_ih', weight_ih)
71 self.register_buffer(
'col_offsets_ih', col_offsets_ih)
72 weight_hh, col_offsets_hh, self.scale_hh, self.
zero_point_hh = \
73 torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
74 self.register_buffer(
'weight_hh', weight_hh)
75 self.register_buffer(
'col_offsets_hh', col_offsets_hh)
77 packed_ih = torch.fbgemm_pack_quantized_matrix(
78 self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
79 self.register_buffer(
'packed_ih', packed_ih)
80 packed_hh = torch.fbgemm_pack_quantized_matrix(
81 self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
82 self.register_buffer(
'packed_hh', packed_hh)
84 self.
bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=
False)
85 self.
bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=
False)
88 s =
'{input_size}, {hidden_size}' 89 if 'bias' in self.__dict__
and self.
bias is not True:
91 if 'nonlinearity' in self.__dict__
and self.nonlinearity !=
"tanh":
92 s +=
', nonlinearity={nonlinearity}' 93 return s.format(**self.__dict__)
95 @torch.jit.script_method
96 def check_forward_input(self, input):
99 "input has inconsistent input_size: got {}, expected {}".format(
102 @torch.jit.script_method
103 def check_forward_hidden(self, input, hx, hidden_label=''):
105 if input.size(0) != hx.size(0):
107 "Input batch size {} doesn't match hidden{} batch size {}".format(
108 input.size(0), hidden_label, hx.size(0)))
112 "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
119 @torch.jit.script_method
121 self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
122 self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
124 torch.fbgemm_pack_quantized_matrix(
125 self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
128 @torch.jit.script_method
137 __constants__ = [
'input_size',
'hidden_size',
'bias',
'scale_hh',
'scale_ih',
138 'zero_point_ih',
'zero_point_hh',
'nonlinearity']
140 def __init__(self, other):
141 super(QuantizedRNNCell, self).__init__(other)
144 @torch.jit.script_method
145 def forward(self, input, hx=None):
149 hx = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
152 ret = _VF.quantized_rnn_tanh_cell(
153 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
154 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
155 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
159 ret = _VF.quantized_rnn_relu_cell(
160 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
161 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
162 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
173 def __init__(self, other):
174 super(QuantizedLSTMCell, self).__init__(other)
176 @torch.jit.script_method
177 def forward(self, input, hx=None):
181 zeros = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
185 return _VF.quantized_lstm_cell(
186 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
187 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
188 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
194 def __init__(self, other):
195 super(QuantizedGRUCell, self).__init__(other)
197 @torch.jit.script_method
198 def forward(self, input, hx=None):
202 hx = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
204 return _VF.quantized_gru_cell(
205 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
206 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
207 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
212 def quantize_rnn_cell_modules(module):
214 for name, mod
in module.named_modules():
217 new_mod = quantize_rnn_cell_modules(mod)
218 if new_mod
is not mod:
219 reassign[name] = new_mod
220 for name, mod
in reassign.items():
221 setattr(module, name, mod)
222 if isinstance(module, torch.nn.LSTMCell):
224 if isinstance(module, torch.nn.GRUCell):
226 if isinstance(module, torch.nn.RNNCell):
232 def quantize_linear_modules(module):
234 for name, mod
in module.named_modules():
237 new_mod = quantize_linear_modules(mod)
238 if new_mod
is not mod:
239 reassign[name] = new_mod
241 for name, mod
in reassign.items():
242 setattr(module, name, mod)
243 if isinstance(mod, torch.nn.Linear):
def check_forward_input(self, input)
def check_forward_hidden(self, input, hx, hidden_label='')
def annotate(the_type, the_value)