Skip to content

Commit

Permalink
improved api system
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed May 28, 2023
1 parent 039019e commit 0e38a1e
Show file tree
Hide file tree
Showing 23 changed files with 284 additions and 132 deletions.
168 changes: 168 additions & 0 deletions api/diffusion/pipelines/diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from dataclasses import dataclass
import inspect
from typing import *
import numpy as np

import torch
from diffusers import (
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
import PIL.Image

from api.models.diffusion import ImageGenerationOptions
from api.plugin import get_plugin_id


@dataclass
class PipeSession:
plugin_data: Dict[str, Any]
opts: ImageGenerationOptions


class DiffusersPipelineModel:
__mode__ = "diffusers"

@classmethod
def from_pretrained(
cls,
pretrained_model_id: str,
use_auth_token: Optional[str] = None,
torch_dtype: torch.dtype = torch.float32,
cache_dir: Optional[str] = None,
device: Optional[torch.device] = None,
subfolder: Optional[str] = None,
):
pass

def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: DDPMScheduler,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
):
self.vae: AutoencoderKL
self.text_encoder: CLIPTextModel
self.tokenizer: CLIPTokenizer
self.unet: UNet2DConditionModel
self.scheduler: DDPMScheduler
self.device: torch.device
self.dtype: torch.dtype
self.session: PipeSession
pass

def get_plugin_data(self):
id = get_plugin_id(inspect.stack()[1])
return self.session.plugin_data[id]

def set_plugin_data(self, data):
id = get_plugin_id(inspect.stack()[1])
self.session.plugin_data[id] = data

def to(self, device: torch.device = None, dtype: torch.dtype = None):
pass

def enterers(self):
pass

def load_resources(
self,
image_height: int,
image_width: int,
batch_size: int,
num_inference_steps: int,
):
pass

def get_timesteps(self, num_inference_steps: int, strength: Optional[float]):
pass

def get_timesteps(self, num_inference_steps: int, strength: Optional[float]):
pass

def prepare_extra_step_kwargs(self, generator: torch.Generator, eta):
pass

def preprocess_image(self, image: PIL.Image.Image, height: int, width: int):
pass

def _encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, List[str]]] = "",
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
pass

def prepare_latents(
self,
vae_scale_factor: int,
unet_in_channels: int,
image: Optional[torch.Tensor],
timestep: torch.Tensor,
batch_size: int,
height: int,
width: int,
dtype: torch.dtype,
generator: torch.Generator,
latents: torch.Tensor = None,
):
pass

def denoise_latent(
self,
latents: torch.Tensor,
timesteps: torch.Tensor,
num_inference_steps: int,
guidance_scale: float,
do_classifier_free_guidance: bool,
prompt_embeds: torch.Tensor,
extra_step_kwargs: Dict[str, Any],
callback: Optional[Callable],
callback_steps: int,
cross_attention_kwargs: Dict[str, Any],
):
pass

def decode_latents(self, latents: torch.Tensor):
pass

def decode_images(self, image: np.ndarray):
pass

def create_output(self, latents: torch.Tensor, output_type: str, return_dict: bool):
pass

def __call__(
self,
opts: ImageGenerationOptions,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
eta: float = 0.0,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
plugin_data: Optional[Dict[str, Any]] = {},
):
pass

def enable_xformers_memory_efficient_attention(
self, attention_op: Optional[Callable] = None
):
pass

def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
pass
15 changes: 14 additions & 1 deletion api/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

handlers: Dict[str, List[Callable]] = {}

T = TypeVar("T", bound="BaseEvent")


class BaseEvent:
__event_name__: ClassVar[str] = ""
Expand All @@ -15,7 +17,11 @@ def register(cls, handler):
handlers[cls].append(handler)

@classmethod
def call_event(cls, event=None):
def call_event(cls: Type[T], *args, **kwargs) -> T:
if len(args) == 1 and type(args[0]) == cls:
event = args[0]
else:
event = cls(*args, **kwargs)
if event is None:
event = cls()
if not isinstance(event, BaseEvent):
Expand All @@ -33,6 +39,13 @@ def call_event(cls, event=None):

return event

def __call__(self):
fields = self.__dataclass_fields__
results = []
for field in fields:
results.append(getattr(self, field))
return results


@dataclass
class CancellableEvent(BaseEvent):
Expand Down
3 changes: 2 additions & 1 deletion api/events/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from typing import *
from fastapi import FastAPI

from gradio import Blocks

Expand All @@ -13,7 +14,7 @@ class PreAppLaunchEvent(BaseEvent):

@dataclass
class PostAppLaunchEvent(BaseEvent):
pass
app: FastAPI


