Skip to content

Commit

Permalink
Merge pull request mmz-001#6 from mmz-001/stage
Browse files Browse the repository at this point in the history
[Fix] add exponential backoff for embeddings
  • Loading branch information
mmz-001 authored Feb 8, 2023
2 parents d7cce7b + 1d16ffd commit a2c60ba
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 9 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ data/local_data/
# VSCode
.vscode/

# TODO
TODO.md

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
122 changes: 122 additions & 0 deletions knowledge_gpt/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Wrapper around OpenAI embedding models."""
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Extra, root_validator

from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env

from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from openai.error import Timeout, APIError, APIConnectionError, RateLimitError


class OpenAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around OpenAI embedding models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``OPENAI_API_KEY`` set with your API key or pass it
as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.embeddings import OpenAIEmbeddings
openai = OpenAIEmbeddings(openai_api_key="my-api-key")
"""

client: Any #: :meta private:
document_model_name: str = "text-embedding-ada-002"
query_model_name: str = "text-embedding-ada-002"
openai_api_key: Optional[str] = None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

# TODO: deprecate this
@root_validator(pre=True, allow_reuse=True)
def get_model_names(cls, values: Dict) -> Dict:
"""Get model names from just old model name."""
if "model_name" in values:
if "document_model_name" in values:
raise ValueError(
"Both `model_name` and `document_model_name` were provided, "
"but only one should be."
)
if "query_model_name" in values:
raise ValueError(
"Both `model_name` and `query_model_name` were provided, "
"but only one should be."
)
model_name = values.pop("model_name")
values["document_model_name"] = f"text-search-{model_name}-doc-001"
values["query_model_name"] = f"text-search-{model_name}-query-001"
return values

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
try:
import openai

openai.api_key = openai_api_key
values["client"] = openai.Embedding
except ImportError:
raise ValueError(
"Could not import openai python package. "
"Please it install it with `pip install openai`."
)
return values

@retry(
reraise=True,
stop=stop_after_attempt(100),
wait=wait_exponential(multiplier=1, min=10, max=60),
retry=(
retry_if_exception_type(Timeout)
| retry_if_exception_type(APIError)
| retry_if_exception_type(APIConnectionError)
| retry_if_exception_type(RateLimitError)
),
)
def _embedding_func(self, text: str, *, engine: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint with exponential backoff."""
# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to OpenAI's embedding endpoint for embedding search docs.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
responses = [
self._embedding_func(text, engine=self.document_model_name)
for text in texts
]
return responses

def embed_query(self, text: str) -> List[float]:
"""Call out to OpenAI's embedding endpoint for embedding query text.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text, engine=self.query_model_name)
return embedding
3 changes: 2 additions & 1 deletion knowledge_gpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def set_openai_api_key(api_key: str):
raise ValueError("File type not supported!")
text = text_to_docs(doc)
try:
index = embed_docs(text)
with st.spinner("Indexing document... This may take a while⏳"):
index = embed_docs(text)
st.session_state["api_key_configured"] = True
except OpenAIError as e:
st.error(e._message)
Expand Down
4 changes: 2 additions & 2 deletions knowledge_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain.vectorstores.faiss import FAISS
from langchain import OpenAI, Cohere
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain.docstore.document import Document
from langchain.vectorstores import FAISS, VectorStore
Expand Down Expand Up @@ -83,7 +83,7 @@ def text_to_docs(text: str | List[str]) -> List[Document]:
return doc_chunks


@st.cache(allow_output_mutation=True)
@st.cache(allow_output_mutation=True, show_spinner=False)
def embed_docs(docs: List[Document]) -> VectorStore:
"""Embeds a list of Documents and returns a FAISS index"""

Expand Down
26 changes: 21 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ packages = [{include = "knowledge_gpt"}]
[tool.poetry.dependencies]
python = "^3.10"
streamlit = "^1.17.0"
langchain = "0.0.71"
langchain = "0.0.79"
cohere = "^3.2.1"
faiss-cpu = "^1.7.3"
openai = "^0.26.2"
docx2txt = "^0.8"
pillow = "^9.4.0"
pypdf = "^3.3.0"
tenacity = "^8.2.0"


[tool.poetry.group.dev.dependencies]
Expand Down

0 comments on commit a2c60ba

Please sign in to comment.