forked from salesforce/UniControl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
59 lines (49 loc) · 2.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
'''
* Copyright (c) 2023 Salesforce, Inc.
* All rights reserved.
* SPDX-License-Identifier: Apache License 2.0
* For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
* By Can Qin
* Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet
* Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala
'''
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image
import cv2
import numpy as np
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
img = wm_encoder.encode(img, 'dwtDct')
img = Image.fromarray(img[:, :, ::-1])
return img
def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
y = (np.array(y)/255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
return x
def check_safety(x_image):
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
x_checked_image[i] = load_replacement(x_checked_image[i])
return x_checked_image, has_nsfw_concept