Skip to content

Commit

Permalink
update run script for sequence classification
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Feb 6, 2025
1 parent 9941e25 commit 336ba37
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 532 deletions.
159 changes: 95 additions & 64 deletions examples/advanced/bionemo/downstream/downstream_nvflare.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,64 +84,6 @@
"warnings.simplefilter(\"ignore\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download Model Checkpoints\n",
"\n",
"The following code will download the pre-trained model, `\"esm2/8m:2.0`, from the NGC registry:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading data from 'nvidia/clara/esm2nv8m:2.0' to file '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz'.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"download_end\": \"2025-02-04 15:59:35\",\n",
" \"download_start\": \"2025-02-04 15:59:33\",\n",
" \"download_time\": \"1s\",\n",
" \"files_downloaded\": 1,\n",
" \"local_path\": \"/root/.cache/bionemo/tmperx2hsc3/esm2nv8m_v2.0\",\n",
" \"size_downloaded\": \"16.97 MB\",\n",
" \"status\": \"COMPLETED\"\n",
"}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Untarring contents of '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz' to '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar\n"
]
}
],
"source": [
"from bionemo.core.data.load import load\n",
"\n",
"checkpoint_path = load(\"esm2/8m:2.0\")\n",
"print(checkpoint_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -276,12 +218,101 @@
"source": [
"**Run training (central, local, & FL)**\n",
"\n",
"You can change the FL job that's going to be simulated inside the `run_sim_sabdab.py` script."
"You can change the FL job that's going to be simulated by changing the arguments of `run_sim_sabdab.py` script. The ESM2 finetuning arguments such as learning rate and others can be modified inside the script itself.\n",
"\n",
"First check its arguments."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py\", line 15, in <module>\n",
" from nvflare.job_config.script_runner import BaseScriptRunner\n",
"ModuleNotFoundError: No module named 'nvflare'\n"
]
}
],
"source": [
"!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --help"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**1. Central training**\n",
"\n",
"To simulate central training, we use one client, running one round of training for several steps. Note that if the `--exp_name` argument contains `\"central\"`, the combined training dataset is used."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py\", line 15, in <module>\n",
" from nvflare.job_config.script_runner import BaseScriptRunner\n",
"ModuleNotFoundError: No module named 'nvflare'\n"
]
}
],
"source": [
"!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=1 --num_rounds=1 --local_steps=300 --exp_name central"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**2. Local training**\n",
"\n",
"To simulate central training, we use six clients, each running one round of training for several steps."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py\", line 15, in <module>\n",
" from nvflare.job_config.script_runner import BaseScriptRunner\n",
"ModuleNotFoundError: No module named 'nvflare'\n"
]
}
],
"source": [
"!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=1 --local_steps=300 --exp_name local"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**3. FedAvg training**\n",
"\n",
"To simulate federated training, we use six clients, running several rounds with FedAvg, each with a smaller number of local steps."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -296,7 +327,7 @@
}
],
"source": [
"! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py"
"!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=30 --local_steps=10 --exp_name fedavg"
]
},
{
Expand All @@ -307,13 +338,13 @@
"| Setting | Accuracy |\n",
"|:-------:|:---------:|\n",
"| Local | 0.821 |\n",
"| FL | **0.833** |\n",
"| FedAvg | **0.833** |\n",
"\n",
"#### Results with heterogeneous data sampling (alpha=1.0)\n",
"| Setting | Accuracy |\n",
"|:-------:|:---------:|\n",
"| Local | 0.813 |\n",
"| FL | **0.835** |\n",
"| FedAvg | **0.835** |\n",
"\n",
"### Task 3. Subcellular location prediction with ESM2nv 650M\n",
"Follow the data download and preparation in [task_fitting.ipynb](../task_fitting/task_fitting.ipynb).\n",
Expand Down Expand Up @@ -343,7 +374,7 @@
"| Setting | Accuracy |\n",
"|:-------:|:---------:|\n",
"| Local | 0.773 |\n",
"| FL | **0.776** |\n",
"| FedAvg | **0.776** |\n",
"\n",
"\n",
"<img src=\"./scl/figs/scl_results.svg\" alt=\"Dirichlet sampling (alpha=1.0)\" width=\"300\"/>"
Expand Down
Loading

0 comments on commit 336ba37

Please sign in to comment.