Skip to content

Commit

Permalink
added GPTSummarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
iryna-kondr committed May 24, 2023
1 parent d3e8144 commit 037b747
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 11 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,22 @@ clf.fit(X_train, y_train_encoded)
yh = clf.predict(X_test)
```

### Text Summarization

GPT excels at performing summarization tasks. Therefore, we provide `GPTSummarizer` that can be used both as stand-alone estimator, or as a preprocessor (in this case we can make an analogy with a dimensionality reduction preprocessor).

Example:
```python
from skllm.preprocessing import GPTSummarizer
from skllm.datasets import get_summarization_dataset

X = get_summarization_dataset()
s = GPTSummarizer(openai_model = 'gpt-3.5-turbo', max_words = 15)
summaries = s.fit_transform(X)
```

Please be aware that the `max_words` hyperparameter sets a soft limit, which is not strictly enforced outside of the prompt. Therefore, in some cases, the actual number of words might be slightly higher.

## Roadmap 🧭

- [x] Zero-Shot Classification with OpenAI GPT 3/4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies = [
"tqdm>=4.60.0",
]
name = "scikit-llm"
version = "0.1.0b2"
version = "0.1.0b3"
authors = [
{ name="Oleg Kostromin", email="[email protected]" },
{ name="Iryna Kondrashchenko", email="[email protected]" },
Expand Down
3 changes: 2 additions & 1 deletion skllm/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from skllm.datasets.multi_class import get_classification_dataset
from skllm.datasets.multi_label import get_multilabel_classification_dataset
from skllm.datasets.multi_label import get_multilabel_classification_dataset
from skllm.datasets.summarization import get_summarization_dataset
15 changes: 15 additions & 0 deletions skllm/datasets/summarization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
def get_summarization_dataset():
X = [
r"The AI research company, OpenAI, has launched a new language model called GPT-4. This model is the latest in a series of transformer-based AI systems designed to perform complex tasks, such as generating human-like text, translating languages, and answering questions. According to OpenAI, GPT-4 is even more powerful and versatile than its predecessors.",
r"John went to the grocery store in the morning to prepare for a small get-together at his house. He bought fresh apples, juicy oranges, and a bottle of milk. Once back home, he used the apples and oranges to make a delicious fruit salad, which he served to his guests in the evening.",
r"The first Mars rover, named Sojourner, was launched by NASA in 1996. The mission was a part of the Mars Pathfinder project and was a major success. The data Sojourner provided about Martian terrain and atmosphere greatly contributed to our understanding of the Red Planet.",
r"A new study suggests that regular exercise can improve memory and cognitive function in older adults. The study, which monitored the health and habits of 500 older adults, recommends 30 minutes of moderate exercise daily for the best results.",
r"The Eiffel Tower, a globally recognized symbol of Paris and France, was completed in 1889 for the World's Fair. Despite its initial criticism and controversy over its unconventional design, the Eiffel Tower has become a beloved landmark and a symbol of French architectural innovation.",
r"Microsoft has announced a new version of its flagship operating system, Windows. The update, which will be rolled out later this year, features improved security protocols and a redesigned user interface, promising a more streamlined and safer user experience.",
r"The WHO declared a new public health emergency due to an outbreak of a previously unknown virus. As the number of cases grows globally, the organization urges nations to ramp up their disease surveillance and response systems.",
r"The 2024 Olympics have been confirmed to take place in Paris, France. This marks the third time the city will host the games, with previous occasions in 1900 and 1924. Preparations are already underway to make the event a grand spectacle.",
r"Apple has introduced its latest iPhone model. The new device boasts a range of features, including an improved camera system, a faster processor, and a longer battery life. It is set to hit the market later this year.",
r"Scientists working in the Amazon rainforest have discovered a new species of bird. The bird, characterized by its unique bright plumage, has created excitement among the global ornithological community."
]

return X
50 changes: 50 additions & 0 deletions skllm/openai/base_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import numpy as np
from numpy import ndarray
import pandas as pd
from skllm.utils import to_numpy as _to_numpy
from typing import Any, Optional, Union, List
from skllm.openai.mixin import OpenAIMixin as _OAIMixin
from tqdm import tqdm
from sklearn.base import (
BaseEstimator as _BaseEstimator,
TransformerMixin as _TransformerMixin,
)
from skllm.openai.chatgpt import (
construct_message,
get_chat_completion,
)

class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin):

system_msg = "You are an scikit-learn transformer."
default_output = "Output is unavailable"

def _get_chat_completion(self, X):
prompt = self._get_prompt(X)
msgs = []
msgs.append(construct_message("system", self.system_msg))
msgs.append(construct_message("user", prompt))
completion = get_chat_completion(
msgs, self._get_openai_key(), self._get_openai_org(), self.openai_model
)
try:
return completion.choices[0].message["content"]
except Exception as e:
print(f"Skipping a sample due to the following error: {str(e)}")
return self.default_output

def fit(self, X: Any = None, y: Any = None, **kwargs: Any):
return self

def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray:
X = _to_numpy(X)
transformed = []
for i in tqdm(range(len(X))):
transformed.append(
self._get_chat_completion(X[i])
)
transformed = np.asarray(transformed, dtype = object)
return transformed

def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray:
return self.fit(X, y).transform(X)
3 changes: 2 additions & 1 deletion skllm/openai/chatgpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from time import sleep
import json
from skllm.openai.credentials import set_credentials

Expand All @@ -20,7 +21,7 @@ def get_chat_completion(messages, key, org, model="gpt-3.5-turbo", max_retries =
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
continue
sleep(3)
print(f"Could not obtain the completion after {max_retries} retries: `{error_type} :: {error_msg}`")

def extract_json_key(json_, key):
Expand Down
3 changes: 2 additions & 1 deletion skllm/openai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import openai
from time import sleep
from skllm.openai.credentials import set_credentials

def get_embedding(
Expand All @@ -19,7 +20,7 @@ def get_embedding(
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
continue
sleep(3)
raise RuntimeError(
f"Could not obtain the embedding after {max_retries} retries: `{error_type} :: {error_msg}`"
)
17 changes: 14 additions & 3 deletions skllm/openai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_zero_shot_prompt_slc(x, labels):
"2. Assign the provided text to that category.",
"3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON.",
"\n",
f"List of categories: {repr(labels)}"
f"List of categories: {repr(labels)}",
"\n",
f"Text sample: ```{x}```",
"\n",
Expand All @@ -27,13 +27,24 @@ def get_zero_shot_prompt_mlc(x, labels, max_cats):
"Perform the following tasks:",
"1. Identify to which categories the provided text belongs to with the highest probability.",
f"2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.",
"3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON."
"3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the list of assigned categories. Do not provide any additional information except the JSON.",
"\n",
f"List of categories: {repr(labels)}"
f"List of categories: {repr(labels)}",
"\n",
f"Text sample: ```{x}```",
"\n",
"Your JSON response: "
]
prompt = "\n".join(lines)
return prompt

def get_summary_prompt(x, max_words):
lines = [
"Your task is to generate a summary of the text sample.",
f"Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.",
"\n",
f"Text sample: ```{x}```",
f"Summarized text:"
]
prompt = "\n".join(lines)
return prompt
3 changes: 2 additions & 1 deletion skllm/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from skllm.preprocessing.gpt_vectorizer import GPTVectorizer
from skllm.preprocessing.gpt_vectorizer import GPTVectorizer
from skllm.preprocessing.gpt_summarizer import GPTSummarizer
22 changes: 22 additions & 0 deletions skllm/preprocessing/gpt_summarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from skllm.openai.prompts import get_summary_prompt
from typing import Optional
from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT

class GPTSummarizer(_BaseGPT):

system_msg = "You are a text summarizer."
default_output = "Summary is unavailable."

def __init__(
self,
openai_key: Optional[str] = None,
openai_org: Optional[str] = None,
openai_model: str = "gpt-3.5-turbo",
max_words: int = 15,
):
self._set_keys(openai_key, openai_org)
self.openai_model = openai_model
self.max_words = max_words

def _get_prompt(self, X) -> str:
return get_summary_prompt(X, self.max_words)
6 changes: 3 additions & 3 deletions skllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

def to_numpy(X):
if isinstance(X, pd.Series):
X = X.to_numpy()
X = X.to_numpy().astype(object)
elif isinstance(X, list):
X = np.asarray(X)
if isinstance(X, np.ndarray):
X = np.asarray(X, dtype = object)
if isinstance(X, np.ndarray) and len(X.shape) > 1:
X = np.squeeze(X)
return X

0 comments on commit 037b747

Please sign in to comment.