From 42a5b3ec1754bc73cc369bbffc2dcac4c67fb8f9 Mon Sep 17 00:00:00 2001 From: Garfield Dai Date: Thu, 12 Oct 2023 23:13:10 +0800 Subject: [PATCH] feat: advanced prompt backend (#1301) Co-authored-by: takatost --- api/constants/model_template.py | 24 +++- api/controllers/console/__init__.py | 2 +- .../console/app/advanced_prompt_template.py | 26 ++++ api/controllers/console/app/generator.py | 30 ----- api/controllers/console/app/message.py | 2 +- api/controllers/web/message.py | 2 +- api/core/completion.py | 90 ++++--------- api/core/conversation_message_task.py | 22 ++-- api/core/generator/llm_generator.py | 108 ++++------------ api/core/indexing_runner.py | 4 +- ...versation_token_db_buffer_shared_memory.py | 2 +- .../model_providers/models/entity/message.py | 10 +- api/core/model_providers/models/llm/base.py | 119 +++++++++++++++--- .../providers/anthropic_provider.py | 7 +- .../providers/azure_openai_provider.py | 17 ++- .../providers/baichuan_provider.py | 6 +- api/core/model_providers/providers/base.py | 27 +++- .../providers/chatglm_provider.py | 7 +- .../providers/huggingface_hub_provider.py | 5 +- .../providers/localai_provider.py | 9 +- .../providers/minimax_provider.py | 7 +- .../providers/openai_provider.py | 16 ++- .../providers/openllm_provider.py | 5 +- .../providers/replicate_provider.py | 6 +- .../providers/spark_provider.py | 7 +- .../providers/tongyi_provider.py | 7 +- .../providers/wenxin_provider.py | 8 +- .../providers/xinference_provider.py | 5 +- .../providers/zhipuai_provider.py | 9 +- api/core/orchestrator_rule_parser.py | 79 ++++++------ api/core/prompt/advanced_prompt_templates.py | 79 ++++++++++++ api/core/prompt/prompt_builder.py | 38 ++---- api/core/prompt/prompt_template.py | 96 +++++--------- api/core/prompt/prompts.py | 34 +---- api/core/tool/dataset_retriever_tool.py | 14 ++- api/events/event_handlers/__init__.py | 1 - ...sation_summary_when_few_message_created.py | 14 --- api/fields/app_fields.py | 4 + api/fields/conversation_fields.py | 1 + ...09c049e8e_add_advanced_prompt_templates.py | 37 ++++++ api/models/model.py | 35 +++++- .../advanced_prompt_template_service.py | 56 +++++++++ api/services/app_model_config_service.py | 92 ++++++++++---- api/services/completion_service.py | 66 +++------- api/services/provider_service.py | 3 + .../generate_conversation_summary_task.py | 55 -------- .../models/llm/test_anthropic_model.py | 2 +- .../models/llm/test_azure_openai_model.py | 2 +- .../models/llm/test_baichuan_model.py | 6 +- .../models/llm/test_huggingface_hub_model.py | 4 +- .../models/llm/test_minimax_model.py | 2 +- .../models/llm/test_openai_model.py | 2 +- .../models/llm/test_openllm_model.py | 2 +- .../models/llm/test_replicate_model.py | 2 +- .../models/llm/test_spark_model.py | 2 +- .../models/llm/test_tongyi_model.py | 2 +- .../models/llm/test_wenxin_model.py | 2 +- .../models/llm/test_xinference_model.py | 2 +- .../models/llm/test_zhipuai_model.py | 6 +- .../model_providers/fake_model_provider.py | 7 +- .../test_base_model_provider.py | 2 +- 61 files changed, 762 insertions(+), 576 deletions(-) create mode 100644 api/controllers/console/app/advanced_prompt_template.py create mode 100644 api/core/prompt/advanced_prompt_templates.py delete mode 100644 api/events/event_handlers/generate_conversation_summary_when_few_message_created.py create mode 100644 api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py create mode 100644 api/services/advanced_prompt_template_service.py delete mode 100644 api/tasks/generate_conversation_summary_task.py diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 3b8fa3fb55daba..c35a0b38d6e1e0 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -31,6 +31,7 @@ 'model': json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo-instruct", + "mode": "completion", "completion_params": { "max_tokens": 512, "temperature": 1, @@ -81,6 +82,7 @@ 'model': json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo", + "mode": "chat", "completion_params": { "max_tokens": 512, "temperature": 1, @@ -137,10 +139,11 @@ }, opening_statement='', suggested_questions=None, - pre_prompt="Please translate the following text into {{target_language}}:\n", + pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:", model=json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo-instruct", + "mode": "completion", "completion_params": { "max_tokens": 1000, "temperature": 0, @@ -169,6 +172,13 @@ 'Italian', ] } + },{ + "paragraph": { + "label": "Query", + "variable": "query", + "required": True, + "default": "" + } } ]) ) @@ -200,6 +210,7 @@ model=json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo", + "mode": "chat", "completion_params": { "max_tokens": 300, "temperature": 0.8, @@ -255,10 +266,11 @@ }, opening_statement='', suggested_questions=None, - pre_prompt="请将以下文本翻译为{{target_language}}:\n", + pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:", model=json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo-instruct", + "mode": "completion", "completion_params": { "max_tokens": 1000, "temperature": 0, @@ -287,6 +299,13 @@ "意大利语", ] } + },{ + "paragraph": { + "label": "文本内容", + "variable": "query", + "required": True, + "default": "" + } } ]) ) @@ -318,6 +337,7 @@ model=json.dumps({ "provider": "openai", "name": "gpt-3.5-turbo", + "mode": "chat", "completion_params": { "max_tokens": 300, "temperature": 0.8, diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 4834f84555adfb..2476d918870c6f 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -9,7 +9,7 @@ from . import setup, version, apikey, admin # Import app controllers -from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio +from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio # Import auth controllers from .auth import login, oauth, data_source_oauth, activate diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py new file mode 100644 index 00000000000000..ce47e9e4d8793b --- /dev/null +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -0,0 +1,26 @@ +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from libs.login import login_required +from services.advanced_prompt_template_service import AdvancedPromptTemplateService + +class AdvancedPromptTemplateList(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + + parser = reqparse.RequestParser() + parser.add_argument('app_mode', type=str, required=True, location='args') + parser.add_argument('model_mode', type=str, required=True, location='args') + parser.add_argument('has_context', type=str, required=False, default='true', location='args') + parser.add_argument('model_name', type=str, required=True, location='args') + args = parser.parse_args() + + service = AdvancedPromptTemplateService() + return service.get_prompt(args) + +api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 70275bb70d9028..f454426ab4b29c 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -12,35 +12,6 @@ LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError -class IntroductionGenerateApi(Resource): - @setup_required - @login_required - @account_initialization_required - def post(self): - parser = reqparse.RequestParser() - parser.add_argument('prompt_template', type=str, required=True, location='json') - args = parser.parse_args() - - account = current_user - - try: - answer = LLMGenerator.generate_introduction( - account.current_tenant_id, - args['prompt_template'] - ) - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, - LLMRateLimitError, LLMAuthorizationError) as e: - raise CompletionRequestError(str(e)) - - return {'introduction': answer} - - class RuleGenerateApi(Resource): @setup_required @login_required @@ -72,5 +43,4 @@ def post(self): return rules -api.add_resource(IntroductionGenerateApi, '/introduction-generate') api.add_resource(RuleGenerateApi, '/rule-generate') diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 1b2765f9bce3ee..d6f9172e57abf8 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -329,7 +329,7 @@ def get(self, app_id, message_id): message_id = str(message_id) # get app info - app_model = _get_app(app_id, 'chat') + app_model = _get_app(app_id) message = db.session.query(Message).filter( Message.id == message_id, diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 9d083f00271ba4..2adc1db45f7508 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -115,7 +115,7 @@ def get(self, app_model, end_user, message_id): streaming = args['response_mode'] == 'streaming' try: - response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming) + response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app') return compact_response(response) except MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/core/completion.py b/api/core/completion.py index 59d589eabf3975..768231a53d941a 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -1,4 +1,3 @@ -import json import logging from typing import Optional, List, Union @@ -16,10 +15,8 @@ from core.model_providers.models.entity.message import PromptMessage from core.model_providers.models.llm.base import BaseLLM from core.orchestrator_rule_parser import OrchestratorRuleParser -from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT -from models.dataset import DocumentSegment, Dataset, Document -from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser +from core.prompt.prompt_template import PromptTemplateParser +from models.model import App, AppModelConfig, Account, Conversation, EndUser class Completion: @@ -30,7 +27,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer """ errors: ProviderTokenNotInitError """ - query = PromptBuilder.process_template(query) + query = PromptTemplateParser.remove_template_variables(query) memory = None if conversation: @@ -160,14 +157,28 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], fake_response: Optional[str]): # get llm prompt - prompt_messages, stop_words = model_instance.get_prompt( - mode=mode, - pre_prompt=app_model_config.pre_prompt, - inputs=inputs, - query=query, - context=agent_execute_result.output if agent_execute_result else None, - memory=memory - ) + if app_model_config.prompt_type == 'simple': + prompt_messages, stop_words = model_instance.get_prompt( + mode=mode, + pre_prompt=app_model_config.pre_prompt, + inputs=inputs, + query=query, + context=agent_execute_result.output if agent_execute_result else None, + memory=memory + ) + else: + prompt_messages = model_instance.get_advanced_prompt( + app_mode=mode, + app_model_config=app_model_config, + inputs=inputs, + query=query, + context=agent_execute_result.output if agent_execute_result else None, + memory=memory + ) + + model_config = app_model_config.model_dict + completion_params = model_config.get("completion_params", {}) + stop_words = completion_params.get("stop", []) cls.recale_llm_max_tokens( model_instance=model_instance, @@ -176,7 +187,7 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App response = model_instance.run( messages=prompt_messages, - stop=stop_words, + stop=stop_words if stop_words else None, callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], fake_response=fake_response ) @@ -266,52 +277,3 @@ def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[Pr model_kwargs = model_instance.get_model_kwargs() model_kwargs.max_tokens = max_tokens model_instance.set_model_kwargs(model_kwargs) - - @classmethod - def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, - app_model_config: AppModelConfig, user: Account, streaming: bool): - - final_model_instance = ModelFactory.get_text_generation_model_from_model_config( - tenant_id=app.tenant_id, - model_config=app_model_config.model_dict, - streaming=streaming - ) - - # get llm prompt - old_prompt_messages, _ = final_model_instance.get_prompt( - mode='completion', - pre_prompt=pre_prompt, - inputs=message.inputs, - query=message.query, - context=None, - memory=None - ) - - original_completion = message.answer.strip() - - prompt = MORE_LIKE_THIS_GENERATE_PROMPT - prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion) - - prompt_messages = [PromptMessage(content=prompt)] - - conversation_message_task = ConversationMessageTask( - task_id=task_id, - app=app, - app_model_config=app_model_config, - user=user, - inputs=message.inputs, - query=message.query, - is_override=True if message.override_model_configs else False, - streaming=streaming, - model_instance=final_model_instance - ) - - cls.recale_llm_max_tokens( - model_instance=final_model_instance, - prompt_messages=prompt_messages - ) - - final_model_instance.run( - messages=prompt_messages, - callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)] - ) diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index ae98f91a889eee..3be6ffaee37bb0 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -10,7 +10,7 @@ from core.model_providers.models.entity.message import to_prompt_messages, MessageType from core.model_providers.models.llm.base import BaseLLM from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import JinjaPromptTemplate +from core.prompt.prompt_template import PromptTemplateParser from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -74,10 +74,10 @@ def init(self): if self.mode == 'chat': introduction = self.app_model_config.opening_statement if introduction: - prompt_template = JinjaPromptTemplate.from_template(template=introduction) - prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs} + prompt_template = PromptTemplateParser(template=introduction) + prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs} try: - introduction = prompt_template.format(**prompt_inputs) + introduction = prompt_template.format(prompt_inputs) except KeyError: pass @@ -150,12 +150,12 @@ def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): message_tokens = llm_message.prompt_tokens answer_tokens = llm_message.completion_tokens - message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN) - message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN) + message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER) + message_price_unit = self.model_instance.get_price_unit(MessageType.USER) answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT) answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT) - message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN) + message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER) answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT) total_price = message_total_price + answer_total_price @@ -163,7 +163,7 @@ def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): self.message.message_tokens = message_tokens self.message.message_unit_price = message_unit_price self.message.message_price_unit = message_price_unit - self.message.answer = PromptBuilder.process_template( + self.message.answer = PromptTemplateParser.remove_template_variables( llm_message.completion.strip()) if llm_message.completion else '' self.message.answer_tokens = answer_tokens self.message.answer_unit_price = answer_unit_price @@ -226,15 +226,15 @@ def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM, agent_loop: AgentLoop): - agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN) - agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN) + agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER) + agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER) agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT) agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT) loop_message_tokens = agent_loop.prompt_tokens loop_answer_tokens = agent_loop.completion_tokens - loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN) + loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER) loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT) loop_total_price = loop_message_total_price + loop_answer_total_price diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 93208df9605e54..a6699f32d7a127 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -10,9 +10,8 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate -from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ - GENERATOR_QA_PROMPT +from core.prompt.prompt_template import PromptTemplateParser +from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT class LLMGenerator: @@ -44,78 +43,19 @@ def generate_conversation_name(cls, tenant_id: str, query, answer): return answer.strip() - @classmethod - def generate_conversation_summary(cls, tenant_id: str, messages): - max_tokens = 200 - - model_instance = ModelFactory.get_text_generation_model( - tenant_id=tenant_id, - model_kwargs=ModelKwargs( - max_tokens=max_tokens - ) - ) - - prompt = CONVERSATION_SUMMARY_PROMPT - prompt_with_empty_context = prompt.format(context='') - prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) - max_context_token_length = model_instance.model_rules.max_tokens.max - max_context_token_length = max_context_token_length if max_context_token_length else 1500 - rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 - - context = '' - for message in messages: - if not message.answer: - continue - - if len(message.query) > 2000: - query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:] - else: - query = message.query - - if len(message.answer) > 2000: - answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:] - else: - answer = message.answer - - message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer - if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: - context += message_qa_text - - if not context: - return '[message too long, no summary]' - - prompt = prompt.format(context=context) - prompts = [PromptMessage(content=prompt)] - response = model_instance.run(prompts) - answer = response.content - return answer.strip() - - @classmethod - def generate_introduction(cls, tenant_id: str, pre_prompt: str): - prompt = INTRODUCTION_GENERATE_PROMPT - prompt = prompt.format(prompt=pre_prompt) - - model_instance = ModelFactory.get_text_generation_model( - tenant_id=tenant_id - ) - - prompts = [PromptMessage(content=prompt)] - response = model_instance.run(prompts) - answer = response.content - return answer.strip() - @classmethod def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str): output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt = JinjaPromptTemplate( - template="{{histories}}\n{{format_instructions}}\nquestions:\n", - input_variables=["histories"], - partial_variables={"format_instructions": format_instructions} + prompt_template = PromptTemplateParser( + template="{{histories}}\n{{format_instructions}}\nquestions:\n" ) - _input = prompt.format_prompt(histories=histories) + prompt = prompt_template.format({ + "histories": histories, + "format_instructions": format_instructions + }) try: model_instance = ModelFactory.get_text_generation_model( @@ -128,10 +68,10 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st except ProviderTokenNotInitError: return [] - prompts = [PromptMessage(content=_input.to_string())] + prompt_messages = [PromptMessage(content=prompt)] try: - output = model_instance.run(prompts) + output = model_instance.run(prompt_messages) questions = output_parser.parse(output.content) except LLMError: questions = [] @@ -145,19 +85,21 @@ def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: st def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict: output_parser = RuleConfigGeneratorOutputParser() - prompt = OutLinePromptTemplate( - template=output_parser.get_format_instructions(), - input_variables=["audiences", "hoping_to_solve"], - partial_variables={ - "variable": '{variable}', - "lanA": '{lanA}', - "lanB": '{lanB}', - "topic": '{topic}' - }, - validate_template=False + prompt_template = PromptTemplateParser( + template=output_parser.get_format_instructions() ) - _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) + prompt = prompt_template.format( + inputs={ + "audiences": audiences, + "hoping_to_solve": hoping_to_solve, + "variable": "{{variable}}", + "lanA": "{{lanA}}", + "lanB": "{{lanB}}", + "topic": "{{topic}}" + }, + remove_template_variables=False + ) model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, @@ -167,10 +109,10 @@ def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: s ) ) - prompts = [PromptMessage(content=_input.to_string())] + prompt_messages = [PromptMessage(content=prompt)] try: - output = model_instance.run(prompts) + output = model_instance.run(prompt_messages) rule_config = output_parser.parse(output.content) except LLMError as e: raise e diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 1475c143c293cb..fcf954a9856279 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -286,7 +286,7 @@ def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), + text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts @@ -383,7 +383,7 @@ def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_p "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)), + text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)), "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py index 55d70d38ad2cf0..755df1201a5255 100644 --- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py +++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py @@ -31,7 +31,7 @@ def buffer(self) -> List[BaseMessage]: chat_messages: List[PromptMessage] = [] for message in messages: - chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) + chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER)) chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) if not chat_messages: diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py index c37e88fac9ee4e..1ae04d67f51af2 100644 --- a/api/core/model_providers/models/entity/message.py +++ b/api/core/model_providers/models/entity/message.py @@ -13,13 +13,13 @@ class LLMRunResult(BaseModel): class MessageType(enum.Enum): - HUMAN = 'human' + USER = 'user' ASSISTANT = 'assistant' SYSTEM = 'system' class PromptMessage(BaseModel): - type: MessageType = MessageType.HUMAN + type: MessageType = MessageType.USER content: str = '' function_call: dict = None @@ -27,7 +27,7 @@ class PromptMessage(BaseModel): def to_lc_messages(messages: list[PromptMessage]): lc_messages = [] for message in messages: - if message.type == MessageType.HUMAN: + if message.type == MessageType.USER: lc_messages.append(HumanMessage(content=message.content)) elif message.type == MessageType.ASSISTANT: additional_kwargs = {} @@ -44,7 +44,7 @@ def to_prompt_messages(messages: list[BaseMessage]): prompt_messages = [] for message in messages: if isinstance(message, HumanMessage): - prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) elif isinstance(message, AIMessage): message_kwargs = { 'content': message.content, @@ -58,7 +58,7 @@ def to_prompt_messages(messages: list[BaseMessage]): elif isinstance(message, SystemMessage): prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) elif isinstance(message, FunctionMessage): - prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER)) return prompt_messages diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py index 7224bf714191db..3a6e8b41ca7a52 100644 --- a/api/core/model_providers/models/llm/base.py +++ b/api/core/model_providers/models/llm/base.py @@ -18,7 +18,7 @@ from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.providers.base import BaseModelProvider from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import JinjaPromptTemplate +from core.prompt.prompt_template import PromptTemplateParser from core.third_party.langchain.llms.fake import FakeLLM import logging @@ -232,7 +232,7 @@ def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.D :param message_type: :return: """ - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + if message_type == MessageType.USER or message_type == MessageType.SYSTEM: unit_price = self.price_config['prompt'] else: unit_price = self.price_config['completion'] @@ -250,7 +250,7 @@ def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal: :param message_type: :return: decimal.Decimal('0.0001') """ - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + if message_type == MessageType.USER or message_type == MessageType.SYSTEM: unit_price = self.price_config['prompt'] else: unit_price = self.price_config['completion'] @@ -265,7 +265,7 @@ def get_price_unit(self, message_type: MessageType) -> decimal.Decimal: :param message_type: :return: decimal.Decimal('0.000001') """ - if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + if message_type == MessageType.USER or message_type == MessageType.SYSTEM: price_unit = self.price_config['unit'] else: price_unit = self.price_config['unit'] @@ -330,6 +330,85 @@ def get_prompt(self, mode: str, prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory) return [PromptMessage(content=prompt)], stops + def get_advanced_prompt(self, app_mode: str, + app_model_config: str, inputs: dict, + query: str, + context: Optional[str], + memory: Optional[BaseChatMemory]) -> List[PromptMessage]: + + model_mode = app_model_config.model_dict['mode'] + conversation_histories_role = {} + + raw_prompt_list = [] + prompt_messages = [] + + if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: + prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] + raw_prompt_list = [{ + 'role': MessageType.USER.value, + 'text': prompt_text + }] + conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role'] + elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value: + raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] + elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value: + raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt'] + elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value: + prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text'] + raw_prompt_list = [{ + 'role': MessageType.USER.value, + 'text': prompt_text + }] + else: + raise Exception("app_mode or model_mode not support") + + for prompt_item in raw_prompt_list: + prompt = prompt_item['text'] + + # set prompt template variables + prompt_template = PromptTemplateParser(template=prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + + if '#context#' in prompt: + if context: + prompt_inputs['#context#'] = context + else: + prompt_inputs['#context#'] = '' + + if '#query#' in prompt: + if query: + prompt_inputs['#query#'] = query + else: + prompt_inputs['#query#'] = '' + + if '#histories#' in prompt: + if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value: + memory.human_prefix = conversation_histories_role['user_prefix'] + memory.ai_prefix = conversation_histories_role['assistant_prefix'] + histories = self._get_history_messages_from_memory(memory, 2000) + prompt_inputs['#histories#'] = histories + else: + prompt_inputs['#histories#'] = '' + + prompt = prompt_template.format( + prompt_inputs + ) + + prompt = re.sub(r'<\|.*?\|>', '', prompt) + + prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt)) + + if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value: + memory.human_prefix = MessageType.USER.value + memory.ai_prefix = MessageType.ASSISTANT.value + histories = self._get_history_messages_list_from_memory(memory, 2000) + prompt_messages.extend(histories) + + if app_mode == 'chat' and model_mode == ModelMode.CHAT.value: + prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query)) + + return prompt_messages + def prompt_file_name(self, mode: str) -> str: if mode == 'completion': return 'common_completion' @@ -342,17 +421,17 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: - prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt']) + prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) context_prompt_content = prompt_template.format( - context=context + {'context': context} ) pre_prompt_content = '' if pre_prompt: - prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt) - prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} + prompt_template = PromptTemplateParser(template=pre_prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} pre_prompt_content = prompt_template.format( - **prompt_inputs + prompt_inputs ) prompt = '' @@ -385,10 +464,8 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant' histories = self._get_history_messages_from_memory(memory, rest_tokens) - prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt']) - histories_prompt_content = prompt_template.format( - histories=histories - ) + prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt']) + histories_prompt_content = prompt_template.format({'histories': histories}) prompt = '' for order in prompt_rules['system_prompt_orders']: @@ -399,10 +476,8 @@ def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict elif order == 'histories_prompt': prompt += histories_prompt_content - prompt_template = JinjaPromptTemplate.from_template(template=query_prompt) - query_prompt_content = prompt_template.format( - query=query - ) + prompt_template = PromptTemplateParser(template=query_prompt) + query_prompt_content = prompt_template.format({'query': query}) prompt += query_prompt_content @@ -433,6 +508,16 @@ def _get_history_messages_from_memory(self, memory: BaseChatMemory, external_context = memory.load_memory_variables({}) return external_context[memory_key] + def _get_history_messages_list_from_memory(self, memory: BaseChatMemory, + max_token_limit: int) -> List[PromptMessage]: + """Get memory messages.""" + memory.max_token_limit = max_token_limit + memory.return_messages = True + memory_key = memory.memory_variables[0] + external_context = memory.load_memory_variables({}) + memory.return_messages = False + return to_prompt_messages(external_context[memory_key]) + def _get_prompt_from_messages(self, messages: List[PromptMessage], model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: if not model_mode: diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py index 35532b0ec40da8..eab61c60ccc2f0 100644 --- a/api/core/model_providers/providers/anthropic_provider.py +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -9,7 +9,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.llm.anthropic_model import AnthropicModel from core.model_providers.models.llm.base import ModelType @@ -34,10 +34,12 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'claude-instant-1', 'name': 'claude-instant-1', + 'mode': ModelMode.CHAT.value, }, { 'id': 'claude-2', 'name': 'claude-2', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -46,6 +48,9 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/azure_openai_provider.py b/api/core/model_providers/providers/azure_openai_provider.py index 4f7c8b717c2e78..a34b463286a76a 100644 --- a/api/core/model_providers/providers/azure_openai_provider.py +++ b/api/core/model_providers/providers/azure_openai_provider.py @@ -12,7 +12,7 @@ from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \ AZURE_OPENAI_API_VERSION -from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode from core.model_providers.models.entity.provider import ModelFeature from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -61,6 +61,10 @@ def get_supported_model_list(self, model_type: ModelType) -> list[dict]: } credentials = json.loads(provider_model.encrypted_config) + + if provider_model.model_type == ModelType.TEXT_GENERATION.value: + model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name']) + if credentials['base_model_name'] in [ 'gpt-4', 'gpt-4-32k', @@ -77,12 +81,19 @@ def get_supported_model_list(self, model_type: ModelType) -> list[dict]: return model_list + def _get_text_generation_model_mode(self, model_name) -> str: + if model_name == 'text-davinci-003': + return ModelMode.COMPLETION.value + else: + return ModelMode.CHAT.value + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: if model_type == ModelType.TEXT_GENERATION: models = [ { 'id': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -90,6 +101,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -97,6 +109,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-4', 'name': 'gpt-4', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -104,6 +117,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-4-32k', 'name': 'gpt-4-32k', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -111,6 +125,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'text-davinci-003', 'name': 'text-davinci-003', + 'mode': ModelMode.COMPLETION.value, } ] diff --git a/api/core/model_providers/providers/baichuan_provider.py b/api/core/model_providers/providers/baichuan_provider.py index 12c475f92daf2c..784c9df2c6582b 100644 --- a/api/core/model_providers/providers/baichuan_provider.py +++ b/api/core/model_providers/providers/baichuan_provider.py @@ -6,7 +6,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.baichuan_model import BaichuanModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM @@ -21,6 +21,9 @@ def provider_name(self): Returns the name of a provider. """ return 'baichuan' + + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: if model_type == ModelType.TEXT_GENERATION: @@ -28,6 +31,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'baichuan2-53b', 'name': 'Baichuan2-53B', + 'mode': ModelMode.CHAT.value, } ] else: diff --git a/api/core/model_providers/providers/base.py b/api/core/model_providers/providers/base.py index f10aa9f99d95aa..9b05b4f5fd63b2 100644 --- a/api/core/model_providers/providers/base.py +++ b/api/core/model_providers/providers/base.py @@ -61,10 +61,19 @@ def get_supported_model_list(self, model_type: ModelType) -> list[dict]: ProviderModel.is_valid == True ).order_by(ProviderModel.created_at.asc()).all() - return [{ - 'id': provider_model.model_name, - 'name': provider_model.model_name - } for provider_model in provider_models] + provider_model_list = [] + for provider_model in provider_models: + provider_model_dict = { + 'id': provider_model.model_name, + 'name': provider_model.model_name + } + + if model_type == ModelType.TEXT_GENERATION: + provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name) + + provider_model_list.append(provider_model_dict) + + return provider_model_list @abstractmethod def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: @@ -76,6 +85,16 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: """ raise NotImplementedError + @abstractmethod + def _get_text_generation_model_mode(self, model_name) -> str: + """ + get text generation model mode. + + :param model_name: + :return: + """ + raise NotImplementedError + @abstractmethod def get_model_class(self, model_type: ModelType) -> Type: """ diff --git a/api/core/model_providers/providers/chatglm_provider.py b/api/core/model_providers/providers/chatglm_provider.py index 4b2a46ad428c05..d3c83e37ce5dfa 100644 --- a/api/core/model_providers/providers/chatglm_provider.py +++ b/api/core/model_providers/providers/chatglm_provider.py @@ -6,7 +6,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.chatglm_model import ChatGLMModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from models.provider import ProviderType @@ -27,15 +27,20 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'chatglm2-6b', 'name': 'ChatGLM2-6B', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'chatglm-6b', 'name': 'ChatGLM-6B', + 'mode': ModelMode.COMPLETION.value, } ] else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py index deae4e35df3abe..2cb7ff120a84c8 100644 --- a/api/core/model_providers/providers/huggingface_hub_provider.py +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -5,7 +5,7 @@ from huggingface_hub import HfApi from core.helper import encrypter -from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -29,6 +29,9 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/localai_provider.py b/api/core/model_providers/providers/localai_provider.py index f5b07b1e6cc7b1..89279996f8a50f 100644 --- a/api/core/model_providers/providers/localai_provider.py +++ b/api/core/model_providers/providers/localai_provider.py @@ -6,7 +6,7 @@ from core.helper import encrypter from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding -from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule +from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode from core.model_providers.models.llm.localai_model import LocalAIModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -27,6 +27,13 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION) + if credentials['completion_type'] == 'chat_completion': + return ModelMode.CHAT.value + else: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/minimax_provider.py b/api/core/model_providers/providers/minimax_provider.py index c13165d6026e75..f643e1e805d262 100644 --- a/api/core/model_providers/providers/minimax_provider.py +++ b/api/core/model_providers/providers/minimax_provider.py @@ -7,7 +7,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.minimax_model import MinimaxModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM @@ -29,10 +29,12 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'abab5.5-chat', 'name': 'abab5.5-chat', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'abab5-chat', 'name': 'abab5-chat', + 'mode': ModelMode.COMPLETION.value, } ] elif model_type == ModelType.EMBEDDINGS: @@ -45,6 +47,9 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/openai_provider.py b/api/core/model_providers/providers/openai_provider.py index 01b2adcedd7964..de5de280252c11 100644 --- a/api/core/model_providers/providers/openai_provider.py +++ b/api/core/model_providers/providers/openai_provider.py @@ -13,8 +13,8 @@ from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType -from core.model_providers.models.llm.openai_model import OpenAIModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode +from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS from core.model_providers.models.moderation.openai_moderation import OpenAIModeration from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.model_providers.providers.hosted import hosted_model_providers @@ -36,6 +36,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-3.5-turbo', 'name': 'gpt-3.5-turbo', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -43,10 +44,12 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-3.5-turbo-instruct', 'name': 'GPT-3.5-Turbo-Instruct', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'gpt-3.5-turbo-16k', 'name': 'gpt-3.5-turbo-16k', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -54,6 +57,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-4', 'name': 'gpt-4', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -61,6 +65,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'gpt-4-32k', 'name': 'gpt-4-32k', + 'mode': ModelMode.CHAT.value, 'features': [ ModelFeature.AGENT_THOUGHT.value ] @@ -68,6 +73,7 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'text-davinci-003', 'name': 'text-davinci-003', + 'mode': ModelMode.COMPLETION.value, } ] @@ -100,6 +106,12 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + if model_name in COMPLETION_MODELS: + return ModelMode.COMPLETION.value + else: + return ModelMode.CHAT.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/openllm_provider.py b/api/core/model_providers/providers/openllm_provider.py index a691507b9f69b6..ea0e0b860d9f3e 100644 --- a/api/core/model_providers/providers/openllm_provider.py +++ b/api/core/model_providers/providers/openllm_provider.py @@ -3,7 +3,7 @@ from core.helper import encrypter from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding -from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode from core.model_providers.models.llm.openllm_model import OpenLLMModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -24,6 +24,9 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/replicate_provider.py b/api/core/model_providers/providers/replicate_provider.py index 9324d432a41058..be9a7aa7aebe3c 100644 --- a/api/core/model_providers/providers/replicate_provider.py +++ b/api/core/model_providers/providers/replicate_provider.py @@ -6,7 +6,8 @@ from replicate.exceptions import ReplicateError from core.helper import encrypter -from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType +from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \ + ModelMode from core.model_providers.models.llm.replicate_model import ReplicateModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -26,6 +27,9 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py index 89ed5d30b708ee..4174c0116352ad 100644 --- a/api/core/model_providers/providers/spark_provider.py +++ b/api/core/model_providers/providers/spark_provider.py @@ -7,7 +7,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.spark_model import SparkModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.spark import ChatSpark @@ -30,15 +30,20 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'spark', 'name': 'Spark V1.5', + 'mode': ModelMode.CHAT.value, }, { 'id': 'spark-v2', 'name': 'Spark V2.0', + 'mode': ModelMode.CHAT.value, } ] else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/tongyi_provider.py b/api/core/model_providers/providers/tongyi_provider.py index d48b4447f8a457..49ff731ac5cb5a 100644 --- a/api/core/model_providers/providers/tongyi_provider.py +++ b/api/core/model_providers/providers/tongyi_provider.py @@ -4,7 +4,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.tongyi_model import TongyiModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi @@ -26,15 +26,20 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'qwen-turbo', 'name': 'qwen-turbo', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'qwen-plus', 'name': 'qwen-plus', + 'mode': ModelMode.COMPLETION.value, } ] else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/wenxin_provider.py b/api/core/model_providers/providers/wenxin_provider.py index d6d18163233288..e729358c0a99d7 100644 --- a/api/core/model_providers/providers/wenxin_provider.py +++ b/api/core/model_providers/providers/wenxin_provider.py @@ -4,7 +4,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.wenxin_model import WenxinModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.wenxin import Wenxin @@ -26,19 +26,25 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'ernie-bot', 'name': 'ERNIE-Bot', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'ernie-bot-turbo', 'name': 'ERNIE-Bot-turbo', + 'mode': ModelMode.COMPLETION.value, }, { 'id': 'bloomz-7b', 'name': 'BLOOMZ-7B', + 'mode': ModelMode.COMPLETION.value, } ] else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index f56c5fb59dc7c4..fff0119eaf09d6 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -6,7 +6,7 @@ from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding -from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode from core.model_providers.models.llm.xinference_model import XinferenceModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError @@ -26,6 +26,9 @@ def provider_name(self): def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/model_providers/providers/zhipuai_provider.py b/api/core/model_providers/providers/zhipuai_provider.py index 0f7dae5f4fab24..9b56851688b8f4 100644 --- a/api/core/model_providers/providers/zhipuai_provider.py +++ b/api/core/model_providers/providers/zhipuai_provider.py @@ -7,7 +7,7 @@ from core.helper import encrypter from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding -from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM @@ -29,18 +29,22 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: { 'id': 'chatglm_pro', 'name': 'chatglm_pro', + 'mode': ModelMode.CHAT.value, }, { 'id': 'chatglm_std', 'name': 'chatglm_std', + 'mode': ModelMode.CHAT.value, }, { 'id': 'chatglm_lite', 'name': 'chatglm_lite', + 'mode': ModelMode.CHAT.value, }, { 'id': 'chatglm_lite_32k', 'name': 'chatglm_lite_32k', + 'mode': ModelMode.CHAT.value, } ] elif model_type == ModelType.EMBEDDINGS: @@ -53,6 +57,9 @@ def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: else: return [] + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.CHAT.value + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: """ Returns the model class. diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index f359cf82fd2629..2ba732ee3dd613 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -1,4 +1,3 @@ -import math from typing import Optional from langchain import WikipediaAPIWrapper @@ -50,6 +49,7 @@ def to_agent_executor(self, conversation_message_task: ConversationMessageTask, tool_configs = agent_mode_config.get('tools', []) agent_provider_name = model_dict.get('provider', 'openai') agent_model_name = model_dict.get('name', 'gpt-4') + dataset_configs = self.app_model_config.dataset_configs_dict agent_model_instance = ModelFactory.get_text_generation_model( tenant_id=self.tenant_id, @@ -96,13 +96,14 @@ def to_agent_executor(self, conversation_message_task: ConversationMessageTask, summary_model_instance = None tools = self.to_tools( - agent_model_instance=agent_model_instance, tool_configs=tool_configs, + callbacks=[agent_callback, DifyStdOutCallbackHandler()], + agent_model_instance=agent_model_instance, conversation_message_task=conversation_message_task, rest_tokens=rest_tokens, - callbacks=[agent_callback, DifyStdOutCallbackHandler()], return_resource=return_resource, - retriever_from=retriever_from + retriever_from=retriever_from, + dataset_configs=dataset_configs ) if len(tools) == 0: @@ -170,20 +171,12 @@ def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: return None - def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, - conversation_message_task: ConversationMessageTask, - rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False, - retriever_from: str = 'dev') -> list[BaseTool]: + def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]: """ Convert app agent tool configs to tools - :param agent_model_instance: - :param rest_tokens: :param tool_configs: app agent tool configs - :param conversation_message_task: :param callbacks: - :param return_resource: - :param retriever_from: :return: """ tools = [] @@ -195,15 +188,15 @@ def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, tool = None if tool_type == "dataset": - tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from) + tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs) elif tool_type == "web_reader": - tool = self.to_web_reader_tool(agent_model_instance) + tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs) elif tool_type == "google_search": - tool = self.to_google_search_tool() + tool = self.to_google_search_tool(tool_config=tool_val, **kwargs) elif tool_type == "wikipedia": - tool = self.to_wikipedia_tool() + tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs) elif tool_type == "current_datetime": - tool = self.to_current_datetime_tool() + tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs) if tool: if tool.callbacks is not None: @@ -215,12 +208,15 @@ def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list, return tools def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask, - rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \ + dataset_configs: dict, rest_tokens: int, + return_resource: bool = False, retriever_from: str = 'dev', + **kwargs) \ -> Optional[BaseTool]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param rest_tokens: :param tool_config: + :param dataset_configs: :param conversation_message_task: :param return_resource: :param retriever_from: @@ -238,10 +234,20 @@ def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0: return None - k = self._dynamic_calc_retrieve_k(dataset, rest_tokens) + top_k = dataset_configs.get("top_k", 2) + + # dynamically adjust top_k when the remaining token number is not enough to support top_k + top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens) + + score_threshold = None + score_threshold_config = dataset_configs.get("score_threshold") + if score_threshold_config and score_threshold_config.get("enable"): + score_threshold = score_threshold_config.get("value") + tool = DatasetRetrieverTool.from_dataset( dataset=dataset, - k=k, + top_k=top_k, + score_threshold=score_threshold, callbacks=[DatasetToolCallbackHandler(conversation_message_task)], conversation_message_task=conversation_message_task, return_resource=return_resource, @@ -250,7 +256,7 @@ def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task return tool - def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]: + def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]: """ A tool for reading web pages @@ -278,7 +284,7 @@ def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool return tool - def to_google_search_tool(self) -> Optional[BaseTool]: + def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id) func_kwargs = tool_provider.credentials_to_func_kwargs() if not func_kwargs: @@ -296,12 +302,12 @@ def to_google_search_tool(self) -> Optional[BaseTool]: return tool - def to_current_datetime_tool(self) -> Optional[BaseTool]: + def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: tool = DatetimeTool() return tool - def to_wikipedia_tool(self) -> Optional[BaseTool]: + def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]: class WikipediaInput(BaseModel): query: str = Field(..., description="search query.") @@ -312,22 +318,18 @@ class WikipediaInput(BaseModel): ) @classmethod - def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: - DEFAULT_K = 2 - CONTEXT_TOKENS_PERCENT = 0.3 - MAX_K = 10 - + def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int: if rest_tokens == -1: - return DEFAULT_K + return top_k processing_rule = dataset.latest_process_rule if not processing_rule: - return DEFAULT_K + return top_k if processing_rule.mode == "custom": rules = processing_rule.rules_dict if not rules: - return DEFAULT_K + return top_k segmentation = rules["segmentation"] segment_max_tokens = segmentation["max_tokens"] @@ -335,14 +337,7 @@ def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'] # when rest_tokens is less than default context tokens - if rest_tokens < segment_max_tokens * DEFAULT_K: + if rest_tokens < segment_max_tokens * top_k: return rest_tokens // segment_max_tokens - context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT) - - # when context_limit_tokens is less than default context tokens, use default_k - if context_limit_tokens <= segment_max_tokens * DEFAULT_K: - return DEFAULT_K - - # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K - return min(context_limit_tokens // segment_max_tokens, MAX_K) + return min(top_k, 10) diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/advanced_prompt_templates.py new file mode 100644 index 00000000000000..c5eee005b6faab --- /dev/null +++ b/api/core/prompt/advanced_prompt_templates.py @@ -0,0 +1,79 @@ +CONTEXT = "Use the following context as your learned knowledge, inside XML tags.\n\n\n{{#context#}}\n\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n" + +BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n" + +CHAT_APP_COMPLETION_PROMPT_CONFIG = { + "completion_prompt_config": { + "prompt": { + "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " + }, + "conversation_histories_role": { + "user_prefix": "Human", + "assistant_prefix": "Assistant" + } + } +} + +CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": { + "prompt": [{ + "role": "system", + "text": "{{#pre_prompt#}}" + }] + } +} + +COMPLETION_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": { + "prompt": [{ + "role": "user", + "text": "{{#pre_prompt#}}" + }] + } +} + +COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { + "completion_prompt_config": { + "prompt": { + "text": "{{#pre_prompt#}}" + } + } +} + +BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = { + "completion_prompt_config": { + "prompt": { + "text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}" + }, + "conversation_histories_role": { + "user_prefix": "用户", + "assistant_prefix": "助手" + } + } +} + +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": { + "prompt": [{ + "role": "system", + "text": "{{#pre_prompt#}}" + }] + } +} + +BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = { + "chat_prompt_config": { + "prompt": [{ + "role": "user", + "text": "{{#pre_prompt#}}" + }] + } +} + +BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = { + "completion_prompt_config": { + "prompt": { + "text": "{{#pre_prompt#}}" + } + } +} diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py index 073cf2ce2567d8..cc2a11a78f1c5b 100644 --- a/api/core/prompt/prompt_builder.py +++ b/api/core/prompt/prompt_builder.py @@ -1,38 +1,24 @@ -import re +from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage -from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate -from langchain.schema import BaseMessage - -from core.prompt.prompt_template import JinjaPromptTemplate +from core.prompt.prompt_template import PromptTemplateParser class PromptBuilder: + @classmethod + def parse_prompt(cls, prompt: str, inputs: dict) -> str: + prompt_template = PromptTemplateParser(prompt) + prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs} + prompt = prompt_template.format(prompt_inputs) + return prompt + @classmethod def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = JinjaPromptTemplate.from_template(prompt_content) - system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template) - prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs} - system_message = system_prompt_template.format(**prompt_inputs) - return system_message + return SystemMessage(content=cls.parse_prompt(prompt_content, inputs)) @classmethod def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = JinjaPromptTemplate.from_template(prompt_content) - ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template) - prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs} - ai_message = ai_prompt_template.format(**prompt_inputs) - return ai_message + return AIMessage(content=cls.parse_prompt(prompt_content, inputs)) @classmethod def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = JinjaPromptTemplate.from_template(prompt_content) - human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template) - human_message = human_prompt_template.format(**inputs) - return human_message - - @classmethod - def process_template(cls, template: str): - processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template) - # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template) - # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template) - return processed_template + return HumanMessage(content=cls.parse_prompt(prompt_content, inputs)) diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py index c51c4700c1c76a..fbf09d2c6424f7 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/prompt_template.py @@ -1,79 +1,39 @@ import re -from typing import Any -from jinja2 import Environment, meta -from langchain import PromptTemplate -from langchain.formatting import StrictFormatter +REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}") -class JinjaPromptTemplate(PromptTemplate): - template_format: str = "jinja2" - """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" +class PromptTemplateParser: + """ + Rules: - @classmethod - def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: - """Load a prompt template from a template.""" - env = Environment() - template = template.replace("{{}}", "{}") - ast = env.parse(template) - input_variables = meta.find_undeclared_variables(ast) - - if "partial_variables" in kwargs: - partial_variables = kwargs["partial_variables"] - input_variables = { - var for var in input_variables if var not in partial_variables - } - - return cls( - input_variables=list(sorted(input_variables)), template=template, **kwargs - ) - - -class OutLinePromptTemplate(PromptTemplate): - @classmethod - def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: - """Load a prompt template from a template.""" - input_variables = { - v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None - } - return cls( - input_variables=list(sorted(input_variables)), template=template, **kwargs - ) - - def format(self, **kwargs: Any) -> str: - """Format the prompt with the inputs. + 1. Template variables must be enclosed in `{{}}`. + 2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters, + and can only start with letters and underscores. + 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. + 4. In addition to the above, 3 types of special template variable Keys are accepted: + `{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed. + """ - Args: - kwargs: Any arguments to be passed to the prompt template. + def __init__(self, template: str): + self.template = template + self.variable_keys = self.extract() - Returns: - A formatted string. + def extract(self) -> list: + # Regular expression to match the template rules + return re.findall(REGEX, self.template) - Example: + def format(self, inputs: dict, remove_template_variables: bool = True) -> str: + def replacer(match): + key = match.group(1) + value = inputs.get(key, match.group(0)) # return original matched string if key not found - .. code-block:: python + if remove_template_variables: + return PromptTemplateParser.remove_template_variables(value) + return value - prompt.format(variable1="foo") - """ - kwargs = self._merge_partial_and_user_variables(**kwargs) - return OneLineFormatter().format(self.template, **kwargs) + return re.sub(REGEX, replacer, self.template) - -class OneLineFormatter(StrictFormatter): - def parse(self, format_string): - last_end = 0 - results = [] - for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string): - field_name = match.group(1) - start, end = match.span() - - literal_text = format_string[last_end:start] - last_end = end - - results.append((literal_text, field_name, '', None)) - - remaining_literal_text = format_string[last_end:] - if remaining_literal_text: - results.append((remaining_literal_text, None, None, None)) - - return results + @classmethod + def remove_template_variables(cls, text: str): + return re.sub(REGEX, r'{\1}', text) diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index 979fe9be96b8c8..44cf954d3e9b84 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -61,36 +61,6 @@ User Input: """ -CONVERSATION_SUMMARY_PROMPT = ( - "Please generate a short summary of the following conversation.\n" - "If the following conversation communicating in English, you should only return an English summary.\n" - "If the following conversation communicating in Chinese, you should only return a Chinese summary.\n" - "[Conversation Start]\n" - "{context}\n" - "[Conversation End]\n\n" - "summary:" -) - -INTRODUCTION_GENERATE_PROMPT = ( - "I am designing a product for users to interact with an AI through dialogue. " - "The Prompt given to the AI before the conversation is:\n\n" - "```\n{prompt}\n```\n\n" - "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. " - "Do not reveal the developer's motivation or deep logic behind the Prompt, " - "but focus on building a relationship with the user:\n" -) - -MORE_LIKE_THIS_GENERATE_PROMPT = ( - "-----\n" - "{original_completion}\n" - "-----\n\n" - "Please use the above content as a sample for generating the result, " - "and include key information points related to the original sample in the result. " - "Try to rephrase this information in different ways and predict according to the rules below.\n\n" - "-----\n" - "{prompt}\n" -) - SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( "Please help me predict the three most likely questions that human would ask, " "and keeping each question under 20 characters.\n" @@ -157,10 +127,10 @@ ``` << MY INTENDED AUDIENCES >> -{audiences} +{{audiences}} << HOPING TO SOLVE >> -{hoping_to_solve} +{{hoping_to_solve}} << OUTPUT >> """ \ No newline at end of file diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 33fec157eaa36b..2c14f40d15bfd1 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -1,5 +1,5 @@ import json -from typing import Type +from typing import Type, Optional from flask import current_app from langchain.tools import BaseTool @@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool): tenant_id: str dataset_id: str - k: int = 3 + top_k: int = 2 + score_threshold: Optional[float] = None conversation_message_task: ConversationMessageTask return_resource: bool retriever_from: str @@ -66,7 +67,7 @@ def _run(self, query: str) -> str: ) ) - documents = kw_table_index.search(query, search_kwargs={'k': self.k}) + documents = kw_table_index.search(query, search_kwargs={'k': self.top_k}) return str("\n".join([document.page_content for document in documents])) else: @@ -80,20 +81,21 @@ def _run(self, query: str) -> str: return '' except ProviderTokenNotInitError: return '' - embeddings = CacheEmbedding(embedding_model) + embeddings = CacheEmbedding(embedding_model) vector_index = VectorIndex( dataset=dataset, config=current_app.config, embeddings=embeddings ) - if self.k > 0: + if self.top_k > 0: documents = vector_index.search( query, search_type='similarity_score_threshold', search_kwargs={ - 'k': self.k, + 'k': self.top_k, + 'score_threshold': self.score_threshold, 'filter': { 'group_id': [dataset.id] } diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 02020e9192a876..2414950a547a9d 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -4,5 +4,4 @@ from .clean_when_dataset_deleted import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .generate_conversation_name_when_first_message_created import handle -from .generate_conversation_summary_when_few_message_created import handle from .create_document_index import handle diff --git a/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py b/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py deleted file mode 100644 index df62a90b8e2f41..00000000000000 --- a/api/events/event_handlers/generate_conversation_summary_when_few_message_created.py +++ /dev/null @@ -1,14 +0,0 @@ -from events.message_event import message_was_created -from tasks.generate_conversation_summary_task import generate_conversation_summary_task - - -@message_was_created.connect -def handle(sender, **kwargs): - message = sender - conversation = kwargs.get('conversation') - is_first_message = kwargs.get('is_first_message') - - if not is_first_message and conversation.mode == 'chat' and not conversation.summary: - history_message_count = conversation.message_count - if history_message_count >= 5: - generate_conversation_summary_task.delay(conversation.id) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index b370ac41e058e0..fccfa5df306f5d 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -28,6 +28,10 @@ 'dataset_query_variable': fields.String, 'pre_prompt': fields.String, 'agent_mode': fields.Raw(attribute='agent_mode_dict'), + 'prompt_type': fields.String, + 'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'), + 'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'), + 'dataset_configs': fields.Raw(attribute='dataset_configs_dict') } app_detail_fields = { diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index dcfbe8a0694a98..df43a62fb6b976 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -123,6 +123,7 @@ def format(self, value): 'from_end_user_id': fields.String, 'from_end_user_session_id': fields.String, 'from_account_id': fields.String, + 'name': fields.String, 'summary': fields.String(attribute='summary_or_query'), 'read_at': TimestampField, 'created_at': TimestampField, diff --git a/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py new file mode 100644 index 00000000000000..cbb04bb01eeca1 --- /dev/null +++ b/api/migrations/versions/b3a09c049e8e_add_advanced_prompt_templates.py @@ -0,0 +1,37 @@ +"""add advanced prompt templates + +Revision ID: b3a09c049e8e +Revises: 2e9819ca5b28 +Create Date: 2023-10-10 15:23:23.395420 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'b3a09c049e8e' +down_revision = '2e9819ca5b28' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.add_column(sa.Column('prompt_type', sa.String(length=255), nullable=False, server_default='simple')) + batch_op.add_column(sa.Column('chat_prompt_config', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('completion_prompt_config', sa.Text(), nullable=True)) + batch_op.add_column(sa.Column('dataset_configs', sa.Text(), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('app_model_configs', schema=None) as batch_op: + batch_op.drop_column('dataset_configs') + batch_op.drop_column('completion_prompt_config') + batch_op.drop_column('chat_prompt_config') + batch_op.drop_column('prompt_type') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index f372f516da1e8f..d3f5c8135f1f99 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -93,6 +93,10 @@ class AppModelConfig(db.Model): agent_mode = db.Column(db.Text) sensitive_word_avoidance = db.Column(db.Text) retriever_resource = db.Column(db.Text) + prompt_type = db.Column(db.String(255), nullable=False, default='simple') + chat_prompt_config = db.Column(db.Text) + completion_prompt_config = db.Column(db.Text) + dataset_configs = db.Column(db.Text) @property def app(self): @@ -139,6 +143,18 @@ def user_input_form_list(self) -> dict: def agent_mode_dict(self) -> dict: return json.loads(self.agent_mode) if self.agent_mode else {"enabled": False, "strategy": None, "tools": []} + @property + def chat_prompt_config_dict(self) -> dict: + return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + + @property + def completion_prompt_config_dict(self) -> dict: + return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + + @property + def dataset_configs_dict(self) -> dict: + return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}} + def to_dict(self) -> dict: return { "provider": "", @@ -155,7 +171,11 @@ def to_dict(self) -> dict: "user_input_form": self.user_input_form_list, "dataset_query_variable": self.dataset_query_variable, "pre_prompt": self.pre_prompt, - "agent_mode": self.agent_mode_dict + "agent_mode": self.agent_mode_dict, + "prompt_type": self.prompt_type, + "chat_prompt_config": self.chat_prompt_config_dict, + "completion_prompt_config": self.completion_prompt_config_dict, + "dataset_configs": self.dataset_configs_dict } def from_model_config_dict(self, model_config: dict): @@ -177,6 +197,13 @@ def from_model_config_dict(self, model_config: dict): self.agent_mode = json.dumps(model_config['agent_mode']) self.retriever_resource = json.dumps(model_config['retriever_resource']) \ if model_config.get('retriever_resource') else None + self.prompt_type = model_config.get('prompt_type', 'simple') + self.chat_prompt_config = json.dumps(model_config.get('chat_prompt_config')) \ + if model_config.get('chat_prompt_config') else None + self.completion_prompt_config = json.dumps(model_config.get('completion_prompt_config')) \ + if model_config.get('completion_prompt_config') else None + self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \ + if model_config.get('dataset_configs') else None return self def copy(self): @@ -197,7 +224,11 @@ def copy(self): dataset_query_variable=self.dataset_query_variable, pre_prompt=self.pre_prompt, agent_mode=self.agent_mode, - retriever_resource=self.retriever_resource + retriever_resource=self.retriever_resource, + prompt_type=self.prompt_type, + chat_prompt_config=self.chat_prompt_config, + completion_prompt_config=self.completion_prompt_config, + dataset_configs=self.dataset_configs ) return new_app_model_config diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py new file mode 100644 index 00000000000000..3ef2b6059e5ff4 --- /dev/null +++ b/api/services/advanced_prompt_template_service.py @@ -0,0 +1,56 @@ + +import copy + +from core.prompt.advanced_prompt_templates import CHAT_APP_COMPLETION_PROMPT_CONFIG, CHAT_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_CHAT_PROMPT_CONFIG, COMPLETION_APP_COMPLETION_PROMPT_CONFIG, \ + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, CONTEXT, BAICHUAN_CONTEXT + +class AdvancedPromptTemplateService: + + def get_prompt(self, args: dict) -> dict: + app_mode = args['app_mode'] + model_mode = args['model_mode'] + model_name = args['model_name'] + has_context = args['has_context'] + + if 'baichuan' in model_name: + return self.get_baichuan_prompt(app_mode, model_mode, has_context) + else: + return self.get_common_prompt(app_mode, model_mode, has_context) + + def get_common_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: + if app_mode == 'chat': + if model_mode == 'completion': + return self.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) + elif model_mode == 'chat': + return self.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) + elif app_mode == 'completion': + if model_mode == 'completion': + return self.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, CONTEXT) + elif model_mode == 'chat': + return self.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, CONTEXT) + + def get_completion_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: + if has_context == 'true': + prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] + + return prompt_template + + + def get_chat_prompt(self, prompt_template: str, has_context: bool, context: str) -> dict: + if has_context == 'true': + prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] + + return prompt_template + + + def get_baichuan_prompt(self, app_mode: str, model_mode:str, has_context: bool) -> dict: + if app_mode == 'chat': + if model_mode == 'completion': + return self.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + elif model_mode == 'chat': + return self.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + elif app_mode == 'completion': + if model_mode == 'completion': + return self.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) + elif model_mode == 'chat': + return self.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, BAICHUAN_CONTEXT) \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 916a1078e5baf7..4acb2f346fbfc2 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -3,7 +3,7 @@ from core.agent.agent_executor import PlanningStrategy from core.model_providers.model_provider_factory import ModelProviderFactory -from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.models.entity.model_params import ModelType, ModelMode from models.account import Account from services.dataset_service import DatasetService @@ -34,40 +34,28 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: # max_tokens if 'max_tokens' not in cp: cp["max_tokens"] = 512 - # - # if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ - # llm_constant.max_context_token_length[model_name]: - # raise ValueError( - # "max_tokens must be an integer greater than 0 " - # "and not exceeding the maximum value of the corresponding model") - # + # temperature if 'temperature' not in cp: cp["temperature"] = 1 - # - # if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2: - # raise ValueError("temperature must be a float between 0 and 2") - # + # top_p if 'top_p' not in cp: cp["top_p"] = 1 - # if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2: - # raise ValueError("top_p must be a float between 0 and 2") - # # presence_penalty if 'presence_penalty' not in cp: cp["presence_penalty"] = 0 - # if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2: - # raise ValueError("presence_penalty must be a float between -2 and 2") - # # presence_penalty if 'frequency_penalty' not in cp: cp["frequency_penalty"] = 0 - # if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: - # raise ValueError("frequency_penalty must be a float between -2 and 2") + # stop + if 'stop' not in cp: + cp["stop"] = [] + elif not isinstance(cp["stop"], list): + raise ValueError("stop in model.completion_params must be of list type") # Filter out extra parameters filtered_cp = { @@ -75,7 +63,8 @@ def validate_model_completion_params(cp: dict, model_name: str) -> dict: "temperature": cp["temperature"], "top_p": cp["top_p"], "presence_penalty": cp["presence_penalty"], - "frequency_penalty": cp["frequency_penalty"] + "frequency_penalty": cp["frequency_penalty"], + "stop": cp["stop"] } return filtered_cp @@ -211,6 +200,10 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: model_ids = [m['id'] for m in model_list] if config["model"]["name"] not in model_ids: raise ValueError("model.name must be in the specified model list") + + # model.mode + if 'mode' not in config['model'] or not config['model']["mode"]: + config['model']["mode"] = "" # model.completion_params if 'completion_params' not in config["model"]: @@ -339,6 +332,9 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: # dataset_query_variable AppModelConfigService.is_dataset_query_variable_valid(config, mode) + # advanced prompt validation + AppModelConfigService.is_advanced_prompt_valid(config, mode) + # Filter out extra parameters filtered_config = { "opening_statement": config["opening_statement"], @@ -351,12 +347,17 @@ def validate_configuration(tenant_id: str, account: Account, config: dict, mode: "model": { "provider": config["model"]["provider"], "name": config["model"]["name"], + "mode": config['model']["mode"], "completion_params": config["model"]["completion_params"] }, "user_input_form": config["user_input_form"], "dataset_query_variable": config.get('dataset_query_variable'), "pre_prompt": config["pre_prompt"], - "agent_mode": config["agent_mode"] + "agent_mode": config["agent_mode"], + "prompt_type": config["prompt_type"], + "chat_prompt_config": config["chat_prompt_config"], + "completion_prompt_config": config["completion_prompt_config"], + "dataset_configs": config["dataset_configs"] } return filtered_config @@ -375,4 +376,51 @@ def is_dataset_query_variable_valid(config: dict, mode: str) -> None: if dataset_exists and not dataset_query_variable: raise ValueError("Dataset query variable is required when dataset is exist") + + + @staticmethod + def is_advanced_prompt_valid(config: dict, app_mode: str) -> None: + # prompt_type + if 'prompt_type' not in config or not config["prompt_type"]: + config["prompt_type"] = "simple" + + if config['prompt_type'] not in ['simple', 'advanced']: + raise ValueError("prompt_type must be in ['simple', 'advanced']") + + # chat_prompt_config + if 'chat_prompt_config' not in config or not config["chat_prompt_config"]: + config["chat_prompt_config"] = {} + if not isinstance(config["chat_prompt_config"], dict): + raise ValueError("chat_prompt_config must be of object type") + + # completion_prompt_config + if 'completion_prompt_config' not in config or not config["completion_prompt_config"]: + config["completion_prompt_config"] = {} + + if not isinstance(config["completion_prompt_config"], dict): + raise ValueError("completion_prompt_config must be of object type") + + # dataset_configs + if 'dataset_configs' not in config or not config["dataset_configs"]: + config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}} + + if not isinstance(config["dataset_configs"], dict): + raise ValueError("dataset_configs must be of object type") + + if config['prompt_type'] == 'advanced': + if not config['chat_prompt_config'] and not config['completion_prompt_config']: + raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced") + + if config['model']["mode"] not in ['chat', 'completion']: + raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced") + + if app_mode == 'chat' and config['model']["mode"] == ModelMode.COMPLETION.value: + user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix'] + assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] + + if not user_prefix: + config['completion_prompt_config']['conversation_histories_role']['user_prefix'] = 'Human' + + if not assistant_prefix: + config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant' diff --git a/api/services/completion_service.py b/api/services/completion_service.py index c95905c6c82b13..e2a28357cb2903 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -244,7 +244,8 @@ def close_pubsub(): @classmethod def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], - message_id: str, streaming: bool = True) -> Union[dict | Generator]: + message_id: str, streaming: bool = True, + retriever_from: str = 'dev') -> Union[dict | Generator]: if not user: raise ValueError('user cannot be None') @@ -266,14 +267,11 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], raise MoreLikeThisDisabledError() app_model_config = message.app_model_config - - if message.override_model_configs: - override_model_configs = json.loads(message.override_model_configs) - pre_prompt = override_model_configs.get("pre_prompt", '') - elif app_model_config: - pre_prompt = app_model_config.pre_prompt - else: - raise AppModelConfigBrokenError() + model_dict = app_model_config.model_dict + completion_params = model_dict.get('completion_params') + completion_params['temperature'] = 0.9 + model_dict['completion_params'] = completion_params + app_model_config.model = json.dumps(model_dict) generate_task_id = str(uuid.uuid4()) @@ -282,58 +280,28 @@ def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser], user = cls.get_real_user_instead_of_proxy_obj(user) - generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={ + generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={ 'flask_app': current_app._get_current_object(), 'generate_task_id': generate_task_id, 'detached_app_model': app_model, 'app_model_config': app_model_config, - 'detached_message': message, - 'pre_prompt': pre_prompt, + 'query': message.query, + 'inputs': message.inputs, 'detached_user': user, - 'streaming': streaming + 'detached_conversation': None, + 'streaming': streaming, + 'is_model_config_override': True, + 'retriever_from': retriever_from }) generate_worker_thread.start() - cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id) + # wait for 10 minutes to close the thread + cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, + generate_task_id) return cls.compact_response(pubsub, streaming) - @classmethod - def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, - app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str, - detached_user: Union[Account, EndUser], streaming: bool): - with flask_app.app_context(): - # fixed the state of the model object when it detached from the original session - user = db.session.merge(detached_user) - app_model = db.session.merge(detached_app_model) - message = db.session.merge(detached_message) - - try: - # run - Completion.generate_more_like_this( - task_id=generate_task_id, - app=app_model, - user=user, - message=message, - pre_prompt=pre_prompt, - app_model_config=app_model_config, - streaming=streaming - ) - except ConversationTaskStoppedException: - pass - except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, - LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, - ModelCurrentlyNotSupportError) as e: - PubHandler.pub_error(user, generate_task_id, e) - except LLMAuthorizationError: - PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided')) - except Exception as e: - logging.exception("Unknown Error in completion") - PubHandler.pub_error(user, generate_task_id, e) - finally: - db.session.commit() - @classmethod def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig): if user_inputs is None: diff --git a/api/services/provider_service.py b/api/services/provider_service.py index 34064d0c33cae7..f9acedf8c24f54 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -482,6 +482,9 @@ def get_valid_model_list(self, tenant_id: str, model_type: str) -> list: 'features': [] } + if 'mode' in model: + valid_model_dict['model_mode'] = model['mode'] + if 'features' in model: valid_model_dict['features'] = model['features'] diff --git a/api/tasks/generate_conversation_summary_task.py b/api/tasks/generate_conversation_summary_task.py deleted file mode 100644 index 791f141d5be3c4..00000000000000 --- a/api/tasks/generate_conversation_summary_task.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -import time - -import click -from celery import shared_task -from werkzeug.exceptions import NotFound - -from core.generator.llm_generator import LLMGenerator -from core.model_providers.error import LLMError, ProviderTokenNotInitError -from extensions.ext_database import db -from models.model import Conversation, Message - - -@shared_task(queue='generation') -def generate_conversation_summary_task(conversation_id: str): - """ - Async Generate conversation summary - :param conversation_id: - - Usage: generate_conversation_summary_task.delay(conversation_id) - """ - logging.info(click.style('Start generate conversation summary: {}'.format(conversation_id), fg='green')) - start_at = time.perf_counter() - - conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() - if not conversation: - raise NotFound('Conversation not found') - - try: - # get conversation messages count - history_message_count = conversation.message_count - if history_message_count >= 5 and not conversation.summary: - app_model = conversation.app - if not app_model: - return - - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.asc()).all() - - conversation.summary = LLMGenerator.generate_conversation_summary(app_model.tenant_id, history_messages) - db.session.add(conversation) - db.session.commit() - except (LLMError, ProviderTokenNotInitError): - conversation.summary = '[No Summary]' - db.session.commit() - pass - except Exception as e: - conversation.summary = '[No Summary]' - db.session.commit() - logging.exception(e) - - end_at = time.perf_counter() - logging.info( - click.style('Conversation summary generated: {} latency: {}'.format(conversation_id, end_at - start_at), - fg='green')) diff --git a/api/tests/integration_tests/models/llm/test_anthropic_model.py b/api/tests/integration_tests/models/llm/test_anthropic_model.py index 32013b27aa64ec..f0636f6e796ecf 100644 --- a/api/tests/integration_tests/models/llm/test_anthropic_model.py +++ b/api/tests/integration_tests/models/llm/test_anthropic_model.py @@ -44,7 +44,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt): model = get_mock_model('claude-2') rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 6 diff --git a/api/tests/integration_tests/models/llm/test_azure_openai_model.py b/api/tests/integration_tests/models/llm/test_azure_openai_model.py index 1df272d1cccfb9..9d289f404dca05 100644 --- a/api/tests/integration_tests/models/llm/test_azure_openai_model.py +++ b/api/tests/integration_tests/models/llm/test_azure_openai_model.py @@ -69,7 +69,7 @@ def test_chat_get_num_tokens(mock_decrypt, mocker): openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) rst = openai_model.get_num_tokens([ PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 22 diff --git a/api/tests/integration_tests/models/llm/test_baichuan_model.py b/api/tests/integration_tests/models/llm/test_baichuan_model.py index 15610e1d1d6ab4..c70b14ce2b2950 100644 --- a/api/tests/integration_tests/models/llm/test_baichuan_model.py +++ b/api/tests/integration_tests/models/llm/test_baichuan_model.py @@ -48,7 +48,7 @@ def test_chat_get_num_tokens(mock_decrypt): model = get_mock_model('baichuan2-53b') rst = model.get_num_tokens([ PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst > 0 @@ -59,7 +59,7 @@ def test_chat_run(mock_decrypt, mocker): model = get_mock_model('baichuan2-53b') messages = [ - PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') ] rst = model.run( messages, @@ -73,7 +73,7 @@ def test_chat_stream_run(mock_decrypt, mocker): model = get_mock_model('baichuan2-53b', streaming=True) messages = [ - PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') ] rst = model.run( messages diff --git a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py index eda95102c9121a..2c8c4556bc1b4d 100644 --- a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py +++ b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py @@ -71,7 +71,7 @@ def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mock mocker ) rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 @@ -88,7 +88,7 @@ def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocke mocker ) rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_minimax_model.py b/api/tests/integration_tests/models/llm/test_minimax_model.py index d93f8ad7353acc..43634f34995877 100644 --- a/api/tests/integration_tests/models/llm/test_minimax_model.py +++ b/api/tests/integration_tests/models/llm/test_minimax_model.py @@ -48,7 +48,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt): model = get_mock_model('abab5.5-chat') rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_openai_model.py b/api/tests/integration_tests/models/llm/test_openai_model.py index 3deeb2f02c5923..e6044c0bb56f3f 100644 --- a/api/tests/integration_tests/models/llm/test_openai_model.py +++ b/api/tests/integration_tests/models/llm/test_openai_model.py @@ -52,7 +52,7 @@ def test_chat_get_num_tokens(mock_decrypt): openai_model = get_mock_openai_model('gpt-3.5-turbo') rst = openai_model.get_num_tokens([ PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 22 diff --git a/api/tests/integration_tests/models/llm/test_openllm_model.py b/api/tests/integration_tests/models/llm/test_openllm_model.py index d515f350481ce3..8a70e6ace4fd53 100644 --- a/api/tests/integration_tests/models/llm/test_openllm_model.py +++ b/api/tests/integration_tests/models/llm/test_openllm_model.py @@ -55,7 +55,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt, mocker): model = get_mock_model('facebook/opt-125m', mocker) rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_replicate_model.py b/api/tests/integration_tests/models/llm/test_replicate_model.py index 13efc198814b3b..d5e55def41f75f 100644 --- a/api/tests/integration_tests/models/llm/test_replicate_model.py +++ b/api/tests/integration_tests/models/llm/test_replicate_model.py @@ -58,7 +58,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt, mocker): model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 7 diff --git a/api/tests/integration_tests/models/llm/test_spark_model.py b/api/tests/integration_tests/models/llm/test_spark_model.py index d07bfb279a7c13..e6fa45f0cbeb3a 100644 --- a/api/tests/integration_tests/models/llm/test_spark_model.py +++ b/api/tests/integration_tests/models/llm/test_spark_model.py @@ -52,7 +52,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt): model = get_mock_model('spark') rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 6 diff --git a/api/tests/integration_tests/models/llm/test_tongyi_model.py b/api/tests/integration_tests/models/llm/test_tongyi_model.py index 8c34497ac7388a..b448c29f47464d 100644 --- a/api/tests/integration_tests/models/llm/test_tongyi_model.py +++ b/api/tests/integration_tests/models/llm/test_tongyi_model.py @@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt): model = get_mock_model('qwen-turbo') rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_wenxin_model.py b/api/tests/integration_tests/models/llm/test_wenxin_model.py index 29a0de3262001a..8cc4779160b660 100644 --- a/api/tests/integration_tests/models/llm/test_wenxin_model.py +++ b/api/tests/integration_tests/models/llm/test_wenxin_model.py @@ -46,7 +46,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt): model = get_mock_model('ernie-bot') rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_xinference_model.py b/api/tests/integration_tests/models/llm/test_xinference_model.py index aab075fae28339..01d5fcdd9f8892 100644 --- a/api/tests/integration_tests/models/llm/test_xinference_model.py +++ b/api/tests/integration_tests/models/llm/test_xinference_model.py @@ -57,7 +57,7 @@ def decrypt_side_effect(tenant_id, encrypted_api_key): def test_get_num_tokens(mock_decrypt, mocker): model = get_mock_model('llama-2-chat', mocker) rst = model.get_num_tokens([ - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst == 5 diff --git a/api/tests/integration_tests/models/llm/test_zhipuai_model.py b/api/tests/integration_tests/models/llm/test_zhipuai_model.py index 4bc47bec9b539a..8f1a60e8f25d64 100644 --- a/api/tests/integration_tests/models/llm/test_zhipuai_model.py +++ b/api/tests/integration_tests/models/llm/test_zhipuai_model.py @@ -46,7 +46,7 @@ def test_chat_get_num_tokens(mock_decrypt): model = get_mock_model('chatglm_lite') rst = model.get_num_tokens([ PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), - PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + PromptMessage(type=MessageType.USER, content='Who is your manufacturer?') ]) assert rst > 0 @@ -57,7 +57,7 @@ def test_chat_run(mock_decrypt, mocker): model = get_mock_model('chatglm_lite') messages = [ - PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') ] rst = model.run( messages, @@ -71,7 +71,7 @@ def test_chat_stream_run(mock_decrypt, mocker): model = get_mock_model('chatglm_lite', streaming=True) messages = [ - PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?') + PromptMessage(type=MessageType.USER, content='Are you Human? you MUST only answer `y` or `n`?') ] rst = model.run( messages diff --git a/api/tests/unit_tests/model_providers/fake_model_provider.py b/api/tests/unit_tests/model_providers/fake_model_provider.py index 4e14d5924efe8b..35c44061dccf8c 100644 --- a/api/tests/unit_tests/model_providers/fake_model_provider.py +++ b/api/tests/unit_tests/model_providers/fake_model_provider.py @@ -1,7 +1,7 @@ from typing import Type from core.model_providers.models.base import BaseProviderModel -from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, ModelMode from core.model_providers.models.llm.openai_model import OpenAIModel from core.model_providers.providers.base import BaseModelProvider @@ -12,7 +12,10 @@ def provider_name(self): return 'fake' def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: - return [{'id': 'test_model', 'name': 'Test Model'}] + return [{'id': 'test_model', 'name': 'Test Model', 'mode': 'completion'}] + + def _get_text_generation_model_mode(self, model_name) -> str: + return ModelMode.COMPLETION.value def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: return OpenAIModel diff --git a/api/tests/unit_tests/model_providers/test_base_model_provider.py b/api/tests/unit_tests/model_providers/test_base_model_provider.py index 7d6e56eb0ac0bd..534599c3199810 100644 --- a/api/tests/unit_tests/model_providers/test_base_model_provider.py +++ b/api/tests/unit_tests/model_providers/test_base_model_provider.py @@ -24,7 +24,7 @@ def test_get_supported_model_list(mocker): provider = FakeModelProvider(provider=Provider()) result = provider.get_supported_model_list(ModelType.TEXT_GENERATION) - assert result == [{'id': 'test_model', 'name': 'test_model'}] + assert result == [{'id': 'test_model', 'name': 'test_model', 'mode': 'completion'}] def test_check_quota_over_limit(mocker):