Skip to content

Commit

Permalink
Refactoring vectordb naming convention in embedchain.config (#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
vatsalrathod16 authored Jul 8, 2024
1 parent 1a5d0d2 commit 83e8c97
Show file tree
Hide file tree
Showing 20 changed files with 124 additions and 124 deletions.
8 changes: 4 additions & 4 deletions embedchain/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .embedder.base import BaseEmbedderConfig as EmbedderConfig
from .embedder.ollama import OllamaEmbedderConfig
from .llm.base import BaseLlmConfig
from .vectordb.chroma import ChromaDbConfig
from .vectordb.elasticsearch import ElasticsearchDBConfig
from .vectordb.opensearch import OpenSearchDBConfig
from .vectordb.zilliz import ZillizDBConfig
from .vector_db.chroma import ChromaDbConfig
from .vector_db.elasticsearch import ElasticsearchDBConfig
from .vector_db.opensearch import OpenSearchDBConfig
from .vector_db.zilliz import ZillizDBConfig
from .mem0_config import Mem0Config
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,56 +1,56 @@
import os
from typing import Optional, Union

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")

self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)
import os
from typing import Optional, Union

from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


@register_deserializable
class ElasticsearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
es_url: Union[str, list[str]] = None,
cloud_id: Optional[str] = None,
batch_size: Optional[int] = 100,
**ES_EXTRA_PARAMS: dict[str, any],
):
"""
Initializes a configuration class instance for an Elasticsearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param es_url: elasticsearch url or list of nodes url to be used for connection, defaults to None
:type es_url: Union[str, list[str]], optional
:param cloud_id: cloud id of the elasticsearch cluster, defaults to None
:type cloud_id: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
:type ES_EXTRA_PARAMS: dict[str, Any], optional
"""
if es_url and cloud_id:
raise ValueError("Only one of `es_url` and `cloud_id` can be set.")
# self, es_url: Union[str, list[str]] = None, **ES_EXTRA_PARAMS: dict[str, any]):
self.ES_URL = es_url or os.environ.get("ELASTICSEARCH_URL")
self.CLOUD_ID = cloud_id or os.environ.get("ELASTICSEARCH_CLOUD_ID")
if not self.ES_URL and not self.CLOUD_ID:
raise AttributeError(
"Elasticsearch needs a URL or CLOUD_ID attribute, "
"this can either be passed to `ElasticsearchDBConfig` or as `ELASTICSEARCH_URL` or `ELASTICSEARCH_CLOUD_ID` in `.env`" # noqa: E501
)
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
# Load API key from .env if it's not explicitly passed.
# Can only set one of 'api_key', 'basic_auth', and 'bearer_auth'
if (
not self.ES_EXTRA_PARAMS.get("api_key")
and not self.ES_EXTRA_PARAMS.get("basic_auth")
and not self.ES_EXTRA_PARAMS.get("bearer_auth")
):
self.ES_EXTRA_PARAMS["api_key"] = os.environ.get("ELASTICSEARCH_API_KEY")

self.batch_size = batch_size
super().__init__(collection_name=collection_name, dir=dir)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size

super().__init__(collection_name=collection_name, dir=dir)
from typing import Optional

from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


@register_deserializable
class OpenSearchDBConfig(BaseVectorDbConfig):
def __init__(
self,
opensearch_url: str,
http_auth: tuple[str, str],
vector_dimension: int = 1536,
collection_name: Optional[str] = None,
dir: Optional[str] = None,
batch_size: Optional[int] = 100,
**extra_params: dict[str, any],
):
"""
Initializes a configuration class instance for an OpenSearch client.
:param collection_name: Default name for the collection, defaults to None
:type collection_name: Optional[str], optional
:param opensearch_url: URL of the OpenSearch domain
:type opensearch_url: str, Eg, "http://localhost:9200"
:param http_auth: Tuple of username and password
:type http_auth: tuple[str, str], Eg, ("username", "password")
:param vector_dimension: Dimension of the vector, defaults to 1536 (openai embedding model)
:type vector_dimension: int, optional
:param dir: Path to the database directory, where the database is stored, defaults to None
:type dir: Optional[str], optional
:param batch_size: Number of items to insert in one batch, defaults to 100
:type batch_size: Optional[int], optional
"""
self.opensearch_url = opensearch_url
self.http_auth = http_auth
self.vector_dimension = vector_dimension
self.extra_params = extra_params
self.batch_size = batch_size

