Skip to content

Commit

Permalink
add huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
cene555 committed Nov 23, 2022
1 parent 83b124b commit 6e63e1b
Show file tree
Hide file tree
Showing 21 changed files with 564 additions and 109 deletions.
68 changes: 0 additions & 68 deletions configs/inference.yaml

This file was deleted.

Empty file removed kandinsky/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions kandinsky2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from .kandinsky2_model import Kandinsky2
from huggingface_hub import hf_hub_url, cached_download
from copy import deepcopy

CONFIG = {'model_config':
{'image_size': 64, 'num_channels': 384, 'num_res_blocks': 3, 'channel_mult': '', 'num_heads': 1,
'num_head_channels': 64, 'num_heads_upsample': -1, 'attention_resolutions': '32,16,8', 'dropout': 0,
'model_dim': 768, 'use_scale_shift_norm': True, 'resblock_updown': True, 'use_fp16': True,
'cache_text_emb': True, 'text_encoder_in_dim1': 1024, 'text_encoder_in_dim2': 640,
'pooling_type': 'from_model', 'in_channels': 4, 'out_channels': 8, 'up': False, 'inpainting': False},

'diffusion_config': {'learn_sigma': True, 'sigma_small': False, 'steps': 1000, 'noise_schedule': 'linear',
'timestep_respacing': '', 'use_kl': False, 'predict_xstart': False,
'rescale_timesteps': True, 'rescale_learned_sigmas': True},
'image_enc_params': {'name': 'AutoencoderKL', 'scale': 0.0512,
'params': {'ckpt_path': '', 'embed_dim': 4,
'ddconfig': {'double_z': True, 'z_channels': 4, 'resolution': 256,
'in_channels': 3, 'out_ch': 3, 'ch': 128,
'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2,
'attn_resolutions': [], 'dropout': 0.0}}},
'text_enc_params1': {'model_path': '', 'model_name': 'multiclip'},
'text_enc_params2': {'model_path': '', 'model_name': 'MT5EncoderModel'},
'tokenizer_name1': '',
'tokenizer_name2': ''}

def get_kandinsky2(device, task_type='text2img', cache_dir='/tmp/kandinsky2', use_auth_token=None):
config = deepcopy(CONFIG)
if task_type == 'inpainting':
model_name = 'Kandinsky-2-0-inpainting.pt'
config_file_url = hf_hub_url(repo_id='sberbank-ai/Kandinsky_2.0', filename=model_name)
elif task_type == 'text2img' or inpainting == 'img2img':
model_name = 'Kandinsky-2-0.pt'
config_file_url = hf_hub_url(repo_id='sberbank-ai/Kandinsky_2.0', filename=model_name)
else:
raise ValueError('Only text2img, img2img and inpainting is available')

cached_download(config_file_url, cache_dir=cache_dir, force_filename=model_name,
use_auth_token=use_auth_token)

