4 from typing 
import Tuple, Optional
     5 from torch 
import Tensor
    13     __constants__ = [
'scale', 
'zero_point']
    15     def __init__(self, other):
    16         super(QuantizedLinear, self).__init__()
    21             other.weight.clone().float())
    22         self.
weight = torch.nn.Parameter(self.
weight, requires_grad=
False)
    24         assert other.bias 
is not None, 
'QuantizedLinear requires a bias'    25         self.
bias = torch.nn.Parameter(other.bias.clone().float())
    29             torch.fbgemm_pack_quantized_matrix(self.weight.clone(), self.weight.size(1), self.weight.size(0)))
    31     @torch.jit.script_method
    33         self.packed_tensor_ptr.set_(
    34             torch.fbgemm_pack_quantized_matrix(
    35                 self.
weight, self.weight.size(1), self.weight.size(0)))
    37     @torch.jit.script_method
    39         self.packed_tensor_ptr.set_(
    42     @torch.jit.script_method
    43     def forward(self, input):
    44         out = torch.fbgemm_linear_int8_weight(
    47         return out.type_as(input)
    50         repr = 
'in_features={in_features}, out_features={out_features}, ' \
    51                'scale={scale}, zero_point={zero_point}'.format(**self.__dict__)
    57     __constants__ = [
'input_size', 
'hidden_size', 
'bias', 
'scale_hh', 
'scale_ih',
    58                      'zero_point_ih', 
'zero_point_hh']
    60     def __init__(self, other):
    61         super(QuantizedRNNCellBase, self).__init__()
    64         self.
bias = other.bias
    66             raise ValueError(
"Quantized RNN cells require bias terms")
    68         weight_ih, col_offsets_ih, self.scale_ih, self.
zero_point_ih = \
    69             torch.fbgemm_linear_quantize_weight(other.weight_ih.clone().float())
    70         self.register_buffer(
'weight_ih', weight_ih)
    71         self.register_buffer(
'col_offsets_ih', col_offsets_ih)
    72         weight_hh, col_offsets_hh, self.scale_hh, self.
zero_point_hh = \
    73             torch.fbgemm_linear_quantize_weight(other.weight_hh.clone().float())
    74         self.register_buffer(
'weight_hh', weight_hh)
    75         self.register_buffer(
'col_offsets_hh', col_offsets_hh)
    77         packed_ih = torch.fbgemm_pack_quantized_matrix(
    78             self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0))
    79         self.register_buffer(
'packed_ih', packed_ih)
    80         packed_hh = torch.fbgemm_pack_quantized_matrix(
    81             self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0))
    82         self.register_buffer(
'packed_hh', packed_hh)
    84         self.
bias_ih = torch.nn.Parameter(other.bias_ih.clone().float(), requires_grad=
False)
    85         self.
bias_hh = torch.nn.Parameter(other.bias_hh.clone().float(), requires_grad=
False)
    88         s = 
'{input_size}, {hidden_size}'    89         if 'bias' in self.__dict__ 
and self.
bias is not True:
    91         if 'nonlinearity' in self.__dict__ 
and self.nonlinearity != 
"tanh":
    92             s += 
', nonlinearity={nonlinearity}'    93         return s.format(**self.__dict__)
    95     @torch.jit.script_method
    96     def check_forward_input(self, input):
    99                 "input has inconsistent input_size: got {}, expected {}".format(
   102     @torch.jit.script_method
   103     def check_forward_hidden(self, input, hx, hidden_label=''):
   105         if input.size(0) != hx.size(0):
   107                 "Input batch size {} doesn't match hidden{} batch size {}".format(
   108                     input.size(0), hidden_label, hx.size(0)))
   112                 "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
   119     @torch.jit.script_method
   121         self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(
   122             self.weight_ih, self.weight_ih.size(1), self.weight_ih.size(0)))
   124             torch.fbgemm_pack_quantized_matrix(
   125                 self.weight_hh, self.weight_hh.size(1), self.weight_hh.size(0)))
   128     @torch.jit.script_method
   137     __constants__ = [
'input_size', 
'hidden_size', 
'bias', 
'scale_hh', 
'scale_ih',
   138                      'zero_point_ih', 
'zero_point_hh', 
'nonlinearity']
   140     def __init__(self, other):
   141         super(QuantizedRNNCell, self).__init__(other)
   144     @torch.jit.script_method
   145     def forward(self, input, hx=None):
   149             hx = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
   152             ret = _VF.quantized_rnn_tanh_cell(
   153                 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
   154                 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
   155                 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
   159             ret = _VF.quantized_rnn_relu_cell(
   160                 input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
   161                 self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
   162                 self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
   173     def __init__(self, other):
   174         super(QuantizedLSTMCell, self).__init__(other)
   176     @torch.jit.script_method
   177     def forward(self, input, hx=None):
   181             zeros = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
   185         return _VF.quantized_lstm_cell(
   186             input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
   187             self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
   188             self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
   194     def __init__(self, other):
   195         super(QuantizedGRUCell, self).__init__(other)
   197     @torch.jit.script_method
   198     def forward(self, input, hx=None):
   202             hx = torch.zeros(input.size(0), self.
hidden_size, dtype=input.dtype, device=input.device)
   204         return _VF.quantized_gru_cell(
   205             input, hx, self.weight_ih, self.weight_hh, self.
bias_ih,
   206             self.
bias_hh, self.packed_ih, self.packed_hh, self.col_offsets_ih,
   207             self.col_offsets_hh, self.scale_ih, self.scale_hh, self.
zero_point_ih,
   212 def quantize_rnn_cell_modules(module):
   214     for name, mod 
in module.named_modules():
   217         new_mod = quantize_rnn_cell_modules(mod)
   218         if new_mod 
is not mod:
   219             reassign[name] = new_mod
   220     for name, mod 
in reassign.items():
   221         setattr(module, name, mod)
   222     if isinstance(module, torch.nn.LSTMCell):
   224     if isinstance(module, torch.nn.GRUCell):
   226     if isinstance(module, torch.nn.RNNCell):
   232 def quantize_linear_modules(module):
   234     for name, mod 
in module.named_modules():
   237         new_mod = quantize_linear_modules(mod)
   238         if new_mod 
is not mod:
   239             reassign[name] = new_mod
   241     for name, mod 
in reassign.items():
   242         setattr(module, name, mod)
   243     if isinstance(mod, torch.nn.Linear):
 
def check_forward_input(self, input)
 
def check_forward_hidden(self, input, hx, hidden_label='')
 
def annotate(the_type, the_value)