diff --git a/dsp/modules/hf_client.py b/dsp/modules/hf_client.py index 3f77a995d..22729012c 100644 --- a/dsp/modules/hf_client.py +++ b/dsp/modules/hf_client.py @@ -217,6 +217,8 @@ def __init__(self, model, **kwargs): if "inst" in self.model.lower() or "instruct" in model.lower(): self.use_inst_template = True + stop_default = "\n\n---" + self.kwargs = { "temperature": 0.0, "max_tokens": 512, @@ -224,7 +226,7 @@ def __init__(self, model, **kwargs): "top_k": 20, "repetition_penalty": 1, "n": 1, - "stop": ["\n\n", "---", "[/INST]"], + "stop": stop_default if "stop" not in kwargs else kwargs["stop"], **kwargs } @@ -247,7 +249,6 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): repetition_penalty = kwargs.get("repetition_penalty", 1) prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt - if use_chat_api: url = f"{self.api_base}/chat/completions" messages = [