Skip to content

Commit

Permalink
Wire DPO (#67)
Browse files Browse the repository at this point in the history
* route dpos through alignment routine

* dpo version bump
  • Loading branch information
Jacobsolawetz authored Aug 1, 2024
1 parent 083ed0f commit 0574c60
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion arcee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.4"
__version__ = "1.3.5"

import os

Expand Down
12 changes: 9 additions & 3 deletions arcee/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,29 +349,34 @@ def corpus_status(corpus: str) -> Dict[str, str]:

def start_alignment(
alignment_name: str,
qa_set: str,
qa_set: Optional[str] = None,
pretrained_model: Optional[str] = None,
merging_model: Optional[str] = None,
alignment_model: Optional[str] = None,
hf_model: Optional[str] = None,
target_compute: Optional[str] = None,
capacity_id: Optional[str] = None,
full_or_peft: Optional[str] = "full",
alignment_type: str = "sft",
full_or_peft: Optional[str] = "full"
) -> Dict[str, str]:
"""
Start the alignment of a model.
Args:
alignment_name (str): The name of the alignment job.
qa_set (str): The name of the QA set to use.
qa_set (Optional[str]): The name of the QA set to use. Required if alignment_type is "sft".
pretrained_model (Optional[str]): The name of the pretrained model to use, if any.
merging_model (Optional[str]): The name of the merging model to use, if any.
alignment_model (Optional[str]): The name of the final alignment model to use, if any.
hf_model (Optional[str]): The name of the Hugging Face model to use, if any.
target_compute (Optional[str]): The name of the compute to use, e.g., "g5.2xlarge" or
"capacity". If omitted, the default compute will be used.
capacity_id (Optional[str]): The name of the capacity block ID to use. If omitted, an
instance will be launched to perform training.
alignment_type (str): The type of alignment to perform. Can be "sft" or "dpo". Defaults to "sft".
"""
if alignment_type == "sft" and qa_set is None:
raise ValueError("qa_set is required when alignment_type is 'sft'")

data = {
"alignment_name": alignment_name,
Expand All @@ -383,6 +388,7 @@ def start_alignment(
"hf_model": hf_model,
"target_compute": target_compute,
"capacity_id": capacity_id,
"alignment_type": alignment_type,
}

# Assuming make_request is a function that handles the request, it's called here
Expand Down

0 comments on commit 0574c60

Please sign in to comment.