Skip to content

Commit

Permalink
resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
leoglonz committed Nov 19, 2024
2 parents 056310b + cf8b97e commit 1a74e83
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 1 deletion.
1 change: 0 additions & 1 deletion src/hydroDL2/api/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def load_model(model: str, ver_name: str = None) -> Module:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except FileNotFoundError:
print(model)
raise ImportError(f"Model '{model}' not found.")

# Retrieve the version name if specified, otherwise get the first class in the module
Expand Down
Empty file.
Empty file.
315 changes: 315 additions & 0 deletions src/hydroDL2/test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"hbv\n",
"__pycache__\n",
"prms\n",
"Folder: hbv\n",
"Folder: __pycache__\n",
"Folder: prms\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import sys\n",
"\n",
"directory = Path('../models')\n",
"\n",
"# # List all files and directories \n",
"# for item in directory.iterdir():\n",
"# print(item.name)\n",
"\n",
"# List only files\n",
"for file in directory.iterdir():\n",
" if file.is_file():\n",
" print(\"File:\", file.name)\n",
"\n",
"# List only directories\n",
"for folder in directory.iterdir():\n",
" if folder.is_dir():\n",
" print(\"Folder:\", folder.name)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"directory = Path('../models')\n",
"\n",
"\n",
"def get_directories(directory):\n",
" dirs = []\n",
" dir_names = []\n",
" avoid_list = ['__pycache__']\n",
"\n",
" for item in directory.iterdir():\n",
" if item.is_dir() and (item.name not in avoid_list):\n",
" dirs.append(item)\n",
" dir_names.append(item.name)\n",
" return dirs, dir_names\n",
"\n",
"\n",
"def get_files(directory):\n",
" files = []\n",
" file_names = []\n",
" avoid_list = ['__init__', '.DS_Store', 'README.md', '.git']\n",
" \n",
" for item in directory.iterdir():\n",
" if item.is_file() and (item.name not in avoid_list):\n",
" files.append(item)\n",
"\n",
" # Remove file extension\n",
" name = os.path.splitext(item.name)[0]\n",
" file_names.append(name)\n",
" return files, file_names\n"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"def available_models():\n",
" directory = Path('../models')\n",
"\n",
" dirs = []\n",
" models = {}\n",
"\n",
" dirs, _ = get_directories(directory)\n",
" for dir in dirs:\n",
" _, file_names = get_files(dir)\n",
" models[dir.name] = file_names\n",
" \n",
" return models"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import importlib\n",
"\n",
"def import_model(module_name: str):\n",
" try:\n",
" # Import the module using importlib\n",
" module = importlib.import_module(module_name)\n",
" return module\n",
" except ModuleNotFoundError:\n",
" print(f\"Module '{module_name}' not found.\")\n",
" return None\n"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"import importlib\n",
"import os\n",
"\n",
"model = 'HBV_v1_1p'\n",
"model_dir = model.split('_')[0].lower() # Get the model directory name\n",
"model_subpath = os.path.join(model_dir, f'{model.lower()}.py') # Add the file name\n",
"\n",
"source = os.path.join('./models', model_subpath)\n",
"\n",
"x = importlib.util.spec_from_file_location('HBVMulTDET',source)\n",
"mod = x.loader.load_module()\n",
"\n",
"mod.HBVMulTDET"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import importlib.util\n",
"import os\n",
"from torch.nn import Module\n",
"\n",
"def load_model(model: str, ver_name: str = None) -> Module:\n",
" \"\"\"Load a model from the models directory.\n",
"\n",
" Each model file in `models/` directory should only contain one model class.\n",
"\n",
" Parameters\n",
" ----------\n",
" model : str\n",
" The model name.\n",
" ver_name : str, optional\n",
" The version name (class) of the model to load within the model file.\n",
" \n",
" Returns\n",
" -------\n",
" Module\n",
" The uninstantiated model.\n",
" \"\"\"\n",
" # Construct file path\n",
" model_dir = model.split('_')[0].lower()\n",
" model_subpath = os.path.join(model_dir, f'{model.lower()}.py')\n",
" \n",
" # Path to the module file in the models directory\n",
" source = os.path.join('./models', model_subpath)\n",
" \n",
" # Load the model dynamically as a module.\n",
" spec = importlib.util.spec_from_file_location(model, source)\n",
" module = importlib.util.module_from_spec(spec)\n",
" spec.loader.exec_module(module)\n",
" \n",
" # Retrieve the version name if specified, otherwise get the first class in the module\n",
" if ver_name:\n",
" cls = getattr(module, ver_name)\n",
" else:\n",
" # Find the first class in the module (this may not always be accurate)\n",
" classes = [attr for attr in dir(module) if isinstance(getattr(module, attr), type)]\n",
" if not classes:\n",
" raise ImportError(f\"No class found in module '{model}'\")\n",
" cls = getattr(module, classes[0])\n",
" \n",
" print(module)\n",
" return cls"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<module 'HBV_v1_1p' from '/data/lgl5139/project_blue_eyes/hydroDL2/src/hydroDL2/./models/hbv/hbv_v1_1p.py'>\n"
]
}
],
"source": [
"x = import_class_from_file('HBV_v1_1p')\n",
"x =x()\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model 'hbv_v1_1p' not found.\n"
]
}
],
"source": [
"import importlib.util\n",
"import os\n",
"\n",
"def load_models(models_path: str):\n",
" model_type \n",
" models = {}\n",
" \n",
" # Scan the models directory for Python files\n",
" for filename in os.listdir(models_path):\n",
" # Check if the file is a Python file\n",
" if filename.endswith(\".py\") and not filename.startswith(\"__\"):\n",
" # Get the module name and full file path\n",
" module_name = filename[:-3] # Strip the '.py' extension\n",
" file_path = os.path.join(models_path, filename)\n",
" \n",
" # Load the module from the file path\n",
" spec = importlib.util.spec_from_file_location(module_name, file_path)\n",
" module = importlib.util.module_from_spec(spec)\n",
" spec.loader.exec_module(module)\n",
" \n",
" # Dynamically find the first class in the module\n",
" classes = [attr for attr in dir(module) if isinstance(getattr(module, attr), type)]\n",
" if classes:\n",
" model_class = getattr(module, classes[0]) # Assumes the first class is the model\n",
" models[module_name] = model_class # Store it in the dictionary by module name\n",
"\n",
" return models\n",
"\n",
"# Usage example\n",
"models_directory = './models'\n",
"loaded_models = load_models(models_directory)\n",
"\n",
"# Access a specific model\n",
"model_name = 'hbv_v1_1p' # Example model name without '.py' extension\n",
"if model_name in loaded_models:\n",
" model_class = loaded_models[model_name]\n",
" model_instance = model_class() # Instantiate the model class\n",
" print(f\"Loaded model instance: {model_instance}\")\n",
"else:\n",
" print(f\"Model '{model_name}' not found.\")\n"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{}"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded_models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "hydrodl",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 1a74e83

Please sign in to comment.