-
-
Notifications
You must be signed in to change notification settings - Fork 454
/
Copy pathmodel_quant.py
130 lines (116 loc) · 4.81 KB
/
model_quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys
import diffusers
from installer import install, log
bnb = None
quanto = None
ao = None
def create_bnb_config(kwargs = None, allow_bnb: bool = True):
from modules import shared, devices
if len(shared.opts.bnb_quantization) > 0 and allow_bnb:
if 'Model' in shared.opts.bnb_quantization:
load_bnb()
if bnb is None:
return kwargs
bnb_config = diffusers.BitsAndBytesConfig(
load_in_8bit=shared.opts.bnb_quantization_type in ['fp8'],
load_in_4bit=shared.opts.bnb_quantization_type in ['nf4', 'fp4'],
bnb_4bit_quant_storage=shared.opts.bnb_quantization_storage,
bnb_4bit_quant_type=shared.opts.bnb_quantization_type,
bnb_4bit_compute_dtype=devices.dtype
)
shared.log.debug(f'Quantization: module=all type=bnb dtype={shared.opts.bnb_quantization_type} storage={shared.opts.bnb_quantization_storage}')
if kwargs is None:
return bnb_config
else:
kwargs['quantization_config'] = bnb_config
return kwargs
return kwargs
def create_ao_config(kwargs = None, allow_ao: bool = True):
from modules import shared
if len(shared.opts.torchao_quantization) > 0 and shared.opts.torchao_quantization_mode == 'pre' and allow_ao:
if 'Model' in shared.opts.torchao_quantization:
load_torchao()
if ao is None:
return kwargs
diffusers.utils.import_utils.is_torchao_available = lambda: True
ao_config = diffusers.TorchAoConfig(shared.opts.torchao_quantization_type)
shared.log.debug(f'Quantization: module=all type=torchao dtype={shared.opts.torchao_quantization_type}')
if kwargs is None:
return ao_config
else:
kwargs['quantization_config'] = ao_config
return kwargs
return kwargs
def load_torchao(msg='', silent=False):
global ao # pylint: disable=global-statement
if ao is not None:
return ao
install('torchao==0.7.0', quiet=True)
try:
import torchao
ao = torchao
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=torchao version={ao.__version__} fn={fn}') # pylint: disable=protected-access
return ao
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import torchao: {e}")
ao = None
if not silent:
raise
return None
def load_bnb(msg='', silent=False):
global bnb # pylint: disable=global-statement
if bnb is not None:
return bnb
install('bitsandbytes==0.45.0', quiet=True)
try:
import bitsandbytes
bnb = bitsandbytes
diffusers.utils.import_utils._bitsandbytes_available = True # pylint: disable=protected-access
diffusers.utils.import_utils._bitsandbytes_version = '0.43.3' # pylint: disable=protected-access
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=bitsandbytes version={bnb.__version__} fn={fn}') # pylint: disable=protected-access
return bnb
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import bitsandbytes: {e}")
bnb = None
if not silent:
raise
return None
def load_quanto(msg='', silent=False):
from modules import shared
global quanto # pylint: disable=global-statement
if quanto is not None:
return quanto
install('optimum-quanto==0.2.6', quiet=True)
try:
from optimum import quanto as optimum_quanto # pylint: disable=no-name-in-module
quanto = optimum_quanto
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'Quantization: type=quanto version={quanto.__version__} fn={fn}') # pylint: disable=protected-access
if shared.opts.diffusers_offload_mode in {'balanced', 'sequential'}:
shared.log.error(f'Quantization: type=quanto offload={shared.opts.diffusers_offload_mode} not supported')
return quanto
except Exception as e:
if len(msg) > 0:
log.error(f"{msg} failed to import optimum.quanto: {e}")
quanto = None
if not silent:
raise
return None
def get_quant(name):
if "qint8" in name.lower():
return 'qint8'
if "qint4" in name.lower():
return 'qint4'
if "fp8" in name.lower():
return 'fp8'
if "fp4" in name.lower():
return 'fp4'
if "nf4" in name.lower():
return 'nf4'
if name.endswith('.gguf'):
return 'gguf'
return 'none'