super().__init__(collection_name=collection_name, dir=dir)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Optional

from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.helpers.json_serializable import register_deserializable


Expand Down
16 changes: 8 additions & 8 deletions embedchain/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ class VectorDBFactory:
"zilliz": "embedchain.vectordb.zilliz.ZillizVectorDB",
}
provider_to_config_class = {
"chroma": "embedchain.config.vectordb.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vectordb.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vectordb.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vectordb.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vectordb.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vectordb.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vectordb.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vectordb.zilliz.ZillizDBConfig",
"chroma": "embedchain.config.vector_db.chroma.ChromaDbConfig",
"elasticsearch": "embedchain.config.vector_db.elasticsearch.ElasticsearchDBConfig",
"opensearch": "embedchain.config.vector_db.opensearch.OpenSearchDBConfig",
"lancedb": "embedchain.config.vector_db.lancedb.LanceDBConfig",
"pinecone": "embedchain.config.vector_db.pinecone.PineconeDBConfig",
"qdrant": "embedchain.config.vector_db.qdrant.QdrantDBConfig",
"weaviate": "embedchain.config.vector_db.weaviate.WeaviateDBConfig",
"zilliz": "embedchain.config.vector_db.zilliz.ZillizDBConfig",
}

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion embedchain/vectordb/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from embedchain.config.vectordb.base import BaseVectorDbConfig
from embedchain.config.vector_db.base import BaseVectorDbConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.helpers.json_serializable import JSONSerializable

Expand Down
2 changes: 1 addition & 1 deletion embedchain/vectordb/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
except ImportError:
raise ImportError('LanceDB is required. Install with pip install "embedchain[lancedb]"') from None

from embedchain.config.vectordb.lancedb import LanceDBConfig
from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB

Expand Down
2 changes: 1 addition & 1 deletion embedchain/vectordb/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from pinecone_text.sparse import BM25Encoder

from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.utils.misc import chunks
from embedchain.vectordb.base import BaseVectorDB
Expand Down
2 changes: 1 addition & 1 deletion embedchain/vectordb/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tqdm import tqdm

from embedchain.config.vectordb.qdrant import QdrantDBConfig
from embedchain.config.vector_db.qdrant import QdrantDBConfig
from embedchain.vectordb.base import BaseVectorDB


Expand Down
2 changes: 1 addition & 1 deletion embedchain/vectordb/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"Weaviate requires extra dependencies. Install with `pip install --upgrade 'embedchain[weaviate]'`"
) from None

from embedchain.config.vectordb.weaviate import WeaviateDBConfig
from embedchain.config.vector_db.weaviate import WeaviateDBConfig
from embedchain.helpers.json_serializable import register_deserializable
from embedchain.vectordb.base import BaseVectorDB

Expand Down
2 changes: 1 addition & 1 deletion tests/vectordb/test_lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.lancedb import LanceDBConfig
from embedchain.config.vector_db.lancedb import LanceDBConfig
from embedchain.vectordb.lancedb import LanceDB

os.environ["OPENAI_API_KEY"] = "test-api-key"
Expand Down
2 changes: 1 addition & 1 deletion tests/vectordb/test_pinecone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.vectordb.pinecone import PineconeDB


Expand Down
2 changes: 1 addition & 1 deletion tests/vectordb/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.qdrant import QdrantDB

Expand Down
2 changes: 1 addition & 1 deletion tests/vectordb/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from embedchain import App
from embedchain.config import AppConfig
from embedchain.config.vectordb.pinecone import PineconeDBConfig
from embedchain.config.vector_db.pinecone import PineconeDBConfig
from embedchain.embedder.base import BaseEmbedder
from embedchain.vectordb.weaviate import WeaviateDB

Expand Down

0 comments on commit 83e8c97

Please sign in to comment.