Skip to content

Commit

Permalink
Merge branch 'main' into feature_focused_summarizer
Browse files Browse the repository at this point in the history
  • Loading branch information
iryna-kondr authored Jun 15, 2023
2 parents cff196a + 66de4ab commit 4a90774
Show file tree
Hide file tree
Showing 28 changed files with 1,096 additions and 119 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ repos:
- id: check-executables-have-shebangs
- id: check-case-conflict
- id: check-added-large-files
- id: detect-aws-credentials
- id: detect-private-key
# Formatter for Json and Yaml files
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ There are several ways you can contribute to this project:
**Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix.

> ### Legal Notice <!-- omit in toc -->
>
> When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license.
## Development dependencies

In order to install all development dependencies, run the following command:

```shell
pip install -e ".[dev]"
pip install -r requirements-dev.txt
```

To ensure that you follow the development workflow, please setup the pre-commit hooks:
Expand Down
74 changes: 58 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ You can support the project in the following ways:

- ⭐ Star Scikit-LLM on GitHub (click the star button in the top right corner)
- 🐦 Check out our related project - [Falcon AutoML](https://github.com/OKUA1/falcon)
- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section
- 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section or [Discord](https://discord.gg/YDAbwuWK7V)
- 🔗 Post about Scikit-LLM on LinkedIn or other platforms

## Documentation 📚

### Configuring OpenAI API Key

At the moment Scikit-LLM is only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required.
At the moment the majority of the Scikit-LLM estimators are only compatible with some of the OpenAI models. Hence, a user-provided OpenAI API key is required.

```python
from skllm.config import SKLLMConfig
Expand All @@ -39,6 +39,41 @@ SKLLMConfig.set_openai_org("<YOUR_ORGANISATION>")
- If you have a free trial OpenAI account, the [rate limits](https://platform.openai.com/docs/guides/rate-limits/overview) are not sufficient (specifically 3 requests per minute). Please switch to the "pay as you go" plan first.
- When calling `SKLLMConfig.set_openai_org`, you have to provide your organization ID and **NOT** the name. You can find your ID [here](https://platform.openai.com/account/org-settings).

### Using GPT4ALL

In addition to OpenAI, some of the models can use [gpt4all](https://gpt4all.io/index.html) as a backend.

**This feature is considered higly experimental!**

In order to use gpt4all, you need to install the corresponding submodule:

```bash
pip install "scikit-llm[gpt4all]"
```

In order to switch from OpenAI to GPT4ALL model, simply provide a string of the format `gpt4all::<model_name>` as an argument. While the model runs completely locally, the estimator still treats it as an OpenAI endpoint and will try to check that the API key is present. You can provide any string as a key.

```python
SKLLMConfig.set_openai_key("any string")
SKLLMConfig.set_openai_org("any string")

ZeroShotGPTClassifier(openai_model="gpt4all::ggml-gpt4all-j-v1.3-groovy")
```

When running for the first time, the model file will be downloaded automatially.

At the moment only the following estimators support gpt4all as a backend:

- `ZeroShotGPTClassifier`
- `MultiLabelZeroShotGPTClassifier`
- `FewShotGPTClassifier`

When using gpt4all please keep the following in mind:

1. Not all gpt4all models are commercially licensable, please consult gpt4all website for more details.
2. The accuracy of the models may be much lower compared to ones provided by OpenAI (especially gpt-4).
3. Not all of the available models were tested, some may not work with scikit-llm at all.

### Zero-Shot Text Classification

One of the powerful ChatGPT features is the ability to perform text classification without being re-trained. For that, the only requirement is that the labels must be descriptive.
Expand Down Expand Up @@ -145,6 +180,27 @@ While the api remains the same as for the zero shot classifier, there are a few

Note: as the model is not being re-trained, but uses the training data during inference, one could say that this is still a (different) zero-shot approach.

### Dynamic Few-Shot Text Classification

`DynamicFewShotGPTClassifier` dynamically selects N samples per class to include in the prompt. This allows the few-shot classifier to scale to datasets that are too large for the standard context window of LLMs.

*How does it work?*

During fitting, the whole dataset is partitioned by class, vectorized, and stored.

During inference, the [annoy](https://github.com/spotify/annoy) library is used for fast neighbor lookup, which allows including only the most similar examples in the prompt.

```python
from skllm import DynamicFewShotGPTClassifier
from skllm.datasets import get_classification_dataset

X, y = get_classification_dataset()

clf = DynamicFewShotGPTClassifier(n_examples=3)
clf.fit(X, y)
labels = clf.predict(X)
```

### Text Vectorization

As an alternative to using GPT as a classifier, it can be used solely for data preprocessing. `GPTVectorizer` allows to embed a chunk of text of arbitrary length to a fixed-dimensional vector, that can be used with virtually any classification or regression model.
Expand Down Expand Up @@ -225,17 +281,3 @@ translated_text = t.fit_transform(X)
- [ ] Open source models

*The order of the elements in the roadmap is arbitrary and does not reflect the planned order of implementation.*

## Contributing

In order to install all development dependencies, run the following command:

```shell
pip install -e ".[dev]"
```

To ensure that you follow the development workflow, please setup the pre-commit hooks:

```shell
pre-commit install
```
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ dependencies = [
"pandas>=1.5.0",
"openai>=0.27.0",
"tqdm>=4.60.0",
"annoy>=1.17.2",
]
name = "scikit-llm"
version = "0.1.0b3"
version = "0.2.0"
authors = [
{ name="Oleg Kostromin", email="[email protected]" },
{ name="Iryna Kondrashchenko", email="[email protected]" },
Expand All @@ -24,10 +25,9 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dynamic = ["optional-dependencies"]

[tool.setuptools.dynamic.optional-dependencies]
dev = { file = ["requirements-dev.txt"] }
[project.optional-dependencies]
gpt4all = ["gpt4all>=0.2.0"]

[tool.ruff]
select = [
Expand Down Expand Up @@ -80,12 +80,13 @@ target-version = ['py38', 'py39', 'py310', 'py311']
profile = "black"
filter_files = true
known_first_party = ["skllm", "skllm.*"]
skip = ["__init__.py"]

[tool.docformatter]
close-quotes-on-newline = true # D209

[tool.interrogate]
fail-under = 80
fail-under = 65
ignore-module = true
ignore-nested-functions = true
ignore-private = true
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ isort
ruff
docformatter
interrogate
numpy
pandas
4 changes: 3 additions & 1 deletion skllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
# ordering is important here to prevent circular imports
from skllm.models.gpt_zero_shot_clf import (
MultiLabelZeroShotGPTClassifier,
ZeroShotGPTClassifier,
)
from skllm.models.gpt_few_shot_clf import FewShotGPTClassifier
from skllm.models.gpt_dyn_few_shot_clf import DynamicFewShotGPTClassifier
16 changes: 16 additions & 0 deletions skllm/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from skllm.gpt4all_client import get_chat_completion as _g4a_get_chat_completion
from skllm.openai.chatgpt import get_chat_completion as _oai_get_chat_completion


def get_chat_completion(
messages: dict, openai_key: str=None, openai_org: str=None, model: str="gpt-3.5-turbo", max_retries: int=3
):
"""
Gets a chat completion from the OpenAI API.
"""
if model.startswith("gpt4all::"):
return _g4a_get_chat_completion(messages, model[9:])
else:
return _oai_get_chat_completion(
messages, openai_key, openai_org, model, max_retries
)
40 changes: 40 additions & 0 deletions skllm/gpt4all_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict

try:
from gpt4all import GPT4All
except (ImportError, ModuleNotFoundError):
GPT4All = None

_loaded_models = {}


def get_chat_completion(messages: Dict, model: str="ggml-gpt4all-j-v1.3-groovy") -> Dict:
"""
Gets a chat completion from GPT4All
Parameters
----------
messages : Dict
The messages to use as a prompt for the chat completion.
model : str
The model to use for the chat completion. Defaults to "ggml-gpt4all-j-v1.3-groovy".
Returns
-------
completion : Dict
"""
if GPT4All is None:
raise ImportError(
"gpt4all is not installed, try `pip install scikit-llm[gpt4all]`"
)
if model not in _loaded_models.keys():
_loaded_models[model] = GPT4All(model)

return _loaded_models[model].chat_completion(
messages, verbose=False, streaming=False, temp=1e-10
)


def unload_models() -> None:
global _loaded_models
_loaded_models = {}
1 change: 1 addition & 0 deletions skllm/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from skllm.memory._annoy import AnnoyMemoryIndex
119 changes: 119 additions & 0 deletions skllm/memory/_annoy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import tempfile
from typing import Any, List

from annoy import AnnoyIndex
from numpy import ndarray

from skllm.memory.base import _BaseMemoryIndex


class AnnoyMemoryIndex(_BaseMemoryIndex):
"""Memory index using Annoy.
Parameters
----------
dim : int
dimensionality of the vectors
metric : str, optional
metric to use, by default "euclidean"
"""

def __init__(self, dim: int, metric: str = "euclidean", **kwargs: Any) -> None:
self._index = AnnoyIndex(dim, metric)
self.metric = metric
self.dim = dim
self.built = False

def add(self, id: int, vector: ndarray) -> None:
"""Adds a vector to the index.
Parameters
----------
id : Any
identifier for the vector
vector : ndarray
vector to add to the index
"""
if self.built:
raise RuntimeError("Cannot add vectors after index is built.")
self._index.add_item(id, vector)

def build(self) -> None:
"""Builds the index.
No new vectors can be added after building.
"""
self._index.build(-1)
self.built = True

def retrieve(self, vectors: ndarray, k: int) -> List[List[int]]:
"""Retrieves the k nearest neighbors for each vector.
Parameters
----------
vectors : ndarray
vectors to retrieve nearest neighbors for
k : int
number of nearest neighbors to retrieve
Returns
-------
List
ids of retrieved nearest neighbors
"""
if not self.built:
raise RuntimeError("Cannot retrieve vectors before the index is built.")
return [
self._index.get_nns_by_vector(v, k, search_k=-1, include_distances=False)
for v in vectors
]

def __getstate__(self) -> dict:
"""Returns the state of the object. To store the actual annoy index, it
has to be written to a temporary file.
Returns
-------
dict
state of the object
"""
state = self.__dict__.copy()

# save index to temporary file
with tempfile.NamedTemporaryFile(delete=False) as tmp:
temp_filename = tmp.name
self._index.save(temp_filename)

# read bytes from the file
with open(temp_filename, "rb") as tmp:
index_bytes = tmp.read()

# store bytes representation in state
state["_index"] = index_bytes

# remove temporary file
os.remove(temp_filename)

return state

def __setstate__(self, state: dict) -> None:
"""Sets the state of the object. It restores the annoy index from the
bytes representation.
Parameters
----------
state : dict
state of the object
"""
self.__dict__.update(state)
# restore index from bytes
with tempfile.NamedTemporaryFile(delete=False) as tmp:
temp_filename = tmp.name
tmp.write(self._index)

self._index = AnnoyIndex(self.dim, self.metric)
self._index.load(temp_filename)

# remove temporary file
os.remove(temp_filename)
Loading

0 comments on commit 4a90774

Please sign in to comment.