Skip to content

Commit

Permalink
finish multimodal code
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Nov 30, 2024
1 parent ea7c246 commit f9d24e8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
2 changes: 1 addition & 1 deletion agi/llms/text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def handle_output(self, image: Any) -> AIMessage:
# Format the result as HTML with embedded image and prompt
formatted_result = f'<img src="{image_source}" {style}>\n'
result = AIMessage(content=[{"type": "text", "text": formatted_result},
{"type": ImageType.PIL_IMAGE, ImageType.PIL_IMAGE: image}])
{"type": "media", "media": image}])
return result

def _save_or_resize_image(self, image: Any) -> str:
Expand Down
14 changes: 12 additions & 2 deletions agi/tasks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,21 @@
from langchain_core.prompt_values import StringPromptValue

def build_messages(input :dict):
if input.get("media") is None:

media = None
media_dict = input.get("media")
if media_dict.get('data'): # 首先检查 'data' 字段
media = media_dict['data']
elif media_dict.get('path'): # 其次检查 'path' 字段
media = media_dict['path']
elif media_dict.get('url'): # 最后检查 'url' 字段
media = media_dict['url']
else:
return HumanMessage(content=input.get("text"))

return HumanMessage(content=[
{"type": "text", "text": input.get("text")},
{"type": "media", "media": input.get("media")},
{"type": "media", "media": media},
])

def parse_input(input: StringPromptValue) -> dict:
Expand Down
7 changes: 5 additions & 2 deletions agi/tasks/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ def stock_code_prompt(input_text):

# return _message

# TODO media只能是字符串不能是对象
# 使用字典对象,在传入渲染参数是将对象或者字符串包装在字典中
multimodal_input_template = PromptTemplate(
template='{{"text":"{text}","media":{media}}}',
partial_variables={"text":None,"media":"null"}
template='{"media":{"url":"{{url}}","path":"{{path}}","data":{{data}}},"text":"{{text}}"}',
partial_variables={"text":None,"url":None,"path":None,"data":"null"},
template_format="mustache"
)
33 changes: 27 additions & 6 deletions tests/tasks/task_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,38 @@

class TestTaskFactory(unittest.TestCase):

# def test_translate_chain(self):
# # Test for TASK_LLM
# llm_task = TaskFactory.create_task(TASK_TRANSLATE)
# resp = llm_task.invoke({"text":"我爱北京天安门"})
# self.assertIsInstance(resp,str)
# self.assertIsNotNone(resp)
def test_translate_chain(self):
# Test for TASK_LLM
llm_task = TaskFactory.create_task(TASK_TRANSLATE)
resp = llm_task.invoke({"text":"我爱北京天安门"})
self.assertIsInstance(resp,str)
self.assertIsNotNone(resp)
def test_text2speech_chain(self):
# Test for TASK_LLM
llm_task = TaskFactory.create_task(TASK_TTS)
resp = llm_task.invoke({"text":"These prompt templates are used to format a single string, and generally are used for simpler inputs"})
self.assertIsInstance(resp,AIMessage)
self.assertIsInstance(resp.content,list)
self.assertIsNotNone(resp.content[1].get("media"))
self.assertEqual(resp.content[1].get("type"),"media")
def test_speech2text_chain(self):
llm_task = TaskFactory.create_task(TASK_SPEECH_TEXT)
resp = llm_task.invoke({"path":"tests/1730604079.wav"})
self.assertIsInstance(resp,AIMessage)
self.assertIsInstance(resp.content,str)
self.assertIsNotNone(resp.content)
def test_text2image_chain(self):
llm_task = TaskFactory.create_task(TASK_IMAGE_GEN)
resp = llm_task.invoke({"text":"星辰大海"})
print(resp)
self.assertIsInstance(resp,AIMessage)
self.assertIsInstance(resp.content,list)
self.assertIsNotNone(resp.content[1].get("media"))
self.assertEqual(resp.content[1].get("type"),"media")
# self.assertIsNotNone(resp.content)
resp = llm_task.invoke({"text":"猫咪在游泳","path":"tests/cat.jpg"})
self.assertIsInstance(resp,AIMessage)
self.assertIsInstance(resp.content,list)
self.assertIsNotNone(resp.content[1].get("media"))
self.assertEqual(resp.content[1].get("type"),"media")

Expand Down

0 comments on commit f9d24e8

Please sign in to comment.