Skip to content

Commit

Permalink
Save/Load config as yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjeblick authored Apr 26, 2023
1 parent 019f8b1 commit bf1cbf6
Show file tree
Hide file tree
Showing 21 changed files with 563 additions and 369 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ clean-env:
clean-data:
rm -rf data

clean-output:
rm -rf output

reports:
mkdir -p reports

Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,15 @@ python prompt.py -e examples/output_oasst1

All open-source datasets and models are posted on [H2O.ai's Hugging Face page](https://huggingface.co/h2oai/).


## Changelog
The field is rapidly evolving, and we are constantly adding new features and fixing bugs.
While we are striving to converge to a stable framework, at this early point of development certain changes may break your existing experiments.
We thus recommend to pin the version of the framework to the one you used for your experiments.
Below, we list a summary of the changes that may affect older experiments:
- [PR 12](https://github.com/h2oai/h2o-llmstudio/pull/12). Experiment configurations are now stored in yaml format,
allowing for more flexibility in the configuration while making it much easier to be backward compatible. Old experiment configurations that are stored in pickle format will be converted to yaml format automatically.
- [PR 40](https://github.com/h2oai/h2o-llmstudio/pull/40). Datasets can now use past chat history as input. This feature can be enabled by setting `Parent Id Column` in the dataset configuration.

## License
H2O LLM Studio is licensed under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
1 change: 0 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ async def serve(q: Q):
copy_expando(q.args, q.client)

await initialize_client(q)

await handle(q)

if not q.args["experiment/display/chat/chatbot"]:
Expand Down
3 changes: 3 additions & 0 deletions app_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,6 @@ def rename_experiment(self, id: int, new_name: str, new_path: str) -> None:
experiment.name = new_name
experiment.path = new_path
self._session.commit()

def update(self) -> None:
self._session.commit()
18 changes: 10 additions & 8 deletions app_utils/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@
from h2o_wave import Q

from app_utils.sections.common import interface
from llm_studio.src.utils.config_utils import load_config

from llm_studio.src.utils.config_utils import (
load_config_py,
save_config_yaml,
)
from .config import default_cfg
from .db import Database, Dataset
from .migration import migrate_app
from .utils import (
get_data_dir,
get_db_path,
get_user_name,
load_user_settings,
prepare_default_dataset,
save_dill,
)

logger = logging.getLogger(__name__)
Expand All @@ -38,21 +40,21 @@ def import_data(q: Q):

df = prepare_default_dataset(path)

cfg = load_config(
cfg = load_config_py(
config_path=os.path.join(
"llm_studio/python_configs", default_cfg.cfg_file
),
config_name="ConfigProblemBase",
)

cfg.dataset.train_dataframe = os.path.join(path, "train_full.pq")
cfg.dataset.prompt_column = "instruction"
cfg.dataset.prompt_column = ("instruction",)
cfg.dataset.answer_column = "output"
cfg.dataset.parent_id_column = "None"

cfg_path = os.path.join(path, f"{default_cfg.cfg_file}.p")
cfg_path = os.path.join(path, f"{default_cfg.cfg_file}.yaml")

save_dill(cfg_path, cfg)
save_config_yaml(cfg_path, cfg)

dataset = Dataset(
id=1,
Expand Down Expand Up @@ -96,6 +98,7 @@ async def initialize_client(q: Q) -> None:

load_user_settings(q)

await migrate_app(q)
await interface(q)

q.args[default_cfg.start_page] = True
Expand Down Expand Up @@ -128,7 +131,6 @@ async def initialize_app(q: Q) -> None:
script_sources.append(url)

q.app["script_sources"] = script_sources

q.app["initialized"] = True

logger.info("Initializing app ... done")
62 changes: 62 additions & 0 deletions app_utils/migration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import logging
import os

import dill
from h2o_wave import Q

from app_utils.db import Database
from app_utils.utils import get_data_dir, get_output_dir
from llm_studio.src.utils.config_utils import save_config_yaml

logger = logging.getLogger(__name__)


async def migrate_app(q: Q) -> None:
"""
Migration scripts for the app.
"""
migrate_pickle_to_yaml(q)
migrate_database_pickle_to_yaml(q)


def migrate_pickle_to_yaml(q: Q) -> None:
"""
Change from pickle -> yaml for config files.
introduced in https://github.com/h2oai/h2o-llmstudio/pull/12
"""
data_dir = get_data_dir(q)
output_dir = get_output_dir(q)

for dir in [data_dir, output_dir]:
if os.path.exists(dir):
for root, dirs, files in os.walk(dir):
for file in files:
if file.endswith(".p") and not os.path.exists(
os.path.join(root, file.replace(".p", ".yaml"))
):
try:
with open(os.path.join(root, file), "rb") as f:
cfg = dill.load(f)
save_config_yaml(
os.path.join(root, file.replace(".p", ".yaml")), cfg
)
logger.info(
f"migrated {os.path.join(root, file)} to yaml"
)
except Exception as e:
logger.error(
f"Could not migrate {os.path.join(root, file)} "
f"to yaml: {e}"
)


def migrate_database_pickle_to_yaml(q: Q) -> None:
"""
Change from pickle -> yaml for config files.
introduced in https://github.com/h2oai/h2o-llmstudio/pull/12
"""
db: Database = q.client.app_db
for dataset_id in db.get_datasets_df()["id"]:
dataset = db.get_dataset(dataset_id)
dataset.config_file = dataset.config_file.replace(".p", ".yaml")
db.update()
27 changes: 15 additions & 12 deletions app_utils/sections/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@
get_problem_types,
get_unique_dataset_name,
kaggle_download,
load_dill,
local_download,
make_label,
parse_ui_elements,
remove_temp_files,
s3_download,
s3_file_options,
save_dill,
)
from app_utils.wave_utils import ui_table_from_df
from llm_studio.src.utils.config_utils import load_config
from llm_studio.src.utils.config_utils import (
load_config_py,
load_config_yaml,
save_config_yaml,
)
from llm_studio.src.utils.data_utils import (
get_fill_columns,
read_dataframe,
Expand Down Expand Up @@ -417,7 +419,7 @@ async def dataset_import(
q.client["dataset/import/cfg_file"], model_types[0][0]
)
if not edit:
q.client["dataset/import/cfg"] = load_config(
q.client["dataset/import/cfg"] = load_config_py(
config_path=(
f"llm_studio/python_configs/"
f"{q.client['dataset/import/cfg_file']}"
Expand Down Expand Up @@ -586,7 +588,9 @@ async def dataset_import(
# change the default validation strategy if validation df set
if cfg.dataset.validation_dataframe != "None":
cfg.dataset.validation_strategy = "custom"
save_dill(f"{new_path}/{q.client['dataset/import/cfg_file']}.p", cfg)
save_config_yaml(
f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml", cfg
)

train_rows = None
if os.path.exists(cfg.dataset.train_dataframe):
Expand All @@ -603,7 +607,7 @@ async def dataset_import(
id=q.client["dataset/import/id"],
name=q.client["dataset/import/name"],
path=new_path,
config_file=f"{new_path}/{q.client['dataset/import/cfg_file']}.p",
config_file=f"{new_path}/{q.client['dataset/import/cfg_file']}.yaml",
train_rows=train_rows,
validation_rows=validation_rows,
)
Expand Down Expand Up @@ -681,7 +685,7 @@ async def dataset_merge(q: Q, step, error=""):
has_experiment = False

current_files = os.listdir(current_dir)
current_files = [x for x in current_files if not x.endswith(".p")]
current_files = [x for x in current_files if not x.endswith(".yaml")]
target_files = os.listdir(target_dir)
overlapping_files = list(set(current_files).intersection(set(target_files)))
rename_map = {}
Expand Down Expand Up @@ -834,7 +838,7 @@ async def dataset_newexperiment(q: Q, dataset_id: int):
dataset = q.client.app_db.get_dataset(dataset_id)

q.client["experiment/start/cfg_file"] = dataset.config_file.split("/")[-1].replace(
".p", ""
".yaml", ""
)
q.client["experiment/start/cfg_category"] = q.client[
"experiment/start/cfg_file"
Expand Down Expand Up @@ -877,15 +881,15 @@ async def dataset_edit(
q.client["dataset/import/id"] = dataset_id

q.client["dataset/import/cfg_file"] = dataset.config_file.split("/")[-1].replace(
".p", ""
".yaml", ""
)
q.client["dataset/import/cfg_category"] = q.client["dataset/import/cfg_file"].split(
"_"
)[0]
q.client["dataset/import/path"] = dataset.path
q.client["dataset/import/name"] = dataset.name
q.client["dataset/import/original_name"] = dataset.name
q.client["dataset/import/cfg"] = load_dill(dataset.config_file)
q.client["dataset/import/cfg"] = load_config_yaml(dataset.config_file)

if allow_merge and experiments_df.shape[0]:
allow_merge = False
Expand Down Expand Up @@ -961,8 +965,7 @@ async def dataset_display(q: Q) -> None:
q.client["dataset/display/id"]
]
dataset = q.client.app_db.get_dataset(dataset_id)
config_file = dataset.config_file
cfg = load_dill(config_file)
cfg = load_config_yaml(dataset.config_file)

has_train_df = cfg.dataset.train_dataframe != "None"

Expand Down
Loading

0 comments on commit bf1cbf6

Please sign in to comment.