-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
231 lines (191 loc) · 7.87 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
用户友好的命令行界面,用于运行removebg工具
"""
import os
import sys
import argparse
import time
from pathlib import Path
import warnings
# 忽略 onnxruntime 的警告
warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime")
def print_banner():
"""打印欢迎横幅"""
banner = """
╔═══════════════════════════════════════════════╗
║ ║
║ 批量图片背景移除工具 (RemoveBG) ║
║ ║
║ 基于 rembg 库开发 ║
║ https://github.com/danielgatis/rembg ║
║ ║
╚═══════════════════════════════════════════════╝
"""
print(banner)
def check_dependencies():
"""检查依赖是否已安装"""
try:
import rembg
import PIL
# 检查NumPy版本
import numpy as np
numpy_version = np.__version__
if numpy_version.startswith('2.'):
print(f"警告: 当前NumPy版本 {numpy_version} 可能与onnxruntime不兼容")
print("建议降级到 NumPy 1.24.x: pip install numpy==1.24.3")
return True
except ImportError as e:
print(f"错误: 缺少必要的依赖。请先运行 'pip install -r requirements.txt'")
print(f"详细错误: {str(e)}")
return False
def check_gpu_availability():
"""检查是否有可用的GPU"""
try:
# 使用更安静的方式检查CUDA可用性
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 尝试使用第一个GPU
# 临时重定向标准错误输出
old_stderr = sys.stderr
sys.stderr = open(os.devnull, 'w')
try:
import onnxruntime as ort
providers = ort.get_available_providers()
result = 'CUDAExecutionProvider' in providers
finally:
# 恢复标准错误输出
sys.stderr.close()
sys.stderr = old_stderr
return result
except Exception:
return False
def get_user_input():
"""获取用户输入"""
print("\n请输入以下信息:")
# 获取输入目录
while True:
input_dir = input("输入图片目录路径: ").strip()
if not input_dir:
print("错误: 输入目录不能为空")
continue
if not os.path.isdir(input_dir):
print(f"错误: 目录不存在: {input_dir}")
continue
break
# 获取输出目录
while True:
output_dir = input("输出图片目录路径: ").strip()
if not output_dir:
print("错误: 输出目录不能为空")
continue
# 如果输出目录不存在,询问是否创建
if not os.path.isdir(output_dir):
create = input(f"目录不存在: {output_dir},是否创建? (y/n): ").strip().lower()
if create != 'y':
continue
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
print(f"错误: 无法创建目录: {str(e)}")
continue
break
# 获取线程数
while True:
workers = input("并行处理的线程数 (默认: 4): ").strip()
if not workers:
workers = 4
break
try:
workers = int(workers)
if workers < 1:
print("错误: 线程数必须大于0")
continue
break
except ValueError:
print("错误: 请输入有效的数字")
# 询问是否使用GPU
gpu_available = check_gpu_availability()
if gpu_available:
print("检测到可用的GPU")
use_gpu = input("是否使用GPU加速处理? (y/n, 默认: y): ").strip().lower()
use_gpu = use_gpu != 'n' # 默认使用GPU
else:
print("未检测到可用的GPU,将使用CPU处理")
use_gpu = False
# 选择模型
models = ['u2net', 'u2netp', 'u2net_human_seg', 'silueta', 'isnet-general-use', 'isnet-anime']
print("\n可用的模型:")
for i, model in enumerate(models):
print(f"{i+1}. {model}")
while True:
model_choice = input(f"请选择模型 (1-{len(models)}, 默认: 1): ").strip()
if not model_choice:
model_name = models[0]
break
try:
model_idx = int(model_choice) - 1
if 0 <= model_idx < len(models):
model_name = models[model_idx]
break
else:
print(f"错误: 请输入1-{len(models)}之间的数字")
except ValueError:
print("错误: 请输入有效的数字")
return input_dir, output_dir, workers, use_gpu, model_name
def main():
"""主函数"""
print_banner()
# 检查依赖
if not check_dependencies():
sys.exit(1)
# 解析命令行参数
parser = argparse.ArgumentParser(description='批量移除图片背景工具')
parser.add_argument('--input', '-i', help='输入图片目录')
parser.add_argument('--output', '-o', help='输出图片目录')
parser.add_argument('--workers', '-w', type=int, default=4, help='并行处理的工作线程数 (默认: 4)')
parser.add_argument('--gpu', '-g', action='store_true', help='使用GPU加速处理 (如果可用)')
parser.add_argument('--model', '-m', type=str, default='u2net',
choices=['u2net', 'u2netp', 'u2net_human_seg', 'silueta', 'isnet-general-use', 'isnet-anime'],
help='使用的模型 (默认: u2net)')
parser.add_argument('--interactive', '-int', action='store_true', help='使用交互模式')
args = parser.parse_args()
# 如果指定了交互模式或者没有提供输入/输出目录,则进入交互模式
if args.interactive or (not args.input or not args.output):
input_dir, output_dir, workers, use_gpu, model_name = get_user_input()
else:
input_dir, output_dir, workers = args.input, args.output, args.workers
use_gpu = args.gpu
model_name = args.model
# 检查输入目录是否存在
if not os.path.isdir(input_dir):
print(f"错误: 输入目录不存在: {input_dir}")
sys.exit(1)
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 如果指定了GPU但GPU不可用,给出提示
if use_gpu and not check_gpu_availability():
print("提示: 未检测到可用的GPU,将使用CPU处理")
try:
# 导入removebg模块
import removebg
# 显示处理信息
print("\n开始处理图片...")
print(f"输入目录: {input_dir}")
print(f"输出目录: {output_dir}")
print(f"线程数: {workers}")
print(f"使用GPU: {'是' if use_gpu else '否'}")
print(f"使用模型: {model_name}")
# 记录开始时间
start_time = time.time()
# 处理图片
removebg.process_directory(input_dir, output_dir, workers, use_gpu, model_name)
# 计算处理时间
elapsed_time = time.time() - start_time
minutes, seconds = divmod(elapsed_time, 60)
print(f"\n处理完成! 总耗时: {int(minutes)}分{int(seconds)}秒")
except Exception as e:
print(f"\n处理过程中发生错误: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()