Skip to content

Commit

Permalink
extend type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
sincRK committed Apr 30, 2024
1 parent e2dadee commit 2e51494
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 71 deletions.
162 changes: 91 additions & 71 deletions goldenverba/server/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastapi import FastAPI, Request, WebSocket, status
from fastapi import FastAPI, WebSocket, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
Expand All @@ -7,7 +7,6 @@
from pathlib import Path

from dotenv import load_dotenv
from pydantic import BaseModel
from starlette.websockets import WebSocketDisconnect
from wasabi import msg # type: ignore[import]

Expand All @@ -18,6 +17,15 @@
from goldenverba.components.reader.interface import Reader
from goldenverba.components.retriever.interface import Retriever
from goldenverba.server.ConfigManager import ConfigManager
from goldenverba.server.types import (
GetComponentPayload,
SetComponentPayload,
LoadPayload,
QueryPayload,
GeneratePayload,
SearchQueryPayload,
GetDocumentPayload,
)
from goldenverba.server.util import setup_managers

load_dotenv()
Expand All @@ -40,7 +48,13 @@
generators = manager.generator_get_generator()

setup_managers(
manager, config_manager, readers, chunker, embedders, retrievers, generators
manager,
config_manager,
readers,
chunker,
embedders,
retrievers,
generators
)
config_manager.save_config()

Expand Down Expand Up @@ -134,53 +148,11 @@ def create_generator_payload(key: str, generator: Generator) -> dict:
)

# Serve the main page and other static files
app.mount("/static", StaticFiles(directory=BASE_DIR / "frontend/out"), name="app")


class QueryPayload(BaseModel):
query: str


class ConversationItem(BaseModel):
type: str
content: str
typewriter: bool


class GeneratePayload(BaseModel):
query: str
context: str
conversation: list[ConversationItem]


class SearchQueryPayload(BaseModel):
query: str
doc_type: str


class GetDocumentPayload(BaseModel):
document_id: str


class LoadPayload(BaseModel):
reader: str
chunker: str
embedder: str
fileBytes: list[str]
fileNames: list[str]
filePath: str
document_type: str
chunkUnits: int
chunkOverlap: int


class GetComponentPayload(BaseModel):
component: str


class SetComponentPayload(BaseModel):
component: str
selected_component: str
app.mount(
"/static",
StaticFiles(directory=BASE_DIR / "frontend/out"),
name="app"
)


@app.get("/")
Expand All @@ -191,7 +163,8 @@ async def serve_frontend():

@app.get("/status")
async def catch_status():
# Check if the path corresponds to a file that exists in the static directory
# Check if the path corresponds to a
# file that exists in the static directory
file_path = BASE_DIR / "frontend/out" / "status.html"
if os.path.isfile(file_path):
return FileResponse(file_path)
Expand All @@ -201,7 +174,8 @@ async def catch_status():

@app.get("/document_explorer")
async def catch_explorer():
# Check if the path corresponds to a file that exists in the static directory
# Check if the path corresponds to a
# file that exists in the static directory
file_path = BASE_DIR / "frontend/out" / "document_explorer.html"
if os.path.isfile(file_path):
return FileResponse(file_path)
Expand Down Expand Up @@ -285,13 +259,16 @@ async def get_components():
try:
data["default_values"] = {
"last_reader": create_reader_payload(
config_manager.get_reader(), readers[config_manager.get_reader()]
config_manager.get_reader(),
readers[config_manager.get_reader()]
),
"last_chunker": create_chunker_payload(
config_manager.get_chunker(), chunker[config_manager.get_chunker()]
config_manager.get_chunker(),
chunker[config_manager.get_chunker()]
),
"last_embedder": create_embedder_payload(
config_manager.get_embedder(), embedders[config_manager.get_embedder()]
config_manager.get_embedder(),
embedders[config_manager.get_embedder()]
),
"last_document_type": "Documentation",
}
Expand All @@ -301,18 +278,27 @@ async def get_components():
config_manager.default_config()
config_manager.save_config()
setup_managers(
manager, config_manager, readers, chunker, embedders, retrievers, generators
manager,
config_manager,
readers,
chunker,
embedders,
retrievers,
generators
)
config_manager.save_config()
data["default_values"] = {
"last_reader": create_reader_payload(
config_manager.get_reader(), readers[config_manager.get_reader()]
config_manager.get_reader(),
readers[config_manager.get_reader()]
),
"last_chunker": create_chunker_payload(
config_manager.get_chunker(), chunker[config_manager.get_chunker()]
config_manager.get_chunker(),
chunker[config_manager.get_chunker()]
),
"last_embedder": create_embedder_payload(
config_manager.get_embedder(), embedders[config_manager.get_embedder()]
config_manager.get_embedder(),
embedders[config_manager.get_embedder()]
),
"last_document_type": "Documentation",
}
Expand All @@ -334,7 +320,9 @@ async def get_component(payload: GetComponentPayload):

