-
Notifications
You must be signed in to change notification settings - Fork 409
/
buckets.py
174 lines (151 loc) · 5.74 KB
/
buckets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from typing import Type, List, Union, TypedDict
class BucketResolution(TypedDict):
width: int
height: int
# resolutions SDXL was trained on with a 1024x1024 base resolution
resolutions_1024: List[BucketResolution] = [
# SDXL Base resolution
{"width": 1024, "height": 1024},
# SDXL Resolutions, widescreen
{"width": 2048, "height": 512},
{"width": 1984, "height": 512},
{"width": 1920, "height": 512},
{"width": 1856, "height": 512},
{"width": 1792, "height": 576},
{"width": 1728, "height": 576},
{"width": 1664, "height": 576},
{"width": 1600, "height": 640},
{"width": 1536, "height": 640},
{"width": 1472, "height": 704},
{"width": 1408, "height": 704},
{"width": 1344, "height": 704},
{"width": 1344, "height": 768},
{"width": 1280, "height": 768},
{"width": 1216, "height": 832},
{"width": 1152, "height": 832},
{"width": 1152, "height": 896},
{"width": 1088, "height": 896},
{"width": 1088, "height": 960},
{"width": 1024, "height": 960},
# SDXL Resolutions, portrait
{"width": 960, "height": 1024},
{"width": 960, "height": 1088},
{"width": 896, "height": 1088},
{"width": 896, "height": 1152}, # 2:3
{"width": 832, "height": 1152},
{"width": 832, "height": 1216},
{"width": 768, "height": 1280},
{"width": 768, "height": 1344},
{"width": 704, "height": 1408},
{"width": 704, "height": 1472},
{"width": 640, "height": 1536},
{"width": 640, "height": 1600},
{"width": 576, "height": 1664},
{"width": 576, "height": 1728},
{"width": 576, "height": 1792},
{"width": 512, "height": 1856},
{"width": 512, "height": 1920},
{"width": 512, "height": 1984},
{"width": 512, "height": 2048},
# extra wides
{"width": 8192, "height": 128},
{"width": 128, "height": 8192},
]
# Even numbers so they can be patched easier
resolutions_dit_1024: List[BucketResolution] = [
# Base resolution
{"width": 1024, "height": 1024},
# widescreen
{"width": 2048, "height": 512},
{"width": 1792, "height": 576},
{"width": 1728, "height": 576},
{"width": 1664, "height": 576},
{"width": 1600, "height": 640},
{"width": 1536, "height": 640},
{"width": 1472, "height": 704},
{"width": 1408, "height": 704},
{"width": 1344, "height": 704},
{"width": 1344, "height": 768},
{"width": 1280, "height": 768},
{"width": 1216, "height": 832},
{"width": 1152, "height": 832},
{"width": 1152, "height": 896},
{"width": 1088, "height": 896},
{"width": 1088, "height": 960},
{"width": 1024, "height": 960},
# portrait
{"width": 960, "height": 1024},
{"width": 960, "height": 1088},
{"width": 896, "height": 1088},
{"width": 896, "height": 1152}, # 2:3
{"width": 832, "height": 1152},
{"width": 832, "height": 1216},
{"width": 768, "height": 1280},
{"width": 768, "height": 1344},
{"width": 704, "height": 1408},
{"width": 704, "height": 1472},
{"width": 640, "height": 1536},
{"width": 640, "height": 1600},
{"width": 576, "height": 1664},
{"width": 576, "height": 1728},
{"width": 576, "height": 1792},
{"width": 512, "height": 1856},
{"width": 512, "height": 1920},
{"width": 512, "height": 1984},
{"width": 512, "height": 2048},
]
def get_bucket_sizes(resolution: int = 512, divisibility: int = 8) -> List[BucketResolution]:
# determine scaler form 1024 to resolution
scaler = resolution / 1024
bucket_size_list = []
for bucket in resolutions_1024:
# must be divisible by 8
width = int(bucket["width"] * scaler)
height = int(bucket["height"] * scaler)
if width % divisibility != 0:
width = width - (width % divisibility)
if height % divisibility != 0:
height = height - (height % divisibility)
bucket_size_list.append({"width": width, "height": height})
return bucket_size_list
def get_resolution(width, height):
num_pixels = width * height
# determine same number of pixels for square image
square_resolution = int(num_pixels ** 0.5)
return square_resolution
def get_bucket_for_image_size(
width: int,
height: int,
bucket_size_list: List[BucketResolution] = None,
resolution: Union[int, None] = None,
divisibility: int = 8
) -> BucketResolution:
if bucket_size_list is None and resolution is None:
# get resolution from width and height
resolution = get_resolution(width, height)
if bucket_size_list is None:
# if real resolution is smaller, use that instead
real_resolution = get_resolution(width, height)
resolution = min(resolution, real_resolution)
bucket_size_list = get_bucket_sizes(resolution=resolution, divisibility=divisibility)
# Check for exact match first
for bucket in bucket_size_list:
if bucket["width"] == width and bucket["height"] == height:
return bucket
# If exact match not found, find the closest bucket
closest_bucket = None
min_removed_pixels = float("inf")
for bucket in bucket_size_list:
scale_w = bucket["width"] / width
scale_h = bucket["height"] / height
# To minimize pixels, we use the larger scale factor to minimize the amount that has to be cropped.
scale = max(scale_w, scale_h)
new_width = int(width * scale)
new_height = int(height * scale)
removed_pixels = (new_width - bucket["width"]) * new_height + (new_height - bucket["height"]) * new_width
if removed_pixels < min_removed_pixels:
min_removed_pixels = removed_pixels
closest_bucket = bucket
if closest_bucket is None:
raise ValueError("No suitable bucket found")
return closest_bucket