Skip to content

Commit

Permalink
Added functionality for custom prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
KennethEnevoldsen committed Sep 13, 2023
1 parent 7dadca5 commit fd51a6e
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 3 deletions.
4 changes: 4 additions & 0 deletions skllm/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class _BaseZeroShotGPTClassifier(BaseClassifier, _OAIMixin):
default_label : Optional[Union[List[str], str]] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -138,10 +140,12 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], str]] = "Random",
prompt_template: Optional[str] = None,
):
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.default_label = default_label
self.prompt_template = prompt_template

@abstractmethod
def _get_prompt(self, x: str) -> str:
Expand Down
19 changes: 18 additions & 1 deletion skllm/models/gpt/gpt_dyn_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DynamicFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
label will be chosen based on probabilities from the training set.
memory_index : Optional[IndexConstructor], default : None
The memory index constructor to use. If None, a SklearnMemoryIndex will be used.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -48,10 +50,12 @@ def __init__(
openai_model: str = "gpt-3.5-turbo",
default_label: str | None = "Random",
memory_index: IndexConstructor | None = None,
prompt_template: str | None = None,
):
super().__init__(openai_key, openai_org, openai_model, default_label)
self.n_examples = n_examples
self.memory_index = memory_index
self.prompt_template = prompt_template

def fit(
self,
Expand Down Expand Up @@ -96,6 +100,18 @@ def fit(

return self

def _get_prompt_template(self) -> str:
"""Returns the prompt template to use.
Returns
-------
str
prompt template
"""
if self.prompt_template is None:
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
return self.prompt_template

def _get_prompt(self, x: str) -> str:
"""Generates the prompt for the given input.
Expand All @@ -109,6 +125,7 @@ def _get_prompt(self, x: str) -> str:
str
final prompt
"""
prompt_template = self._get_prompt_template()
embedding = self.embedding_model_.transform([x])
training_data = []
for cls in self.classes_:
Expand All @@ -118,7 +135,7 @@ def _get_prompt(self, x: str) -> str:
neighbors = [partition[i] for i in neighbors[0]]
training_data.extend(
[
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
prompt_template.format(x=neighbor, label=cls)
for neighbor in neighbors
]
)
Expand Down
15 changes: 14 additions & 1 deletion skllm/models/gpt/gpt_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def fit(
self.classes_, self.probabilities_ = self._get_unique_targets(y)
return self

def _get_prompt_template(self) -> str:
"""Returns the prompt template to use.
Returns
-------
str
prompt template
"""
if self.prompt_template is None:
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
return self.prompt_template

def _get_prompt(self, x: str) -> str:
"""Generates the prompt for the given input.
Expand All @@ -58,10 +70,11 @@ def _get_prompt(self, x: str) -> str:
str
final prompt
"""
prompt_template = self._get_prompt_template()
training_data = []
for xt, yt in zip(*self.training_data_):
training_data.append(
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
prompt_template.format(x=xt, label=yt)
)

training_data_str = "\n".join(training_data)
Expand Down
11 changes: 10 additions & 1 deletion skllm/models/gpt/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
default_label : Optional[str] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -36,11 +39,17 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[str] = "Random",
prompt_template: Optional[str] = None,
):
super().__init__(openai_key, openai_org, openai_model, default_label)
self.prompt_template = prompt_template

def _get_prompt(self, x) -> str:
return build_zero_shot_prompt_slc(x, repr(self.classes_))
if self.prompt_template is None:
return build_zero_shot_prompt_slc(x, repr(self.classes_))
return build_zero_shot_prompt_slc(
x, repr(self.classes_), template=self.prompt_template
)

def fit(
self,
Expand Down

0 comments on commit fd51a6e

Please sign in to comment.