From 6d1d6d32b3bc412db8f4a692b867cd41087f2157 Mon Sep 17 00:00:00 2001 From: Sefik Ilkin Serengil Date: Sat, 5 Oct 2024 21:54:15 +0100 Subject: [PATCH] indexes for source urls of weights added --- deepface/commons/weight_utils.py | 51 ++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/deepface/commons/weight_utils.py b/deepface/commons/weight_utils.py index d1aecf99..067b8f96 100644 --- a/deepface/commons/weight_utils.py +++ b/deepface/commons/weight_utils.py @@ -19,6 +19,40 @@ logger = Logger() +# pylint: disable=line-too-long +WEIGHTS = { + "facial_recognition": { + "VGG-Face": "https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5", + "Facenet": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5", + "Facenet512": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5", + "OpenFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5", + "FbDeepFace": "https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip", + "ArcFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5", + "DeepID": "https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5", + "SFace": "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx", + "GhostFaceNet": "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5", + "Dlib": "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2", + }, + "demography": { + "Age": "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5", + "Gender": "https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5", + "Emotion": "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5", + "Race": "https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5", + }, + "detection": { + "ssd_model": "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt", + "ssd_weights": "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel", + "yolo": "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb", + "yunet": "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx", + "dlib": "http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2", + "centerface": "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx", + }, + "spoofing": { + "MiniFASNetV2": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth", + "MiniFASNetV1SE": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth", + }, +} + ALLOWED_COMPRESS_TYPES = ["zip", "bz2"] @@ -95,3 +129,20 @@ def load_model_weights(model: Sequential, weight_file: str) -> Sequential: "and copying it to the target folder." ) from err return model + + +def retrieve_model_source(model_name: str, task: str) -> str: + """ + Find the source url of a given model name + Args: + model_name (str): given model name + Returns: + weight_url (str): source url of the given model + """ + if task not in ["facial_recognition", "detection", "demography", "spoofing"]: + raise ValueError(f"unimplemented task - {task}") + + source_url = WEIGHTS.get(task, {}).get(model_name) + if source_url is None: + raise ValueError(f"Source url cannot be found for given model {task}-{model_name}") + return source_url