2
2
3
3
import torch
4
4
5
- from ..tensor import QBitsTensor , QTensor , absmax_scale , qbitsdtype
5
+ from ..tensor import QBitsTensor , QTensor , absmax_scale , qint2 , qint4 , qint8 , qtype
6
6
from .qmodule import QModuleMixin , register_qmodule
7
7
8
8
11
11
12
12
@register_qmodule (torch .nn .Linear )
13
13
class QLinear (QModuleMixin , torch .nn .Linear ):
14
- def __init__ (self , * args , weights : torch . dtype = torch . int8 , ** kwargs ):
14
+ def __init__ (self , * args , weights : qtype = qint8 , ** kwargs ):
15
15
super ().__init__ (* args , ** kwargs )
16
16
self .weights = weights
17
17
18
18
@classmethod
19
- def from_module (cls , module , weights = torch . int8 , activations : Optional [torch . dtype ] = None ):
19
+ def from_module (cls , module , weights = qint8 , activations : Optional [qtype ] = None ):
20
20
qmodule = cls (
21
21
module .in_features ,
22
22
module .out_features ,
@@ -36,12 +36,12 @@ def qweight(self):
36
36
if isinstance (self .weight , QTensor ):
37
37
return self .weight
38
38
# Quantize the weights per-axis
39
- if isinstance ( self .weights , torch . dtype ) :
39
+ if self .weights == qint8 :
40
40
wscale = absmax_scale (self .weight , axis = 0 )
41
- return QTensor .quantize (self .weight , itype = self .weights , scale = wscale )
42
- elif isinstance ( self .weights , qbitsdtype ):
43
- return QBitsTensor .quantize (self .weight , itype = self .weights , axis = 0 )
44
- raise ValueError ("Invalid quantized weights type" )
41
+ return QTensor .quantize (self .weight , qtype = self .weights , scale = wscale )
42
+ elif self .weights in ( qint2 , qint4 ):
43
+ return QBitsTensor .quantize (self .weight , qtype = self .weights , axis = 0 )
44
+ raise ValueError (f "Invalid quantized weights type { self . weights } " )
45
45
46
46
def qforward (self , input : torch .Tensor ) -> torch .Tensor :
47
47
if self .activations is not None and not isinstance (input , QTensor ):
0 commit comments