cache_dir_text_en1 = os.path.join(cache_dir, 'text_encoder1')
cache_dir_text_en2 = os.path.join(cache_dir, 'text_encoder2')
for name in ['config.json', 'pytorch_model.bin', 'sentencepiece.bpe.model', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json']:
config_file_url = hf_hub_url(repo_id='sberbank-ai/Kandinsky_2.0', filename=f'text_encoder1/{name}')
cached_download(config_file_url, cache_dir=cache_dir_text_en1, force_filename=name,
use_auth_token=use_auth_token)

for name in ['config.json', 'pytorch_model.bin', 'spiece.model', 'special_tokens_map.json', 'tokenizer_config.json']:
config_file_url = hf_hub_url(repo_id='sberbank-ai/Kandinsky_2.0', filename=f'text_encoder2/{name}')
cached_download(config_file_url, cache_dir=cache_dir_text_en2, force_filename=name,
use_auth_token=use_auth_token)
config_file_url = hf_hub_url(repo_id='sberbank-ai/Kandinsky_2.0', filename='vae.ckpt')
cached_download(config_file_url, cache_dir=cache_dir, force_filename='vae.ckpt',
use_auth_token=use_auth_token)

config['text_enc_params1']['model_path'] = cache_dir_text_en1
config['text_enc_params2']['model_path'] = cache_dir_text_en2
config['tokenizer_name1'] = cache_dir_text_en1
config['tokenizer_name2'] = cache_dir_text_en2
config['image_enc_params']['params']['ckpt_path'] = os.path.join(cache_dir, 'vae.ckpt')
unet_path = os.path.join(cache_dir, model_name)


model = Kandinsky2(config, unet_path, device, task_type)
return model
17 changes: 7 additions & 10 deletions kandinsky/kandinsky_model.py → kandinsky2/kandinsky2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import cv2
import torch
from omegaconf import OmegaConf
import clip
import math
from .model.text_encoders import TextEncoder
from .vqgan.autoencoder import VQModelInterface, AutoencoderKL
Expand All @@ -14,9 +13,9 @@
from .utils import prepare_image, q_sample, process_images, prepare_mask


class Natalle:
def __init__(self, config_path, model_path, device, task_type='text2img', vae_path=None):
self.config = dict(OmegaConf.load(config_path))
class Kandinsky2:
def __init__(self, config, model_path, device, task_type='text2img'):
self.config = config
self.device = device
self.task_type = task_type
if task_type == 'text2img' or task_type == 'img2img':
Expand All @@ -34,8 +33,6 @@ def __init__(self, config_path, model_path, device, task_type='text2img', vae_pa
self.text_encoder2 = TextEncoder(**self.config['text_enc_params2']).to(self.device).eval()

self.use_fp16 = self.config['model_config']['use_fp16']
if vae_path is not None:
self.config['image_enc_params']['params']['ckpt_path'] = vae_path

if self.config['image_enc_params'] is not None:
self.use_image_enc = True
Expand Down Expand Up @@ -174,7 +171,7 @@ def denoised_fn(x_start, ):
eta=ddim_eta
)[:batch_size]
else:
print(1 / 0)
raise ValueError('Only p_sampler and ddim_sampler is available')
self.model.del_cache()
if self.use_image_enc:
if self.use_fp16:
Expand All @@ -187,7 +184,7 @@ def denoised_fn(x_start, ):
def generate_text2img(self, prompt, num_steps=100,
batch_size=1, guidance_scale=7, progress=True,
dynamic_threshold_v=99.5, denoised_type='dynamic_threshold', h=512, w=512
, sampler='ddim_sampler', ddim_eta=0.8):
, sampler='ddim_sampler', ddim_eta=0.05):
config = deepcopy(self.config)
config['diffusion_config']['timestep_respacing'] = str(num_steps)
if sampler == 'ddim_sampler':
Expand All @@ -203,7 +200,7 @@ def generate_text2img(self, prompt, num_steps=100,
def generate_img2img(self, prompt, pil_img, strength=0.7,
num_steps=100, guidance_scale=7, progress=True,
dynamic_threshold_v=99.5, denoised_type='dynamic_threshold'
, sampler='ddim_sampler', ddim_eta=0.8):
, sampler='ddim_sampler', ddim_eta=0.05):

config = deepcopy(self.config)
config['diffusion_config']['timestep_respacing'] = str(num_steps)
Expand All @@ -229,7 +226,7 @@ def generate_img2img(self, prompt, pil_img, strength=0.7,
def generate_inpainting(self, prompt, pil_img, img_mask,
num_steps=100, guidance_scale=7, progress=True,
dynamic_threshold_v=99.5, denoised_type='dynamic_threshold',
sampler='ddim_sampler', ddim_eta=0.8):
sampler='ddim_sampler', ddim_eta=0.05):
config = deepcopy(self.config)
config['diffusion_config']['timestep_respacing'] = str(num_steps)
if sampler == 'ddim_sampler':
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch.nn.functional as F
import math
from transformers import T5EncoderModel, MT5EncoderModel, BertModel, XLMRobertaModel
import clip
import transformers
import os

def attention(q, k, v, d_k):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
Expand Down Expand Up @@ -88,32 +88,18 @@ def forward(self, text, mask=None):
pooled_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x, pooled_out

class MCLIPConfig(transformers.PretrainedConfig):
model_type = "M-CLIP"

def __init__(self, modelBase='xlm-roberta-large', transformerDimSize=1024, imageDimSize=768, **kwargs):
self.transformerDimensions = transformerDimSize
self.numDims = imageDimSize
self.modelBase = modelBase
super().__init__(**kwargs)

class MultilingualCLIP(transformers.PreTrainedModel):
config_class = MCLIPConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
out_features=config.numDims)
class MultilingualCLIP(nn.Module):
def __init__(self, config):
super().__init__()
self.transformer = transformers.AutoModel.from_pretrained(config)
self.LinearTransformation = torch.nn.Linear(in_features=1024,
out_features=640)


def forward(self, input_ids, attention_mask):
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None]
return self.LinearTransformation(embs2), embs

@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
model.load_state_dict(state_dict)
return model, [], [], []

class TextEncoder(nn.Module):
def __init__(self, model_path, model_name):
Expand All @@ -129,7 +115,8 @@ def __init__(self, model_path, model_name):
elif self.model_name == 'BertModel':
self.model = BertModel.from_pretrained(model_path)
elif self.model_name == 'multiclip':
self.model = MultilingualCLIP.from_pretrained(model_path)
self.model = MultilingualCLIP(model_path)
self.model.load_state_dict(torch.load(os.path.join(model_path, 'pytorch_model.bin')), strict=False)
elif self.model_name == 'xlm_roberta':
self.model = XLMRobertaModel.from_pretrained(model_path).half()
self.model.eval()
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 6e63e1b

Please sign in to comment.