Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Nov 3, 2024
1 parent ac40b00 commit 49962f9
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 113 deletions.
21 changes: 8 additions & 13 deletions agi/llms/image2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
from datetime import date
from pathlib import Path
from diffusers import AutoPipelineForText2Image
from diffusers import AutoPipelineForImage2Image
import torch
from typing import Any, List, Mapping, Optional,Union
Expand All @@ -26,18 +25,14 @@ class Image2Image(CustomerLLM):
save_image: bool = True

def __init__(self, model_path: str=os.path.join(model_root,"sdxl-turbo"),**kwargs):
if model_path is not None:
super(Image2Image, self).__init__(
llm=AutoPipelineForImage2Image.from_pretrained(
os.path.join(model_root,"sdxl-turbo"), torch_dtype=torch.float16, variant="fp16"
))
# self.model.save_pretrained(os.path.join(model_root,"sdxl-turbo"))
self.model_path = model_path
else:
super(StableDiff, self).__init__(
llm=AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16"
))

super(Image2Image, self).__init__(
llm=AutoPipelineForImage2Image.from_pretrained(
os.path.join(model_root,"sdxl-turbo"), torch_dtype=torch.float16, variant="fp16"
))
# self.model.save_pretrained(os.path.join(model_root,"sdxl-turbo"))
self.model_path = model_path

# self.model.to(self.device)
# 使用cpu和to('cuda')互斥,内存减小一半
self.model.enable_model_cpu_offload()
Expand Down
9 changes: 0 additions & 9 deletions agi/llms/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,6 @@

import threading
import gc
from langchain_community.chat_models import QianfanChatEndpoint
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# openai
os.environ['OPENAI_API_KEY'] = ''
# qianfan
os.environ["QIANFAN_AK"] = "your_ak"
os.environ["QIANFAN_SK"] = "your_sk"
# tongyi
os.environ["DASHSCOPE_API_KEY"] = ""

from urllib.parse import urljoin
class ModelFactory:
Expand Down
2 changes: 0 additions & 2 deletions agi/tasks/llm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from langchain.retrievers import EnsembleRetriever
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_ollama import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
from langchain_core.messages.ai import AIMessage,AIMessageChunk
from langchain_core.runnables.utils import AddableDict
from langchain_core.runnables.base import Runnable
Expand Down
100 changes: 11 additions & 89 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,90 +1,12 @@
fastapi==0.111.0
uvicorn[standard]==0.30.6
pydantic==2.8.2
python-multipart==0.0.9

Flask==3.0.3
Flask-Cors==5.0.0

python-socketio==5.11.3
python-jose==3.3.0
passlib[bcrypt]==1.7.4

requests==2.32.3
aiohttp==3.10.5

sqlalchemy==2.0.32
alembic==1.13.2
peewee==3.17.6
peewee-migrate==1.12.2
psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.2.0

pymongo
redis
boto3==1.35.0

argon2-cffi==23.1.0
APScheduler==3.10.4

# AI libraries
openai
anthropic
google-generativeai==0.7.2
tiktoken

langchain==0.2.15
langchain-community==0.2.12
langchain-chroma==0.1.2
langchain_openai
langchain_ollama
langchain_chroma
langgraph
TTS==0.21.3
duckduckgo-search~=6.2.11

fake-useragent==1.5.1
chromadb==0.5.5
sentence-transformers==3.0.1
pypdf==4.3.1
docx2txt==0.8
python-pptx==1.0.0
unstructured==0.15.9
nltk==3.9.1
Markdown==3.7
pypandoc==1.13
pandas==2.2.2
openpyxl==3.1.5
pyxlsb==1.0.10
xlrd==2.0.1
validators==0.33.0
psutil

opencv-python-headless==4.10.0.84
rapidocr-onnxruntime==1.3.24

fpdf2==2.7.9
rank-bm25==0.2.2

faster-whisper==1.0.3

PyJWT[crypto]==2.9.0
authlib==1.3.2

black==24.8.0
langfuse==2.44.0
youtube-transcript-api==0.6.2
pytube==15.0.0

extract_msg
pydub


## Tests
docker~=7.1.0
pytest~=8.3.2
pytest-docker~=3.1.1

playwright
langchain==0.3.7
langchain-chroma==0.1.4
langchain_openai==0.2.5
langgraph==0.2.44
langchain_community==
TTS==0.22.0
faster-whisper
diffusers
torch
transformers
spacy

20 changes: 20 additions & 0 deletions tests/llms/test_model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import unittest
from agi.llms.model_factory import ModelFactory

class TestModelFactory(unittest.TestCase):

def test_get_model(self):
ollama_model = ModelFactory.get_model("ollama")
resp = ollama_model.invoke("介绍下美国")
print(type(resp))

# def test_release_model(self):
# self.assertIsNotNone(output)
# self.assertIsNotNone(output.image)
# self.assertIsNotNone(output.image.pil_image)
# self.assertIsNotNone(output.content)
# print(output.content)


if __name__ == "__main__":
unittest.main()

0 comments on commit 49962f9

Please sign in to comment.