forked from Linaqruf/kohya-trainer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmerge_block_weighted.py
114 lines (93 loc) · 3.89 KB
/
merge_block_weighted.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# original code: https://github.com/eyriewow/merge-models
import os
import argparse
import re
import torch
from tqdm import tqdm
NUM_INPUT_BLOCKS = 12
NUM_MID_BLOCK = 1
NUM_OUTPUT_BLOCKS = 12
NUM_TOTAL_BLOCKS = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + NUM_OUTPUT_BLOCKS
def merge(args):
if args.weights is None:
weights = None
else:
weights = [float(w) for w in args.weights.split(',')]
if len(weights) != NUM_TOTAL_BLOCKS:
print(f"weights value must be {NUM_TOTAL_BLOCKS}.")
return
device = args.device
print("loading", args.model_0)
model_0 = torch.load(args.model_0, map_location=device)
print("loading", args.model_1)
model_1 = torch.load(args.model_1, map_location=device)
theta_0 = model_0["state_dict"]
theta_1 = model_1["state_dict"]
alpha = args.base_alpha
output_file = f'{args.output}-{str(alpha)[2:] + "0"}-bw.ckpt'
# check if output file already exists, ask to overwrite
if os.path.isfile(output_file):
print("Output file already exists. Overwrite? (y/n)")
while True:
overwrite = input()
if overwrite == "y":
break
elif overwrite == "n":
print("Exiting...")
return
else:
print("Please enter y or n")
re_inp = re.compile(r'\.input_blocks\.(\d+)\.') # 12
re_mid = re.compile(r'\.middle_block\.(\d+)\.') # 1
re_out = re.compile(r'\.output_blocks\.(\d+)\.') # 12
for key in (tqdm(theta_0.keys(), desc="Stage 1/2") if not args.verbose else theta_0.keys()):
if "model" in key and key in theta_1:
current_alpha = alpha
# check weighted and U-Net or not
if weights is not None and 'model.diffusion_model.' in key:
# check block index
weight_index = -1
if 'time_embed' in key:
weight_index = 0 # before input blocks
elif '.out.' in key:
weight_index = NUM_TOTAL_BLOCKS - 1 # after output blocks
else:
m = re_inp.search(key)
if m:
inp_idx = int(m.groups()[0])
weight_index = inp_idx
else:
m = re_mid.search(key)
if m:
weight_index = NUM_INPUT_BLOCKS
else:
m = re_out.search(key)
if m:
out_idx = int(m.groups()[0])
weight_index = NUM_INPUT_BLOCKS + NUM_MID_BLOCK + out_idx
if weight_index >= NUM_TOTAL_BLOCKS:
print(f"error. illegal block index: {key}")
if weight_index >= 0:
current_alpha = weights[weight_index]
if args.verbose:
print(f"weighted '{key}': {current_alpha}")
theta_0[key] = (1 - current_alpha) * theta_0[key] + current_alpha * theta_1[key]
for key in tqdm(theta_1.keys(), desc="Stage 2/2"):
if "model" in key and key not in theta_0:
theta_0[key] = theta_1[key]
print("Saving...")
torch.save({"state_dict": theta_0}, output_file)
print("Done!")
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Merge two models with weights for each block")
parser.add_argument("model_0", type=str, help="Path to model 0")
parser.add_argument("model_1", type=str, help="Path to model 1")
parser.add_argument("--base_alpha", type=float,
help="Alpha value (for model 0) except U-Net, optional, defaults to 0.5", default=0.5, required=False)
parser.add_argument("--output", type=str, help="Output file name, without extension", default="merged", required=False)
parser.add_argument("--device", type=str, help="Device to use, defaults to cpu", default="cpu", required=False)
parser.add_argument("--weights", type=str,
help=f"comma separated {NUM_TOTAL_BLOCKS} weights value (for model 0) for each U-Net block", default=None, required=False)
parser.add_argument("--verbose", action='store_true', help="show each block weight", required=False)
args = parser.parse_args()
merge(args)