Skip to content

Commit

Permalink
Merge pull request BeastByteAI#70 from KennethEnevoldsen/added-custom…
Browse files Browse the repository at this point in the history
…-prompt-functionality

Added functionality for custom prompts
  • Loading branch information
OKUA1 authored Sep 13, 2023
2 parents 1c93c89 + f993cab commit c3ec570
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,6 @@ test.py
tmp.ipynb
tmp.py
*.pickle

# vscode
.vscode/
4 changes: 4 additions & 0 deletions skllm/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class _BaseZeroShotGPTClassifier(BaseClassifier, _OAIMixin):
default_label : Optional[Union[List[str], str]] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -138,10 +140,12 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[Union[List[str], str]] = "Random",
prompt_template: Optional[str] = None,
):
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.default_label = default_label
self.prompt_template = prompt_template

@abstractmethod
def _get_prompt(self, x: str) -> str:
Expand Down
19 changes: 18 additions & 1 deletion skllm/models/gpt/gpt_dyn_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DynamicFewShotGPTClassifier(_BaseZeroShotGPTClassifier):
label will be chosen based on probabilities from the training set.
memory_index : Optional[IndexConstructor], default : None
The memory index constructor to use. If None, a SklearnMemoryIndex will be used.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -48,10 +50,12 @@ def __init__(
openai_model: str = "gpt-3.5-turbo",
default_label: str | None = "Random",
memory_index: IndexConstructor | None = None,
prompt_template: str | None = None,
):
super().__init__(openai_key, openai_org, openai_model, default_label)
self.n_examples = n_examples
self.memory_index = memory_index
self.prompt_template = prompt_template

def fit(
self,
Expand Down Expand Up @@ -96,6 +100,18 @@ def fit(

return self

def _get_prompt_template(self) -> str:
"""Returns the prompt template to use.
Returns
-------
str
prompt template
"""
if self.prompt_template is None:
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
return self.prompt_template

def _get_prompt(self, x: str) -> str:
"""Generates the prompt for the given input.
Expand All @@ -109,6 +125,7 @@ def _get_prompt(self, x: str) -> str:
str
final prompt
"""
prompt_template = self._get_prompt_template()
embedding = self.embedding_model_.transform([x])
training_data = []
for cls in self.classes_:
Expand All @@ -118,7 +135,7 @@ def _get_prompt(self, x: str) -> str:
neighbors = [partition[i] for i in neighbors[0]]
training_data.extend(
[
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
prompt_template.format(x=neighbor, label=cls)
for neighbor in neighbors
]
)
Expand Down
15 changes: 14 additions & 1 deletion skllm/models/gpt/gpt_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def fit(
self.classes_, self.probabilities_ = self._get_unique_targets(y)
return self

def _get_prompt_template(self) -> str:
"""Returns the prompt template to use.
Returns
-------
str
prompt template
"""
if self.prompt_template is None:
return _TRAINING_SAMPLE_PROMPT_TEMPLATE
return self.prompt_template

def _get_prompt(self, x: str) -> str:
"""Generates the prompt for the given input.
Expand All @@ -59,10 +71,11 @@ def _get_prompt(self, x: str) -> str:
str
final prompt
"""
prompt_template = self._get_prompt_template()
training_data = []
for xt, yt in zip(*self.training_data_):
training_data.append(
_TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
prompt_template.format(x=xt, label=yt)
)

training_data_str = "\n".join(training_data)
Expand Down
11 changes: 10 additions & 1 deletion skllm/models/gpt/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class ZeroShotGPTClassifier(_BaseZeroShotGPTClassifier):
default_label : Optional[str] , default : 'Random'
The default label to use if the LLM could not generate a response for a sample. If set to 'Random' a random
label will be chosen based on probabilities from the training set.
prompt_template: str , A formattable string with the following placeholders: {x} - the sample to classify, {labels} - the list of labels.
If None, the default prompt template will be used.
"""

def __init__(
Expand All @@ -36,11 +39,17 @@ def __init__(
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
default_label: Optional[str] = "Random",
prompt_template: Optional[str] = None,
):
super().__init__(openai_key, openai_org, openai_model, default_label)
self.prompt_template = prompt_template

def _get_prompt(self, x) -> str:
return build_zero_shot_prompt_slc(x, repr(self.classes_))
if self.prompt_template is None:
return build_zero_shot_prompt_slc(x, repr(self.classes_))
return build_zero_shot_prompt_slc(
x, repr(self.classes_), template=self.prompt_template
)

def fit(
self,
Expand Down
25 changes: 15 additions & 10 deletions tests/test_chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
import unittest
from unittest.mock import patch

from skllm.openai.chatgpt import (
construct_message,
extract_json_key,
get_chat_completion,
)
from skllm.openai.chatgpt import construct_message, get_chat_completion
from skllm.utils import extract_json_key


class TestChatGPT(unittest.TestCase):

@patch("skllm.openai.credentials.set_credentials")
@patch("openai.ChatCompletion.create")
def test_get_chat_completion(self, mock_create, mock_set_credentials):
Expand All @@ -21,9 +17,18 @@ def test_get_chat_completion(self, mock_create, mock_set_credentials):

result = get_chat_completion(messages, key, org, model)

self.assertTrue(mock_set_credentials.call_count <= 1, "set_credentials should be called at most once")
self.assertEqual(mock_create.call_count, 2, "ChatCompletion.create should be called twice due to an exception "
"on the first call")
self.assertTrue(
mock_set_credentials.call_count <= 1,
"set_credentials should be called at most once",
)
self.assertEqual(
mock_create.call_count,
2,
(
"ChatCompletion.create should be called twice due to an exception "
"on the first call"
),
)
self.assertEqual(result, "success")

def test_construct_message(self):
Expand All @@ -45,5 +50,5 @@ def test_extract_json_key(self):
self.assertEqual(result_with_invalid_key, None)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
3 changes: 1 addition & 2 deletions tests/test_gpt_few_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import numpy as np

from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
from skllm import DynamicFewShotGPTClassifier, FewShotGPTClassifier


class TestFewShotGPTClassifier(unittest.TestCase):
Expand Down
5 changes: 1 addition & 4 deletions tests/test_gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@

import numpy as np

from skllm.models.gpt_zero_shot_clf import (
MultiLabelZeroShotGPTClassifier,
ZeroShotGPTClassifier,
)
from skllm import MultiLabelZeroShotGPTClassifier, ZeroShotGPTClassifier


def _get_ret(label):
Expand Down

0 comments on commit c3ec570

Please sign in to comment.