Skip to content

Commit

Permalink
add factory test
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Nov 9, 2024
1 parent 72a4cbf commit bcc4cd1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 12 deletions.
6 changes: 2 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@ langchain-chroma==0.1.4
langchain_openai==0.2.5
langgraph==0.2.44
langchain_community==0.3.5
langchain_ollama==0.2.0
transformers==4.45.0
spacy[ja]==3.8.2
# --index-url https://download.pytorch.org/whl/cu121
# torch==2.5.1+cu121
torch==2.4.0
torchaudio==2.4.0
# TTS==0.22.0
librosa==0.10.1
git+https://github.com/neoguojing/TTS.git
ctranslate2==4.4.0
# nvidia-cudnn-cu12==8.*
faster-whisper==1.0.3
diffusers==0.21.2
accelerate==1.1.1
Expand Down
49 changes: 41 additions & 8 deletions tests/llms/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
import unittest
from agi.llms.model_factory import ModelFactory
from agi.llms.base import build_multi_modal_message,ImageType,AudioType
from langchain_core.messages import AIMessage, HumanMessage


class TestModelFactory(unittest.TestCase):

def test_get_model(self):
ollama_model = ModelFactory.get_model("ollama")
resp = ollama_model.invoke("介绍下美国")
print(type(resp))
self.assertIsNotNone(resp.content)
self.assertEqual(len(ModelFactory._instances),1)
ModelFactory.destroy("ollama")
self.assertEqual(len(ModelFactory._instances),0)

instance = ModelFactory.get_model("text2image")
input = HumanMessage(content="a chinese leader")
resp = instance.invoke(input)
self.assertIsNotNone(resp.content)
self.assertEqual(len(ModelFactory._instances),1)
ModelFactory.destroy("text2image")
self.assertEqual(len(ModelFactory._instances),0)

instance = ModelFactory.get_model("image2image")
input = build_multi_modal_message("as a cat woman","tests/cat.jpg",ImageType.FILE_PATH)
resp = instance.invoke(input)
self.assertIsNotNone(resp.content)
self.assertEqual(len(ModelFactory._instances),1)
ModelFactory.destroy("image2image")
self.assertEqual(len(ModelFactory._instances),0)

instance = ModelFactory.get_model("speech2text")
input = self.input = build_multi_modal_message("","tests/1730604079.wav",AudioType.FILE_PATH)
resp = instance.invoke(input)
self.assertIsNotNone(resp.content)
print(resp.content)
self.assertEqual(len(ModelFactory._instances),1)
ModelFactory.destroy("speech2text")
self.assertEqual(len(ModelFactory._instances),0)

instance = ModelFactory.get_model("text2speech")
input = HumanMessage(content="岁的思考的加快速度为空军党委科技")
resp = instance.invoke(input)
self.assertIsNotNone(resp.content)
print(resp.content)
self.assertEqual(len(ModelFactory._instances),1)
ModelFactory.destroy("text2speech")
self.assertEqual(len(ModelFactory._instances),0)

# def test_release_model(self):
# self.assertIsNotNone(output)
# self.assertIsNotNone(output.image)
# self.assertIsNotNone(output.image.pil_image)
# self.assertIsNotNone(output.content)
# print(output.content)


if __name__ == "__main__":
unittest.main()

0 comments on commit bcc4cd1

Please sign in to comment.