Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
OKUA1 committed May 29, 2023
2 parents de77aa5 + c8f23c7 commit b8d3155
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 68 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ ignore = [
"D205",
"D401",
"E501",
"N803",
"N806",
]
extend-exclude = ["tests/*.py", "setup.py"]
target-version = "py38"
Expand Down
33 changes: 19 additions & 14 deletions skllm/models/gpt_zero_shot_clf.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
from typing import Optional, Union, List, Any
import numpy as np
import pandas as pd
from collections import Counter
import random
from tqdm import tqdm
from abc import ABC, abstractmethod
from collections import Counter
from typing import Any, List, Optional, Union

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, ClassifierMixin
from skllm.openai.prompts import get_zero_shot_prompt_slc, get_zero_shot_prompt_mlc
from tqdm import tqdm

from skllm.openai.chatgpt import (
construct_message,
get_chat_completion,
extract_json_key,
get_chat_completion,
)
from skllm.config import SKLLMConfig as _Config
from skllm.utils import to_numpy as _to_numpy
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from skllm.prompts.builders import (
build_zero_shot_prompt_mlc,
build_zero_shot_prompt_slc,
)
from skllm.utils import to_numpy as _to_numpy


class _BaseZeroShotGPTClassifier(ABC, BaseEstimator, ClassifierMixin, _OAIMixin):
def __init__(
Expand All @@ -34,7 +39,7 @@ def fit(
X: Optional[Union[np.ndarray, pd.Series, List[str]]],
y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
):
X = self._to_np(X)
X = self._to_np(X)
self.classes_, self.probabilities_ = self._get_unique_targets(y)
return self

Expand Down Expand Up @@ -91,15 +96,15 @@ def _extract_labels(self, y: Any) -> List[str]:
return labels

def _get_prompt(self, x) -> str:
return get_zero_shot_prompt_slc(x, self.classes_)
return build_zero_shot_prompt_slc(x, repr(self.classes_))

def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
label = str(
extract_json_key(completion.choices[0].message["content"], "label")
)
except Exception as e:
except Exception:
label = ""

if label not in self.classes_:
Expand Down Expand Up @@ -137,15 +142,15 @@ def _extract_labels(self, y) -> List[str]:
return labels

def _get_prompt(self, x) -> str:
return get_zero_shot_prompt_mlc(x, self.classes_, self.max_labels)
return build_zero_shot_prompt_mlc(x, repr(self.classes_), self.max_labels)

def _predict_single(self, x):
completion = self._get_chat_completion(x)
try:
labels = extract_json_key(completion.choices[0].message["content"], "label")
if not isinstance(labels, list):
raise RuntimeError("Invalid labels type, expected list")
except Exception as e:
except Exception:
labels = []

labels = list(filter(lambda l: l in self.classes_, labels))
Expand Down
50 changes: 0 additions & 50 deletions skllm/openai/prompts.py

This file was deleted.

9 changes: 5 additions & 4 deletions skllm/preprocessing/gpt_summarizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from skllm.openai.prompts import get_summary_prompt
from typing import Optional

from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT
from skllm.prompts.builders import build_summary_prompt

class GPTSummarizer(_BaseGPT):

class GPTSummarizer(_BaseGPT):
system_msg = "You are a text summarizer."
default_output = "Summary is unavailable."

Expand All @@ -17,6 +18,6 @@ def __init__(
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.max_words = max_words

def _get_prompt(self, X) -> str:
return get_summary_prompt(X, self.max_words)
return build_summary_prompt(X, self.max_words)
80 changes: 80 additions & 0 deletions skllm/prompts/builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Union

from skllm.prompts.templates import (
SUMMARY_PROMPT_TEMPLATE,
ZERO_SHOT_CLF_PROMPT_TEMPLATE,
ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
)

# TODO add validators


def build_zero_shot_prompt_slc(
x: str, labels: str, template: str = ZERO_SHOT_CLF_PROMPT_TEMPLATE
) -> str:
"""Builds a prompt for zero-shot single-label classification.
Parameters
----------
x : str
sample to classify
labels : str
candidate labels in a list-like representation
template : str
prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE
Returns
-------
str
prepared prompt
"""
return template.format(x=x, labels=labels)


def build_zero_shot_prompt_mlc(
x: str,
labels: str,
max_cats: Union[int, str],
template: str = ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
) -> str:
"""Builds a prompt for zero-shot multi-label classification.
Parameters
----------
x : str
sample to classify
labels : str
candidate labels in a list-like representation
max_cats : Union[int,str]
maximum number of categories to assign
template : str
prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_MLCLF_PROMPT_TEMPLATE
Returns
-------
str
prepared prompt
"""
return template.format(x=x, labels=labels, max_cats=max_cats)


def build_summary_prompt(
x: str, max_words: Union[int, str], template: str = SUMMARY_PROMPT_TEMPLATE
) -> str:
"""Builds a prompt for text summarization.
Parameters
----------
x : str
sample to summarize
max_words : Union[int,str]
maximum number of words to use in the summary
template : str
prompt template to use, must contain placeholders for all variables, by default SUMMARY_PROMPT_TEMPLATE
Returns
-------
str
prepared prompt
"""
return template.format(x=x, max_words=max_words)
41 changes: 41 additions & 0 deletions skllm/prompts/templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
ZERO_SHOT_CLF_PROMPT_TEMPLATE = """
You will be provided with the following information:
1. An arbitrary text sample. The sample is delimited with triple backticks.
2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated.
Perform the following tasks:
1. Identify to which category the provided text belongs to with the highest probability.
2. Assign the provided text to that category.
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON.
List of categories: {labels}
Text sample: ```{x}```
Your JSON response:
"""

ZERO_SHOT_MLCLF_PROMPT_TEMPLATE = """
You will be provided with the following information:
1. An arbitrary text sample. The sample is delimited with triple backticks.
2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. The text sample belongs to at least one category but cannot exceed {max_cats}.
Perform the following tasks:
1. Identify to which categories the provided text belongs to with the highest probability.
2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.
3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON.
List of categories: {labels}
Text sample: ```{x}```
Your JSON response:
"""

SUMMARY_PROMPT_TEMPLATE = """
Your task is to generate a summary of the text sample.
Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.
Text sample: ```{x}```
Summarized text:
"""

0 comments on commit b8d3155

Please sign in to comment.