-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
95 lines (74 loc) · 2.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import io
import re
from fastapi import FastAPI, File, UploadFile, HTTPException
from ultralytics import YOLO
from PIL import Image
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = [
"http://localhost:5173",
"http://172.18.50.52:5173"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def read_root():
return "Hola desde FastApi!"
@app.post("/classify")
async def classify_image(file: list[UploadFile]):
#model = YOLO("yolo11n-cls.pt")
model = YOLO("best.pt")
allowed_types = ["image/jpeg", "image/png", "image/gif"]
max_file_size = 10 * 1024 * 1024 # 10MB
for f in file:
if f.content_type not in allowed_types:
raise HTTPException(status_code=400, detail=f"Tipo de archivo no permitido.")
if f.size > max_file_size:
raise HTTPException(status_code=400, detail=f"Archivo demasiado grande: {f.filename}")
if ( len(file) > 1 ):
images = [Image.open(io.BytesIO(f.file.read())) for f in file]
classify = model.predict(images)
print(classify)
embeddings = [result.verbose() for result in classify]
joined_results = "".join(embeddings)
data_with_no_repeated_names = filter_by_repeated_name( data_to_dict(joined_results) )
return filter_by_repeated_name( sort_by_confidence(data_with_no_repeated_names) )
image = Image.open(io.BytesIO(file[0].file.read()))
classify = model.predict(image, conf=0.4)
return data_to_dict(classify[0].verbose())
def data_to_dict(data):
parts = data.split(",")
result = []
for part in parts:
if part.strip():
# name, number = part.strip().split()
# result.append({
# "name": name,
# "conf": float(number)
# })
match = re.match(r"(.+?)\s+([\d.]+)", part.strip())
if match:
name, number = match.groups()
result.append({
"name": name,
"conf": float(number)
})
return result
return result
def filter_by_repeated_name(data):
seen_names = set()
unique_data = []
for item in data:
if item["name"] not in seen_names:
unique_data.append(item)
seen_names.add(item["name"])
return unique_data
def filter_by_confidence(data, threshold):
return [item for item in data if item["conf"] < threshold]
def sort_by_confidence(data):
return sorted(data, key=lambda x: x["conf"], reverse=True)