Skip to content

Commit

Permalink
implement hypernetwork
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed May 29, 2023
1 parent c1e2428 commit 53d60b5
Show file tree
Hide file tree
Showing 11 changed files with 137 additions and 28 deletions.
2 changes: 2 additions & 0 deletions models/vae/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
9 changes: 7 additions & 2 deletions modules/diffusion/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from glob import glob
from typing import *

import losalina.hypernetwork
import torch

from api.events import event_handler
from api.events.generation import LoadResourceEvent
from api.models.diffusion import ImageGenerationOptions
from modules.diffusion.attentions import replace_attentions_for_hypernetwork
from modules.logger import logger
from modules.shared import ROOT_DIR

from . import lora, lyco
from .hypernetwork import hijack_hypernetwork

latest_networks: List[Tuple[str, torch.nn.Module]] = []

Expand Down Expand Up @@ -87,6 +88,9 @@ def load_network_modules(e: LoadResourceEvent):
elif module_type == "lyco":
filepath = find_network_filepath(basename, "lycoris")
network_module = lyco
elif module_type == "hypernet":
filepath = find_network_filepath(basename, "hypernetwork")
network_module = losalina.hypernetwork
else:
continue

Expand All @@ -97,6 +101,7 @@ def load_network_modules(e: LoadResourceEvent):
network, weights_sd = network_module.create_network_from_weights(
multiplier,
filepath,
e.pipe.vae,
e.pipe.text_encoder,
e.pipe.unet,
)
Expand All @@ -116,4 +121,4 @@ def load_network_modules(e: LoadResourceEvent):


def init():
replace_attentions_for_hypernetwork()
hijack_hypernetwork()
50 changes: 50 additions & 0 deletions modules/diffusion/networks/hypernetwork/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from losalina import hypernetwork

loaded_networks = []


def apply_single_hypernetwork(
hypernetwork: hypernetwork, hidden_states, encoder_hidden_states
):
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
return context_k, context_v


def apply_hypernetworks(context_k, context_v, layer=None):
if len(loaded_networks) == 0:
return context_v, context_v
for hypernetwork in loaded_networks:
context_k, context_v = hypernetwork.forward(context_k, context_v)

context_k = context_k.to(dtype=context_k.dtype)
context_v = context_v.to(dtype=context_k.dtype)

return context_k, context_v


def apply_to(self: hypernetwork):
loaded_networks.append(self)


def restore(self: hypernetwork, *args):
loaded_networks.clear()


def hijack_hypernetwork():
hypernetwork.Hypernetwork.apply_to = apply_to
hypernetwork.Hypernetwork.restore = restore

import diffusers.models.attention_processor

from . import attentions

# replace the forward function of the attention processors for the hypernetworks
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
attentions.xformers_forward
)
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
attentions.sliced_attn_forward
)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = (
attentions.v2_0_forward
)
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import (
Attention,
AttnProcessor2_0,
SlicedAttnProcessor,
XFormersAttnProcessor,
)
Expand All @@ -10,6 +12,8 @@
except:
xformers = None

from . import apply_hypernetworks


def xformers_forward(
self: XFormersAttnProcessor,
Expand All @@ -35,16 +39,7 @@ def xformers_forward(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

# Apply hypernetwork if present
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
context_k, context_v = attn.hypernetwork.forward(
hidden_states, encoder_hidden_states
)
context_k = context_k.to(hidden_states.dtype)
context_v = context_v.to(hidden_states.dtype)
else:
context_k = encoder_hidden_states
context_v = encoder_hidden_states
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)

key = attn.to_k(context_k)
value = attn.to_v(context_v)
Expand Down Expand Up @@ -96,16 +91,7 @@ def sliced_attn_forward(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

# Apply hypernetwork if present
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
context_k, context_v = attn.hypernetwork.forward(
hidden_states, encoder_hidden_states
)
context_k = context_k.to(hidden_states.dtype)
context_v = context_v.to(hidden_states.dtype)
else:
context_k = encoder_hidden_states
context_v = encoder_hidden_states
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)

key = attn.to_k(context_k)
value = attn.to_v(context_v)
Expand Down Expand Up @@ -145,6 +131,65 @@ def sliced_attn_forward(
return hidden_states


def v2_0_forward(
self: AttnProcessor2_0,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)

key = attn.to_k(context_k)
value = attn.to_v(context_v)

head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states


def replace_attentions_for_hypernetwork():
import diffusers.models.attention_processor

Expand All @@ -154,3 +199,4 @@ def replace_attentions_for_hypernetwork():
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
sliced_attn_forward
)
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
3 changes: 2 additions & 1 deletion modules/diffusion/networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def create_network_from_weights(
multiplier: float,
file: str,
vae,
text_encoder,
unet,
weights_sd: torch.Tensor = None,
Expand Down Expand Up @@ -256,7 +257,7 @@ def set_multiplier(self, multiplier):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier

def apply_to(self, **kwargs):
def apply_to(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()

Expand Down
1 change: 1 addition & 0 deletions modules/diffusion/networks/lyco.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def create_network_from_weights(
multiplier: float,
file: str,
vae,
text_encoder,
unet,
weights_sd: torch.Tensor = None,
Expand Down
3 changes: 1 addition & 2 deletions modules/diffusion/pipelines/diffusers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import *
import gc
import inspect
import os
Expand All @@ -7,7 +6,6 @@

import numpy as np
import PIL.Image
from safetensors.torch import load_file
import torch
from diffusers import (
AutoencoderKL,
Expand All @@ -20,6 +18,7 @@
convert_from_ckpt,
)
from diffusers.utils import PIL_INTERPOLATION, numpy_to_pil, randn_tensor
from safetensors.torch import load_file
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

Expand Down
6 changes: 5 additions & 1 deletion modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import *

import torch
from packaging.version import Version

from api.models.diffusion import ImageGenerationOptions
from lib.diffusers.scheduler import SCHEDULERS, parser_schedulers_config
Expand Down Expand Up @@ -95,7 +96,10 @@ def activate(self):
torch_dtype=torch_dtype,
cache_dir=hf_diffusers_cache_dir(),
).to(device=device)
self.pipe.enable_attention_slicing()

if Version(torch.__version__) < Version("2"):
self.pipe.enable_attention_slicing()

if (
utils.is_installed("xformers")
and config.get("xformers")
Expand Down
2 changes: 1 addition & 1 deletion modules/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ def list_vae_models():
os.makedirs(dir, exist_ok=True)
return os.listdir(dir)


def resolve_vae(name: str):
return os.path.join(ROOT_DIR, "models", "vae", name)

2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
- [ ] ~Acceleration using AITemplate~ -> Use [`VoltaML fast stable diffusion webui`](https://github.com/VoltaML/voltaML-fast-stable-diffusion)
- [x] ControlNet -> [plugin](https://github.com/ddPn08/radiata-controlnet-plugin)
- [x] Lora & Lycoris
- [ ] Hypernetwork
- [x] Hypernetwork
- [ ] Composable lora
- [ ] Latent couple

Expand Down
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ accelerate==0.18.0
diffusers==0.16.1
transformers==4.28.1
sentencepiece~=0.1
losalina==1.0.0
omegaconf
safetensors

Expand Down

0 comments on commit 53d60b5

Please sign in to comment.