for key in embedders:
current_embedder = embedders[key]
current_embedder_data = create_embedder_payload(key, current_embedder)
current_embedder_data = create_embedder_payload(
key, current_embedder
)
data["components"].append(current_embedder_data)

elif payload.component == "retrievers":
Expand All @@ -345,7 +333,9 @@ async def get_component(payload: GetComponentPayload):

for key in retrievers:
current_retriever = retrievers[key]
current_retriever_data = create_retriever_payload(key, current_retriever)
current_retriever_data = create_retriever_payload(
key, current_retriever
)
data["components"].append(current_retriever_data)

elif payload.component == "generators":
Expand All @@ -356,7 +346,9 @@ async def get_component(payload: GetComponentPayload):

for key in generators:
current_generator = generators[key]
current_generator_data = create_generator_payload(key, current_generator)
current_generator_data = create_generator_payload(
key, current_generator
)
data["components"].append(current_generator_data)

return JSONResponse(content=data)
Expand Down Expand Up @@ -464,7 +456,10 @@ async def load_data(payload: LoadPayload):
current_chunker.default_overlap = payload.chunkOverlap

msg.info(
f"Received Data to Import: READER({payload.reader}, Documents {len(payload.fileBytes)}, Type {payload.document_type}) CHUNKER ({payload.chunker}, UNITS {payload.chunkUnits}, OVERLAP {payload.chunkOverlap}), EMBEDDER ({payload.embedder})"
f"Received Data to Import: READER({payload.reader}, "
f"Documents {len(payload.fileBytes)}, Type {payload.document_type}) "
f"CHUNKER ({payload.chunker}, UNITS {payload.chunkUnits}, "
f"OVERLAP {payload.chunkOverlap}), EMBEDDER ({payload.embedder})"
)

if payload.fileBytes or payload.filePath:
Expand All @@ -488,12 +483,17 @@ async def load_data(payload: LoadPayload):
)

document_count = len(documents)
chunks_count = sum([len(document.chunks) for document in documents])
chunks_count = sum(
[len(document.chunks) for document in documents]
)

return JSONResponse(
content={
"status": 200,
"status_msg": f"Succesfully imported {document_count} documents and {chunks_count} chunks",
"status_msg": (
f"Succesfully imported {document_count} "
f"documents and {chunks_count} chunks",
)
}
)
except Exception as e:
Expand Down Expand Up @@ -657,7 +657,7 @@ async def get_document(payload: GetDocumentPayload):
)


## Retrieve all documents imported to Weaviate
# Retrieve all documents imported to Weaviate
@app.post("/api/get_all_documents")
async def get_all_documents(payload: SearchQueryPayload):
msg.info("Get all documents request received")
Expand All @@ -672,7 +672,12 @@ async def get_all_documents(payload: SearchQueryPayload):
content={
"documents": documents,
"doc_types": list(doc_types),
"current_embedder": manager.embedder_manager.selected_embedder.name,
"current_embedder": (
manager
.embedder_manager
.selected_embedder
.name,
)
}
)
except Exception as e:
Expand All @@ -681,28 +686,43 @@ async def get_all_documents(payload: SearchQueryPayload):
content={
"documents": [],
"doc_types": [],
"current_embedder": manager.embedder_manager.selected_embedder.name,
"current_embedder": (
manager
.embedder_manager
.selected_embedder
.name,
)
}
)


## Search for documentation
# Search for documentation
@app.post("/api/search_documents")
async def search_documents(payload: SearchQueryPayload):
try:
documents = manager.search_documents(payload.query, payload.doc_type)
return JSONResponse(
content={
"documents": documents,
"current_embedder": manager.embedder_manager.selected_embedder.name,
"current_embedder": (
manager
.embedder_manager
.selected_embedder
.name,
)
}
)
except Exception as e:
msg.fail(f"All Document retrieval failed: {str(e)}")
return JSONResponse(
content={
"documents": [],
"current_embedder": manager.embedder_manager.selected_embedder.name,
"current_embedder": (
manager
.embedder_manager
.selected_embedder
.name,
)
}
)

Expand Down
54 changes: 54 additions & 0 deletions goldenverba/server/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from pydantic import BaseModel


class QueryPayload(BaseModel):
query: str


class ConversationItem(BaseModel):
type: str
content: str
typewriter: bool


class GeneratePayload(BaseModel):
query: str
context: str
conversation: list[ConversationItem]


class GeneratedMessage(BaseModel):
message: str
finish_reason: str
cached: bool
distance: float


class SearchQueryPayload(BaseModel):
query: str
doc_type: str


class GetDocumentPayload(BaseModel):
document_id: str


class LoadPayload(BaseModel):
reader: str
chunker: str
embedder: str
fileBytes: list[str]
fileNames: list[str]
filePath: str
document_type: str
chunkUnits: int
chunkOverlap: int


class GetComponentPayload(BaseModel):
component: str


class SetComponentPayload(BaseModel):
component: str
selected_component: str

0 comments on commit 2e51494

Please sign in to comment.