@dataclass
Expand Down
9 changes: 6 additions & 3 deletions api/events/generation.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,29 @@
from dataclasses import dataclass, field
from typing import *
from typing import Any

import torch

from api.diffusion.pipelines.diffusers import DiffusersPipelineModel

from . import BaseEvent, SkippableEvent


@dataclass
class LoadResourceEvent(BaseEvent):
pipe: Any
pipe: DiffusersPipelineModel


@dataclass
class PromptTokenizingEvent(BaseEvent):
pipe: Any
pipe: DiffusersPipelineModel
text_tokens: List
text_weights: List


@dataclass
class UNetDenoisingEvent(SkippableEvent):
pipe: Any
pipe: DiffusersPipelineModel

latent_model_input: torch.Tensor
step: int
Expand Down
5 changes: 3 additions & 2 deletions api/plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect


def get_plugin_id():
frm = inspect.stack()[1]
def get_plugin_id(frm=None):
if frm is None:
frm = inspect.stack()[1]
mod = inspect.getmodule(frm[0])
return mod.__name__.split(".")[1]
1 change: 1 addition & 0 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def prepare_environment():

sys.argv, reinstall_torch = extract_arg(sys.argv, "--reinstall-torch")
sys.argv, reinstall_xformers = extract_arg(sys.argv, "--reinstall-xformers")
sys.argv, reinstall_tensorrt = extract_arg(sys.argv, "--reinstall-tensorrt")
tensorrt = "--tensorrt" in sys.argv

if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
Expand Down
1 change: 0 additions & 1 deletion lib/tensorrt/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def create_models(
use_auth_token: Optional[str],
device: Union[str, torch.device],
max_batch_size: int,
hf_cache_dir: Optional[str] = None,
unet_in_channels: int = 4,
embedding_dim: int = 768,
):
Expand Down
21 changes: 10 additions & 11 deletions modules/acceleration/tensorrt/engine.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import gc
import os

import torch
import tensorrt
import torch

from api.models.tensorrt import BuildEngineOptions, TensorRTEngineData
from lib.tensorrt.utilities import (
Expand All @@ -19,7 +19,6 @@
load_vae_encoder,
)
from modules.logger import logger
from modules.shared import hf_diffusers_cache_dir


def create_onnx_path(name, onnx_dir, opt=True):
Expand All @@ -35,15 +34,15 @@ def __init__(self, opts: BuildEngineOptions):

unet = load_unet(self.model.model_id)
text_encoder = load_text_encoder(self.model.model_id)
self.models = create_models(
model_id=self.model.model_id,
device=torch.device("cuda"),
use_auth_token=opts.hf_token,
max_batch_size=opts.max_batch_size,
hf_cache_dir=hf_diffusers_cache_dir(),
unet_in_channels=unet.config.in_channels,
embedding_dim=text_encoder.config.hidden_size,
)
self.model_args = {
"model_id": self.model.model_id,
"device": torch.device("cuda"),
"use_auth_token": opts.hf_token,
"max_batch_size": opts.max_batch_size,
"unet_in_channels": unet.config.in_channels,
"embedding_dim": text_encoder.config.hidden_size,
}
self.models = create_models(**self.model_args)
if not opts.full_acceleration:
self.models = {
"unet": self.models["unet"],
Expand Down
7 changes: 3 additions & 4 deletions modules/components/gallery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
import os
from random import randint
from typing import *

import gradio as gr
import gradio.blocks
import gradio.utils


from PIL import Image
import json
import os


def outputs_gallery_info_ui(elem_classes=[], **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions modules/components/image_generation_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ def prompt_ui():
lines=3,
placeholder="Prompt",
show_label=False,
elem_classes=["prompt-textbox"],
)
negative_prompt_textbox = gr.TextArea(
"",
lines=3,
placeholder="Negative Prompt",
show_label=False,
elem_classes=["negative-prompt-textbox"],
)
return prompt_textbox, negative_prompt_textbox

Expand Down
2 changes: 1 addition & 1 deletion modules/diffusion/networks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def restore_networks(*modules: torch.nn.Module):
def load_network_modules(e: LoadResourceEvent):
global latest_networks

opts: ImageGenerationOptions = e.pipe.opts
opts: ImageGenerationOptions = e.pipe.session.opts

positive_networks, opts.prompt = get_networks_from_prompt(opts.prompt)

Expand Down
2 changes: 1 addition & 1 deletion modules/diffusion/pipelines/deepfloyd_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import T5EncoderModel

from modules import config
from modules.shared import hf_diffusers_cache_dir, hf_transformers_cache_dir, get_device
from modules.shared import get_device, hf_diffusers_cache_dir, hf_transformers_cache_dir


class IFDiffusionPipeline:
Expand Down
Loading

0 comments on commit 0e38a1e

Please sign in to comment.