Skip to content

Commit

Permalink
Merge branch 'feature' into personal/yfei/update-compose
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 authored Dec 9, 2024
2 parents 073f915 + 1b27e9f commit 79904af
Show file tree
Hide file tree
Showing 30 changed files with 638 additions and 344 deletions.
1 change: 1 addition & 0 deletions src/pai_rag/app/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class RagQuery(BaseModel):
session_id: str | None = None
vector_db: VectorDbConfig | None = None
stream: bool | None = False
citation: bool | None = False
with_intent: bool | None = False
index_name: str | None = None

Expand Down
12 changes: 11 additions & 1 deletion src/pai_rag/app/web/rag_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def query(
text: str,
with_history: bool = False,
stream: bool = False,
citation: bool = False,
with_intent: bool = False,
index_name: str = None,
):
Expand All @@ -221,9 +222,11 @@ def query(
question=text,
session_id=session_id,
stream=stream,
citation=citation,
with_intent=with_intent,
index_name=index_name,
)
print(q)
r = requests.post(self.query_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
Expand All @@ -248,10 +251,17 @@ def query_search(
self,
text: str,
with_history: bool = False,
citation: bool = False,
stream: bool = False,
):
session_id = self.session_id if with_history else None
q = dict(question=text, session_id=session_id, stream=stream, with_intent=False)
q = dict(
question=text,
session_id=session_id,
stream=stream,
with_intent=False,
citation=citation,
)
r = requests.post(self.search_url, json=q, stream=True)
if r.status_code != HTTPStatus.OK:
raise RagApiError(code=r.status_code, msg=r.text)
Expand Down
40 changes: 36 additions & 4 deletions src/pai_rag/app/web/tabs/chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def respond(input_elements: List[Any]):
chatbot = update_dict["chatbot"]
is_streaming = update_dict["is_streaming"]
index_name = update_dict["chat_index"]
citation = update_dict["citation"]

if chatbot is not None:
chatbot.append((msg, ""))
Expand All @@ -51,13 +52,17 @@ def respond(input_elements: List[Any]):

elif query_type == "RAG (Search Web)":
response_gen = rag_client.query_search(
msg, with_history=update_dict["include_history"], stream=is_streaming
msg,
with_history=update_dict["include_history"],
stream=is_streaming,
citation=citation,
)
else:
response_gen = rag_client.query(
msg,
with_history=update_dict["include_history"],
stream=is_streaming,
citation=citation,
index_name=index_name,
)

Expand Down Expand Up @@ -94,6 +99,12 @@ def create_chat_tab() -> Dict[str, Any]:
elem_id="is_streaming",
value=True,
)
citation = gr.Checkbox(
label="Citation",
info="Need Citation",
elem_id="citation",
value=True,
)
need_image = gr.Checkbox(
label="Display Image",
info="Inference with multi-modal LLM.",
Expand Down Expand Up @@ -281,19 +292,35 @@ def change_retrieval_mode(retrieval_mode):
search_args = {search_api_key, search_count, search_lang}

with gr.Column(visible=True) as lc_col:
with gr.Tab("LLM Prompt"):
with gr.Tab("Prompt"):
text_qa_template = gr.Textbox(
label="Prompt Template",
value="",
elem_id="text_qa_template",
lines=10,
interactive=True,
)
with gr.Tab("MultiModal LLM Prompt"):
citation_text_qa_template = gr.Textbox(
label="Citation Prompt Template",
value="",
elem_id="citation_text_qa_template",
lines=10,
interactive=True,
)
with gr.Tab("MultiModal Prompt"):
multimodal_qa_template = gr.Textbox(
label="Multi-modal LLM Prompt Template",
label="Multi-modal Prompt Template",
value="",
elem_id="multimodal_qa_template",
lines=12,
interactive=True,
)
citation_multimodal_qa_template = gr.Textbox(
label="Citation Multi-modal Prompt Template",
value="",
elem_id="citation_multimodal_qa_template",
lines=12,
interactive=True,
)

cur_tokens = gr.Textbox(
Expand Down Expand Up @@ -367,10 +394,13 @@ def change_query_radio(query_type):
{
text_qa_template,
multimodal_qa_template,
citation_text_qa_template,
citation_multimodal_qa_template,
question,
query_type,
chatbot,
is_streaming,
citation,
need_image,
include_history,
chat_index,
Expand Down Expand Up @@ -419,6 +449,8 @@ def change_query_radio(query_type):
similarity_threshold.elem_id: similarity_threshold,
reranker_similarity_threshold.elem_id: reranker_similarity_threshold,
multimodal_qa_template.elem_id: multimodal_qa_template,
citation_multimodal_qa_template.elem_id: citation_multimodal_qa_template,
citation_text_qa_template.elem_id: citation_text_qa_template,
text_qa_template.elem_id: text_qa_template,
search_lang.elem_id: search_lang,
search_api_key.elem_id: search_api_key,
Expand Down
2 changes: 0 additions & 2 deletions src/pai_rag/app/web/ui_constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from pai_rag.integrations.synthesizer.pai_synthesizer import (
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)
from pai_rag.utils.prompt_template import DEFAULT_TEXT_QA_PROMPT_TMPL

DEFAULT_TEXT_QA_PROMPT_TMPL = DEFAULT_TEXT_QA_PROMPT_TMPL
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL

DA_GENERAL_PROMPTS = "给定一个输入问题,创建一个语法正确的{dialect}查询语句来执行,不要从特定的表中查询所有列,只根据问题查询几个相关的列。请注意只使用你在schema descriptions 中看到的列名。\n=====\n 小心不要查询不存在的列。请注意哪个列位于哪个表中。必要时,请使用表名限定列名。\n=====\n 你必须使用以下格式,每项占一行:\n\n Question: Question here\n SQLQuery: SQL Query to run \n\n Only use tables listed below.\n {schema}\n\n Question: {query_str} \n SQLQuery: "
Expand Down
26 changes: 22 additions & 4 deletions src/pai_rag/app/web/view_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from pai_rag.app.web.ui_constants import (
LLM_MODEL_KEY_DICT,
MLLM_MODEL_KEY_DICT,
DEFAULT_TEXT_QA_PROMPT_TMPL,
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)
import pandas as pd
import os
Expand Down Expand Up @@ -115,8 +113,10 @@ class ViewModel(BaseModel):

synthesizer_type: str = None

text_qa_template: str = DEFAULT_TEXT_QA_PROMPT_TMPL
multimodal_qa_template: str = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL
text_qa_template: str = None
multimodal_qa_template: str = None
citation_text_qa_template: str = None
citation_multimodal_qa_template: str = None

# agent
agent_api_definition: str = None # API tool definition
Expand Down Expand Up @@ -199,6 +199,12 @@ def from_app_config(config: RagConfig):

view_model.text_qa_template = config.synthesizer.text_qa_template
view_model.multimodal_qa_template = config.synthesizer.multimodal_qa_template
view_model.citation_text_qa_template = (
config.synthesizer.citation_text_qa_template
)
view_model.citation_multimodal_qa_template = (
config.synthesizer.citation_multimodal_qa_template
)

view_model.search_api_key = config.search.search_api_key or os.environ.get(
"BING_SEARCH_KEY"
Expand Down Expand Up @@ -340,6 +346,12 @@ def to_app_config(self):
config["synthesizer"]["use_multimodal_llm"] = self.use_mllm
config["synthesizer"]["text_qa_template"] = self.text_qa_template
config["synthesizer"]["multimodal_qa_template"] = self.multimodal_qa_template
config["synthesizer"][
"citation_text_qa_template"
] = self.citation_text_qa_template
config["synthesizer"][
"citation_multimodal_qa_template"
] = self.citation_multimodal_qa_template

config["search"]["search_api_key"] = self.search_api_key or os.environ.get(
"BING_SEARCH_KEY"
Expand Down Expand Up @@ -518,6 +530,12 @@ def to_component_settings(self) -> Dict[str, Dict[str, Any]]:

settings["text_qa_template"] = {"value": self.text_qa_template}
settings["multimodal_qa_template"] = {"value": self.multimodal_qa_template}
settings["citation_text_qa_template"] = {
"value": self.citation_text_qa_template
}
settings["citation_multimodal_qa_template"] = {
"value": self.citation_multimodal_qa_template
}

# search
settings["search_api_key"] = {"value": self.search_api_key}
Expand Down
6 changes: 0 additions & 6 deletions src/pai_rag/config/settings.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,3 @@ search_api_key = ""

[rag.synthesizer]
type = "SimpleSummarize"
text_qa_template = "参考内容信息如下\n---------------------\n{context_str}\n---------------------根据提供内容而非其他知识回答问题.\n问题: {query_str}\n答案: \n"

[rag.trace]
type = "pai_trace"
endpoint = "http://tracing-analysis-dc-hz.aliyuncs.com:8090"
token = ""
40 changes: 7 additions & 33 deletions src/pai_rag/core/models/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from enum import Enum
from typing import List, Literal
from typing import List
from pydantic import BaseModel
from llama_index.core.prompts.default_prompt_selectors import (
DEFAULT_TEXT_QA_PROMPT_SEL,
)
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pai_rag.integrations.synthesizer.pai_synthesizer import (
DEFAULT_TEXT_QA_TMPL,
DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
CITATION_TEXT_QA_TMPL,
CITATION_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL,
)


Expand Down Expand Up @@ -46,32 +45,7 @@ class SearchWebConfig(BaseModel):

class SynthesizerConfig(BaseModel):
use_multimodal_llm: bool = False
text_qa_template: str = DEFAULT_TEXT_QA_PROMPT_SEL
text_qa_template: str = DEFAULT_TEXT_QA_TMPL
citation_text_qa_template: str = CITATION_TEXT_QA_TMPL
multimodal_qa_template: str = DEFAULT_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL


class TraceType(str, Enum):
"""Trace types."""

arize_pheonix = "arize_phoenix"
pai_trace = "pai_trace"
default = "no_trace"


class BaseTraceConfig(BaseModel):
type: str = TraceType.default

@classmethod
def get_subclasses(cls):
return tuple(cls.__subclasses__())


class PaiTraceConfig(BaseTraceConfig):
type: Literal[TraceType.pai_trace] = TraceType.pai_trace
endpoint: str | None = None
token: str | None = None
app_name: str = "PAIRAG-Service"


class ArizeTraceConfig(BaseTraceConfig):
type: Literal[TraceType.arize_pheonix] = TraceType.arize_pheonix
citation_multimodal_qa_template: str = CITATION_MULTI_MODAL_IMAGE_QA_PROMPT_TMPL
4 changes: 3 additions & 1 deletion src/pai_rag/core/rag_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ async def aquery(
elif intent != Intents.RAG:
return ValueError(f"Invalid intent {intent}")

query_bundle = PaiQueryBundle(query_str=new_question, stream=query.stream)
query_bundle = PaiQueryBundle(
query_str=new_question, stream=query.stream, citation=query.citation
)
chat_store.add_message(
session_id, ChatMessage(role=MessageRole.USER, content=query.question)
)
Expand Down
8 changes: 0 additions & 8 deletions src/pai_rag/core/rag_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Annotated, Dict, Union
from pydantic import BaseModel, ConfigDict, Field, BeforeValidator
from pai_rag.core.models.config import (
BaseTraceConfig,
NodeEnhancementConfig,
OssStoreConfig,
RetrieverConfig,
Expand Down Expand Up @@ -133,10 +132,3 @@ class RagConfig(BaseModel):

# synthesizer
synthesizer: SynthesizerConfig

# trace
trace: Annotated[
Union[BaseTraceConfig.get_subclasses()],
Field(discriminator="type"),
BeforeValidator(validate_case_insensitive),
] | None = None
6 changes: 6 additions & 0 deletions src/pai_rag/core/rag_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def resolve_synthesizer(config: RagConfig) -> PaiSynthesizer:
multimodal_qa_template=PromptTemplate(
template=config.synthesizer.multimodal_qa_template
),
citation_text_qa_template=PromptTemplate(
template=config.synthesizer.citation_text_qa_template
),
citation_multimodal_qa_template=PromptTemplate(
template=config.synthesizer.citation_multimodal_qa_template
),
)
return synthesizer

Expand Down
13 changes: 7 additions & 6 deletions src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE
from pai_rag.utils.constants import DEFAULT_MODEL_DIR

DEFAULT_CNCLIP_MODEL_DIR = os.path.join(
DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14"
)
DEFAULT_CNCLIP_MODEL = "ViT-L-14"


Expand All @@ -29,6 +26,7 @@ def __init__(
self,
model_name: str = DEFAULT_CNCLIP_MODEL,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
model_path: str = DEFAULT_MODEL_DIR,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -41,9 +39,10 @@ def __init__(
raise ValueError(f"Unknown ChineseClip model: {model_name}.")

self._device = "cuda" if torch.cuda.is_available() else "cpu"

self._model, self._preprocess = load_from_name(
self.model_name, device=self._device, download_root=DEFAULT_CNCLIP_MODEL_DIR
self.model_name,
device=self._device,
download_root=model_path,
)
self._model.eval()

Expand Down Expand Up @@ -96,7 +95,9 @@ def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:


if __name__ == "__main__":
clip_embedding = CnClipEmbedding()
clip_embedding = CnClipEmbedding(
os.path.join(DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14")
)

image_embedding = clip_embedding.get_image_embedding(
"example_data/cn_clip/pokemon.jpeg"
Expand Down
Loading

0 comments on commit 79904af

Please sign in to comment.