Skip to content

Commit

Permalink
feat: advanced prompt backend (langgenius#1301)
Browse files Browse the repository at this point in the history
Co-authored-by: takatost <[email protected]>
  • Loading branch information
GarfieldDai and takatost authored Oct 12, 2023
1 parent 2d1cb07 commit 42a5b3e
Show file tree
Hide file tree
Showing 61 changed files with 762 additions and 576 deletions.
24 changes: 22 additions & 2 deletions api/constants/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
Expand Down Expand Up @@ -81,6 +82,7 @@
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
Expand Down Expand Up @@ -137,10 +139,11 @@
},
opening_statement='',
suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n",
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
Expand Down Expand Up @@ -169,6 +172,13 @@
'Italian',
]
}
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
Expand Down Expand Up @@ -200,6 +210,7 @@
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
Expand Down Expand Up @@ -255,10 +266,11 @@
},
opening_statement='',
suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
Expand Down Expand Up @@ -287,6 +299,13 @@
"意大利语",
]
}
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
Expand Down Expand Up @@ -318,6 +337,7 @@
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from . import setup, version, apikey, admin

# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio

# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate
Expand Down
26 changes: 26 additions & 0 deletions api/controllers/console/app/advanced_prompt_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from flask_restful import Resource, reqparse

from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService

class AdvancedPromptTemplateList(Resource):

@setup_required
@login_required
@account_initialization_required
def get(self):

parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()

service = AdvancedPromptTemplateService()
return service.get_prompt(args)

api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')
30 changes: 0 additions & 30 deletions api/controllers/console/app/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,35 +12,6 @@
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError


class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()

account = current_user

try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))

return {'introduction': answer}


class RuleGenerateApi(Resource):
@setup_required
@login_required
Expand Down Expand Up @@ -72,5 +43,4 @@ def post(self):
return rules


api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')
2 changes: 1 addition & 1 deletion api/controllers/console/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def get(self, app_id, message_id):
message_id = str(message_id)

# get app info
app_model = _get_app(app_id, 'chat')
app_model = _get_app(app_id)

message = db.session.query(Message).filter(
Message.id == message_id,
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/web/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def get(self, app_model, end_user, message_id):
streaming = args['response_mode'] == 'streaming'

try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")
Expand Down
90 changes: 26 additions & 64 deletions api/core/completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
from typing import Optional, List, Union

Expand All @@ -16,10 +15,8 @@
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
from core.prompt.prompt_template import PromptTemplateParser
from models.model import App, AppModelConfig, Account, Conversation, EndUser


class Completion:
Expand All @@ -30,7 +27,7 @@ def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, quer
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)

memory = None
if conversation:
Expand Down Expand Up @@ -160,14 +157,28 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)

model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])

cls.recale_llm_max_tokens(
model_instance=model_instance,
Expand All @@ -176,7 +187,7 @@ def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: App

response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
Expand Down Expand Up @@ -266,52 +277,3 @@ def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[Pr
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)

@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):

final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)

# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)

original_completion = message.answer.strip()

prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)

prompt_messages = [PromptMessage(content=prompt)]

conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)

cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)

final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)
22 changes: 11 additions & 11 deletions api/core/conversation_message_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
Expand Down Expand Up @@ -74,10 +74,10 @@ def init(self):
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass

Expand Down Expand Up @@ -150,20 +150,20 @@ def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens

message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)

message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price

self.message.message = llm_message.prompt
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
Expand Down Expand Up @@ -226,15 +226,15 @@ def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) ->

def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)

loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens

loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price

Expand Down
Loading

0 comments on commit 42a5b3e

Please sign in to comment.