-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
315 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,315 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"hbv\n", | ||
"__pycache__\n", | ||
"prms\n", | ||
"Folder: hbv\n", | ||
"Folder: __pycache__\n", | ||
"Folder: prms\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from pathlib import Path\n", | ||
"import sys\n", | ||
"\n", | ||
"directory = Path('../models')\n", | ||
"\n", | ||
"# # List all files and directories \n", | ||
"# for item in directory.iterdir():\n", | ||
"# print(item.name)\n", | ||
"\n", | ||
"# List only files\n", | ||
"for file in directory.iterdir():\n", | ||
" if file.is_file():\n", | ||
" print(\"File:\", file.name)\n", | ||
"\n", | ||
"# List only directories\n", | ||
"for folder in directory.iterdir():\n", | ||
" if folder.is_dir():\n", | ||
" print(\"Folder:\", folder.name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"from pathlib import Path\n", | ||
"directory = Path('../models')\n", | ||
"\n", | ||
"\n", | ||
"def get_directories(directory):\n", | ||
" dirs = []\n", | ||
" dir_names = []\n", | ||
" avoid_list = ['__pycache__']\n", | ||
"\n", | ||
" for item in directory.iterdir():\n", | ||
" if item.is_dir() and (item.name not in avoid_list):\n", | ||
" dirs.append(item)\n", | ||
" dir_names.append(item.name)\n", | ||
" return dirs, dir_names\n", | ||
"\n", | ||
"\n", | ||
"def get_files(directory):\n", | ||
" files = []\n", | ||
" file_names = []\n", | ||
" avoid_list = ['__init__', '.DS_Store', 'README.md', '.git']\n", | ||
" \n", | ||
" for item in directory.iterdir():\n", | ||
" if item.is_file() and (item.name not in avoid_list):\n", | ||
" files.append(item)\n", | ||
"\n", | ||
" # Remove file extension\n", | ||
" name = os.path.splitext(item.name)[0]\n", | ||
" file_names.append(name)\n", | ||
" return files, file_names\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def available_models():\n", | ||
" directory = Path('../models')\n", | ||
"\n", | ||
" dirs = []\n", | ||
" models = {}\n", | ||
"\n", | ||
" dirs, _ = get_directories(directory)\n", | ||
" for dir in dirs:\n", | ||
" _, file_names = get_files(dir)\n", | ||
" models[dir.name] = file_names\n", | ||
" \n", | ||
" return models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import importlib\n", | ||
"\n", | ||
"def import_model(module_name: str):\n", | ||
" try:\n", | ||
" # Import the module using importlib\n", | ||
" module = importlib.import_module(module_name)\n", | ||
" return module\n", | ||
" except ModuleNotFoundError:\n", | ||
" print(f\"Module '{module_name}' not found.\")\n", | ||
" return None\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import importlib\n", | ||
"import os\n", | ||
"\n", | ||
"model = 'HBV_v1_1p'\n", | ||
"model_dir = model.split('_')[0].lower() # Get the model directory name\n", | ||
"model_subpath = os.path.join(model_dir, f'{model.lower()}.py') # Add the file name\n", | ||
"\n", | ||
"source = os.path.join('./models', model_subpath)\n", | ||
"\n", | ||
"x = importlib.util.spec_from_file_location('HBVMulTDET',source)\n", | ||
"mod = x.loader.load_module()\n", | ||
"\n", | ||
"mod.HBVMulTDET" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 48, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import importlib.util\n", | ||
"import os\n", | ||
"from torch.nn import Module\n", | ||
"\n", | ||
"def load_model(model: str, ver_name: str = None) -> Module:\n", | ||
" \"\"\"Load a model from the models directory.\n", | ||
"\n", | ||
" Each model file in `models/` directory should only contain one model class.\n", | ||
"\n", | ||
" Parameters\n", | ||
" ----------\n", | ||
" model : str\n", | ||
" The model name.\n", | ||
" ver_name : str, optional\n", | ||
" The version name (class) of the model to load within the model file.\n", | ||
" \n", | ||
" Returns\n", | ||
" -------\n", | ||
" Module\n", | ||
" The uninstantiated model.\n", | ||
" \"\"\"\n", | ||
" # Construct file path\n", | ||
" model_dir = model.split('_')[0].lower()\n", | ||
" model_subpath = os.path.join(model_dir, f'{model.lower()}.py')\n", | ||
" \n", | ||
" # Path to the module file in the models directory\n", | ||
" source = os.path.join('./models', model_subpath)\n", | ||
" \n", | ||
" # Load the model dynamically as a module.\n", | ||
" spec = importlib.util.spec_from_file_location(model, source)\n", | ||
" module = importlib.util.module_from_spec(spec)\n", | ||
" spec.loader.exec_module(module)\n", | ||
" \n", | ||
" # Retrieve the version name if specified, otherwise get the first class in the module\n", | ||
" if ver_name:\n", | ||
" cls = getattr(module, ver_name)\n", | ||
" else:\n", | ||
" # Find the first class in the module (this may not always be accurate)\n", | ||
" classes = [attr for attr in dir(module) if isinstance(getattr(module, attr), type)]\n", | ||
" if not classes:\n", | ||
" raise ImportError(f\"No class found in module '{model}'\")\n", | ||
" cls = getattr(module, classes[0])\n", | ||
" \n", | ||
" print(module)\n", | ||
" return cls" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 49, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"<module 'HBV_v1_1p' from '/data/lgl5139/project_blue_eyes/hydroDL2/src/hydroDL2/./models/hbv/hbv_v1_1p.py'>\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"x = import_class_from_file('HBV_v1_1p')\n", | ||
"x =x()\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 38, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Model 'hbv_v1_1p' not found.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import importlib.util\n", | ||
"import os\n", | ||
"\n", | ||
"def load_models(models_path: str):\n", | ||
" model_type \n", | ||
" models = {}\n", | ||
" \n", | ||
" # Scan the models directory for Python files\n", | ||
" for filename in os.listdir(models_path):\n", | ||
" # Check if the file is a Python file\n", | ||
" if filename.endswith(\".py\") and not filename.startswith(\"__\"):\n", | ||
" # Get the module name and full file path\n", | ||
" module_name = filename[:-3] # Strip the '.py' extension\n", | ||
" file_path = os.path.join(models_path, filename)\n", | ||
" \n", | ||
" # Load the module from the file path\n", | ||
" spec = importlib.util.spec_from_file_location(module_name, file_path)\n", | ||
" module = importlib.util.module_from_spec(spec)\n", | ||
" spec.loader.exec_module(module)\n", | ||
" \n", | ||
" # Dynamically find the first class in the module\n", | ||
" classes = [attr for attr in dir(module) if isinstance(getattr(module, attr), type)]\n", | ||
" if classes:\n", | ||
" model_class = getattr(module, classes[0]) # Assumes the first class is the model\n", | ||
" models[module_name] = model_class # Store it in the dictionary by module name\n", | ||
"\n", | ||
" return models\n", | ||
"\n", | ||
"# Usage example\n", | ||
"models_directory = './models'\n", | ||
"loaded_models = load_models(models_directory)\n", | ||
"\n", | ||
"# Access a specific model\n", | ||
"model_name = 'hbv_v1_1p' # Example model name without '.py' extension\n", | ||
"if model_name in loaded_models:\n", | ||
" model_class = loaded_models[model_name]\n", | ||
" model_instance = model_class() # Instantiate the model class\n", | ||
" print(f\"Loaded model instance: {model_instance}\")\n", | ||
"else:\n", | ||
" print(f\"Model '{model_name}' not found.\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 35, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{}" | ||
] | ||
}, | ||
"execution_count": 35, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"loaded_models" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "hydrodl", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |