forked from mmz-001/knowledge_gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request mmz-001#6 from mmz-001/stage
[Fix] add exponential backoff for embeddings
- Loading branch information
Showing
6 changed files
with
152 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Wrapper around OpenAI embedding models.""" | ||
from typing import Any, Dict, List, Optional | ||
|
||
from pydantic import BaseModel, Extra, root_validator | ||
|
||
from langchain.embeddings.base import Embeddings | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
from tenacity import ( | ||
retry, | ||
retry_if_exception_type, | ||
stop_after_attempt, | ||
wait_exponential, | ||
) | ||
from openai.error import Timeout, APIError, APIConnectionError, RateLimitError | ||
|
||
|
||
class OpenAIEmbeddings(BaseModel, Embeddings): | ||
"""Wrapper around OpenAI embedding models. | ||
To use, you should have the ``openai`` python package installed, and the | ||
environment variable ``OPENAI_API_KEY`` set with your API key or pass it | ||
as a named parameter to the constructor. | ||
Example: | ||
.. code-block:: python | ||
from langchain.embeddings import OpenAIEmbeddings | ||
openai = OpenAIEmbeddings(openai_api_key="my-api-key") | ||
""" | ||
|
||
client: Any #: :meta private: | ||
document_model_name: str = "text-embedding-ada-002" | ||
query_model_name: str = "text-embedding-ada-002" | ||
openai_api_key: Optional[str] = None | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
extra = Extra.forbid | ||
|
||
# TODO: deprecate this | ||
@root_validator(pre=True, allow_reuse=True) | ||
def get_model_names(cls, values: Dict) -> Dict: | ||
"""Get model names from just old model name.""" | ||
if "model_name" in values: | ||
if "document_model_name" in values: | ||
raise ValueError( | ||
"Both `model_name` and `document_model_name` were provided, " | ||
"but only one should be." | ||
) | ||
if "query_model_name" in values: | ||
raise ValueError( | ||
"Both `model_name` and `query_model_name` were provided, " | ||
"but only one should be." | ||
) | ||
model_name = values.pop("model_name") | ||
values["document_model_name"] = f"text-search-{model_name}-doc-001" | ||
values["query_model_name"] = f"text-search-{model_name}-query-001" | ||
return values | ||
|
||
@root_validator(allow_reuse=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and python package exists in environment.""" | ||
openai_api_key = get_from_dict_or_env( | ||
values, "openai_api_key", "OPENAI_API_KEY" | ||
) | ||
try: | ||
import openai | ||
|
||
openai.api_key = openai_api_key | ||
values["client"] = openai.Embedding | ||
except ImportError: | ||
raise ValueError( | ||
"Could not import openai python package. " | ||
"Please it install it with `pip install openai`." | ||
) | ||
return values | ||
|
||
@retry( | ||
reraise=True, | ||
stop=stop_after_attempt(100), | ||
wait=wait_exponential(multiplier=1, min=10, max=60), | ||
retry=( | ||
retry_if_exception_type(Timeout) | ||
| retry_if_exception_type(APIError) | ||
| retry_if_exception_type(APIConnectionError) | ||
| retry_if_exception_type(RateLimitError) | ||
), | ||
) | ||
def _embedding_func(self, text: str, *, engine: str) -> List[float]: | ||
"""Call out to OpenAI's embedding endpoint with exponential backoff.""" | ||
# replace newlines, which can negatively affect performance. | ||
text = text.replace("\n", " ") | ||
return self.client.create(input=[text], engine=engine)["data"][0]["embedding"] | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Call out to OpenAI's embedding endpoint for embedding search docs. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
responses = [ | ||
self._embedding_func(text, engine=self.document_model_name) | ||
for text in texts | ||
] | ||
return responses | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Call out to OpenAI's embedding endpoint for embedding query text. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
embedding = self._embedding_func(text, engine=self.query_model_name) | ||
return embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters