forked from BeastByteAI/scikit-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d3e8144
commit 037b747
Showing
11 changed files
with
129 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ dependencies = [ | |
"tqdm>=4.60.0", | ||
] | ||
name = "scikit-llm" | ||
version = "0.1.0b2" | ||
version = "0.1.0b3" | ||
authors = [ | ||
{ name="Oleg Kostromin", email="[email protected]" }, | ||
{ name="Iryna Kondrashchenko", email="[email protected]" }, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from skllm.datasets.multi_class import get_classification_dataset | ||
from skllm.datasets.multi_label import get_multilabel_classification_dataset | ||
from skllm.datasets.multi_label import get_multilabel_classification_dataset | ||
from skllm.datasets.summarization import get_summarization_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
def get_summarization_dataset(): | ||
X = [ | ||
r"The AI research company, OpenAI, has launched a new language model called GPT-4. This model is the latest in a series of transformer-based AI systems designed to perform complex tasks, such as generating human-like text, translating languages, and answering questions. According to OpenAI, GPT-4 is even more powerful and versatile than its predecessors.", | ||
r"John went to the grocery store in the morning to prepare for a small get-together at his house. He bought fresh apples, juicy oranges, and a bottle of milk. Once back home, he used the apples and oranges to make a delicious fruit salad, which he served to his guests in the evening.", | ||
r"The first Mars rover, named Sojourner, was launched by NASA in 1996. The mission was a part of the Mars Pathfinder project and was a major success. The data Sojourner provided about Martian terrain and atmosphere greatly contributed to our understanding of the Red Planet.", | ||
r"A new study suggests that regular exercise can improve memory and cognitive function in older adults. The study, which monitored the health and habits of 500 older adults, recommends 30 minutes of moderate exercise daily for the best results.", | ||
r"The Eiffel Tower, a globally recognized symbol of Paris and France, was completed in 1889 for the World's Fair. Despite its initial criticism and controversy over its unconventional design, the Eiffel Tower has become a beloved landmark and a symbol of French architectural innovation.", | ||
r"Microsoft has announced a new version of its flagship operating system, Windows. The update, which will be rolled out later this year, features improved security protocols and a redesigned user interface, promising a more streamlined and safer user experience.", | ||
r"The WHO declared a new public health emergency due to an outbreak of a previously unknown virus. As the number of cases grows globally, the organization urges nations to ramp up their disease surveillance and response systems.", | ||
r"The 2024 Olympics have been confirmed to take place in Paris, France. This marks the third time the city will host the games, with previous occasions in 1900 and 1924. Preparations are already underway to make the event a grand spectacle.", | ||
r"Apple has introduced its latest iPhone model. The new device boasts a range of features, including an improved camera system, a faster processor, and a longer battery life. It is set to hit the market later this year.", | ||
r"Scientists working in the Amazon rainforest have discovered a new species of bird. The bird, characterized by its unique bright plumage, has created excitement among the global ornithological community." | ||
] | ||
|
||
return X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import numpy as np | ||
from numpy import ndarray | ||
import pandas as pd | ||
from skllm.utils import to_numpy as _to_numpy | ||
from typing import Any, Optional, Union, List | ||
from skllm.openai.mixin import OpenAIMixin as _OAIMixin | ||
from tqdm import tqdm | ||
from sklearn.base import ( | ||
BaseEstimator as _BaseEstimator, | ||
TransformerMixin as _TransformerMixin, | ||
) | ||
from skllm.openai.chatgpt import ( | ||
construct_message, | ||
get_chat_completion, | ||
) | ||
|
||
class BaseZeroShotGPTTransformer(_BaseEstimator, _TransformerMixin, _OAIMixin): | ||
|
||
system_msg = "You are an scikit-learn transformer." | ||
default_output = "Output is unavailable" | ||
|
||
def _get_chat_completion(self, X): | ||
prompt = self._get_prompt(X) | ||
msgs = [] | ||
msgs.append(construct_message("system", self.system_msg)) | ||
msgs.append(construct_message("user", prompt)) | ||
completion = get_chat_completion( | ||
msgs, self._get_openai_key(), self._get_openai_org(), self.openai_model | ||
) | ||
try: | ||
return completion.choices[0].message["content"] | ||
except Exception as e: | ||
print(f"Skipping a sample due to the following error: {str(e)}") | ||
return self.default_output | ||
|
||
def fit(self, X: Any = None, y: Any = None, **kwargs: Any): | ||
return self | ||
|
||
def transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], **kwargs: Any) -> ndarray: | ||
X = _to_numpy(X) | ||
transformed = [] | ||
for i in tqdm(range(len(X))): | ||
transformed.append( | ||
self._get_chat_completion(X[i]) | ||
) | ||
transformed = np.asarray(transformed, dtype = object) | ||
return transformed | ||
|
||
def fit_transform(self, X: Optional[Union[np.ndarray, pd.Series, List[str]]], y=None, **fit_params) -> ndarray: | ||
return self.fit(X, y).transform(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from skllm.preprocessing.gpt_vectorizer import GPTVectorizer | ||
from skllm.preprocessing.gpt_vectorizer import GPTVectorizer | ||
from skllm.preprocessing.gpt_summarizer import GPTSummarizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from skllm.openai.prompts import get_summary_prompt | ||
from typing import Optional | ||
from skllm.openai.base_gpt import BaseZeroShotGPTTransformer as _BaseGPT | ||
|
||
class GPTSummarizer(_BaseGPT): | ||
|
||
system_msg = "You are a text summarizer." | ||
default_output = "Summary is unavailable." | ||
|
||
def __init__( | ||
self, | ||
openai_key: Optional[str] = None, | ||
openai_org: Optional[str] = None, | ||
openai_model: str = "gpt-3.5-turbo", | ||
max_words: int = 15, | ||
): | ||
self._set_keys(openai_key, openai_org) | ||
self.openai_model = openai_model | ||
self.max_words = max_words | ||
|
||
def _get_prompt(self, X) -> str: | ||
return get_summary_prompt(X, self.max_words) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters