Skip to content

Commit

Permalink
indexes for source urls of weights added
Browse files Browse the repository at this point in the history
  • Loading branch information
serengil committed Oct 5, 2024
1 parent cb02857 commit 6d1d6d3
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions deepface/commons/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,40 @@

logger = Logger()

# pylint: disable=line-too-long
WEIGHTS = {
"facial_recognition": {
"VGG-Face": "https://github.com/serengil/deepface_models/releases/download/v1.0/vgg_face_weights.h5",
"Facenet": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet_weights.h5",
"Facenet512": "https://github.com/serengil/deepface_models/releases/download/v1.0/facenet512_weights.h5",
"OpenFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/openface_weights.h5",
"FbDeepFace": "https://github.com/swghosh/DeepFace/releases/download/weights-vggface2-2d-aligned/VGGFace2_DeepFace_weights_val-0.9034.h5.zip",
"ArcFace": "https://github.com/serengil/deepface_models/releases/download/v1.0/arcface_weights.h5",
"DeepID": "https://github.com/serengil/deepface_models/releases/download/v1.0/deepid_keras_weights.h5",
"SFace": "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
"GhostFaceNet": "https://github.com/HamadYA/GhostFaceNets/releases/download/v1.2/GhostFaceNet_W1.3_S1_ArcFace.h5",
"Dlib": "http://dlib.net/files/dlib_face_recognition_resnet_model_v1.dat.bz2",
},
"demography": {
"Age": "https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5",
"Gender": "https://github.com/serengil/deepface_models/releases/download/v1.0/gender_model_weights.h5",
"Emotion": "https://github.com/serengil/deepface_models/releases/download/v1.0/facial_expression_model_weights.h5",
"Race": "https://github.com/serengil/deepface_models/releases/download/v1.0/race_model_single_batch.h5",
},
"detection": {
"ssd_model": "https://github.com/opencv/opencv/raw/3.4.0/samples/dnn/face_detector/deploy.prototxt",
"ssd_weights": "https://github.com/opencv/opencv_3rdparty/raw/dnn_samples_face_detector_20170830/res10_300x300_ssd_iter_140000.caffemodel",
"yolo": "https://drive.google.com/uc?id=1qcr9DbgsX3ryrz2uU8w4Xm3cOrRywXqb",
"yunet": "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
"dlib": "http://dlib.net/files/shape_predictor_5_face_landmarks.dat.bz2",
"centerface": "https://github.com/Star-Clouds/CenterFace/raw/master/models/onnx/centerface.onnx",
},
"spoofing": {
"MiniFASNetV2": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/2.7_80x80_MiniFASNetV2.pth",
"MiniFASNetV1SE": "https://github.com/minivision-ai/Silent-Face-Anti-Spoofing/raw/master/resources/anti_spoof_models/4_0_0_80x80_MiniFASNetV1SE.pth",
},
}

ALLOWED_COMPRESS_TYPES = ["zip", "bz2"]


Expand Down Expand Up @@ -95,3 +129,20 @@ def load_model_weights(model: Sequential, weight_file: str) -> Sequential:
"and copying it to the target folder."
) from err
return model


def retrieve_model_source(model_name: str, task: str) -> str:
"""
Find the source url of a given model name
Args:
model_name (str): given model name
Returns:
weight_url (str): source url of the given model
"""
if task not in ["facial_recognition", "detection", "demography", "spoofing"]:
raise ValueError(f"unimplemented task - {task}")

source_url = WEIGHTS.get(task, {}).get(model_name)
if source_url is None:
raise ValueError(f"Source url cannot be found for given model {task}-{model_name}")
return source_url

0 comments on commit 6d1d6d3

Please sign in to comment.