forked from vladmandic/sdnext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patht_gate.py
45 lines (40 loc) · 1.93 KB
/
t_gate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
from modules import scripts, processing, shared, sd_models, devices
from installer import install
class Script(scripts.Script):
def title(self):
return 'T-Gate: Accelerate via Gating Attention'
def show(self, is_img2img):
return not is_img2img if shared.native else False
# return signature is array of gradio components
def ui(self, _is_img2img):
with gr.Row():
gr.HTML('<a href="https://github.com/HaozheLiu-ST/T-GATE">  T-Gate: Accelerate via Gating Attention</a><br>')
with gr.Row():
enabled = gr.Checkbox(label="Enabled", value=True)
with gr.Row():
gate_step = gr.Slider(minimum=1, maximum=50, step=1, label="Gate step", elem_id="t_gate_steps", value=10)
return [enabled, gate_step]
def run(self, p: processing.StableDiffusionProcessing, enabled, gate_step): # pylint: disable=arguments-differ
p.gate_step = min(gate_step, p.steps) if enabled else -1
if not enabled:
return None
install('tgate')
import tgate
if shared.sd_model_type == 'sd':
cls = tgate.TgateSDLoader
elif shared.sd_model_type == 'sdxl':
cls = tgate.TgateSDXLLoader
else:
shared.log.warning(f'T-Gate: pipeline={shared.sd_model_type} required=sd or sdxl')
return None
old_pipe = shared.sd_model
shared.sd_model = cls(shared.sd_model, gate_step=p.gate_step)
sd_models.copy_diffuser_options(shared.sd_model, old_pipe)
sd_models.move_model(shared.sd_model, devices.device) # move pipeline to device
sd_models.set_diffuser_options(shared.sd_model, vae=None, op='model')
shared.log.debug(f'T-Gate: pipeline={shared.sd_model.__class__.__name__} steps={p.gate_step}')
processed = processing.process_images(p)
shared.sd_model = old_pipe
del shared.sd_model.tgate
return processed