Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Oct 27, 2024
1 parent 6f85af6 commit 36ae0ff
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 27 deletions.
192 changes: 190 additions & 2 deletions agi/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,198 @@
import torch
from langchain.llms.base import LLM
from pydantic import Field
from typing import Any,Union,List,Dict
from typing import Any,Union,Literal,List,Dict
from langchain_core.runnables import Runnable, RunnableSerializable,RunnableConfig
from langchain_core.messages.base import BaseMessage
from pydantic import BaseModel, HttpUrl, constr
import base64
import requests
from typing import Optional
from PIL import Image as PILImage
from io import BytesIO
from typing import Optional, List
import requests
from diffusers.utils import load_image

class Image:
url: Optional[str] = None # 图片的 URL
pil_image: Optional[PILImage.Image] = None # 使用 PIL 图像对象
filename: Optional[str] = None # 文件名
filetype: Optional[str] = None # 文件类型 (如 'image/jpeg', 'image/png')
size: Optional[int] = None # 文件大小(字节)

class CustomerLLM(LLM):
@classmethod
def new(cls, url_or_path: str):
"""从本地文件创建 Image 实例"""
pil_image = load_image(url_or_path)
filename = url_or_path.split('/')[-1]
filetype = pil_image.format # 使用 PIL 提取文件格式
size = pil_image.tobytes().__sizeof__() # 计算字节大小

instance = cls()
instance.pil_image = pil_image
instance.filename = filename
instance.filetype = filetype
instance.size = size
return instance

def save_image(self, output_path: str):
"""保存 PIL 图像对象为图片文件"""
if self.pil_image:
self.pil_image.save(output_path)

def pretty_repr(self) -> List[str]:
"""返回图片的美观表示"""
lines = [
f"URL: {self.url}" if self.url else "URL: None",
f"Filename: {self.filename}" if self.filename else "Filename: None",
f"Filetype: {self.filetype}" if self.filetype else "Filetype: None",
f"Size: {self.size} bytes" if self.size is not None else "Size: None"
]
return lines



# 使用示例
# image_from_local = Image.from_local("path/to/local/image.png") # 从本地文件创建 Image 实例
# image_from_local.display_info()

# image_from_url = Image.from_url("http://example.com/image.png") # 从 URL 创建 Image 实例
# image_from_url.display_info()

# # 解码并保存
# image_from_url.decode_image("path/to/save/decoded_image.png")

from pydantic import BaseModel, HttpUrl
import requests
import numpy as np
from typing import List, Optional

class Audio(BaseModel):
url: Optional[HttpUrl] = None # 音频的 URL
samples: Optional[List[int]] = None # 音频的样本数据
filename: Optional[str] = None # 文件名
filetype: Optional[str] = None # 文件类型 (如 'audio/mpeg', 'audio/wav')
size: Optional[int] = None # 文件大小(字节)

@classmethod
def from_local(cls, audio_path: str):
"""从本地文件创建 Audio 实例"""
with open(audio_path, "rb") as audio_file:
binary_data = audio_file.read()
# 假设音频是 16-bit PCM
samples = np.frombuffer(binary_data, dtype=np.int16).tolist()
filename = audio_path.split('/')[-1]
filetype = filename.split('.')[-1] # 简单提取文件扩展名
size = len(binary_data)

return cls(samples=samples, filename=filename, filetype=filetype, size=size)

@classmethod
def from_url(cls, url: HttpUrl):
"""从 URL 下载音频并创建 Audio 实例"""
response = requests.get(url)
if response.status_code == 200:
binary_data = response.content
# 假设音频是 16-bit PCM
samples = np.frombuffer(binary_data, dtype=np.int16).tolist()
filename = url.split('/')[-1]
filetype = filename.split('.')[-1] # 简单提取文件扩展名
size = len(binary_data)
return cls(url=url, samples=samples, filename=filename, filetype=filetype, size=size)
else:
raise Exception(f"Failed to download audio: {response.status_code}")

def to_binary(self) -> bytes:
"""将样本数据转换回二进制格式"""
return np.array(self.samples, dtype=np.int16).tobytes()

def pretty_repr(self, html: bool = False) -> List[str]:
"""返回音频的美观表示。
Args:
html: 是否返回 HTML 格式的字符串。
默认值为 False。
Returns:
音频的美观表示。
"""
lines = [
f"URL: {self.url}" if self.url else "URL: None",
f"Filename: {self.filename}" if self.filename else "Filename: None",
f"Filetype: {self.filetype}" if self.filetype else "Filetype: None",
f"Size: {self.size} bytes" if self.size is not None else "Size: None"
]

return lines

# 使用示例
# audio_from_local = Audio.from_local("path/to/local/audio.wav") # 从本地文件创建 Audio 实例
# print(audio_from_local.pretty_repr()) # 以文本形式打印信息
# print(audio_from_local.pretty_repr(html=True)) # 以 HTML 格式打印信息

# audio_from_url = Audio.from_url("http://example.com/audio.wav") # 从 URL 创建 Audio 实例
# print(audio_from_url.pretty_repr())
# print(audio_from_url.pretty_repr(html=True))

# # 将样本数据保存为音频文件
# with open("path/to/save/decoded_audio.wav", "wb") as audio_file:
# audio_file.write(audio_from_url.to_binary())


class MultiModalMessage(BaseMessage):
image: Image = None
audio: Audio = None
"""The type of the message (used for deserialization). Defaults to "ai"."""

def __init__(
self, content: Union[str, list[Union[str, dict]]],image: Image =None,audio: Audio = None, **kwargs: Any
) -> None:
"""Pass in content as positional arg.
Args:
content: The content of the message.
kwargs: Additional arguments to pass to the parent class.
"""
super().__init__(content=content, **kwargs)
self.audio = audio
self.image = image

@classmethod
def get_lc_namespace(cls) -> list[str]:
"""Get the namespace of the langchain object.
Returns:
The namespace of the langchain object.
Defaults to ["langchain", "schema", "messages"].
"""
return ["langchain", "schema", "messages"]

@property
def lc_attributes(self) -> dict:
"""Attrs to be serialized even if they are derived from other init args."""
return {
"image": self.image,
"audio": self.audio,
}

def pretty_repr(self, html: bool = False) -> str:
"""Return a pretty representation of the message.
Args:
html: Whether to return an HTML-formatted string.
Defaults to False.
Returns:
A pretty representation of the message.
"""
base = super().pretty_repr(html=html)
lines = self.image.pretty_repr()
lines.extend(self.audio.pretty_repr())

return (base.strip() + "\n" + "\n".join(lines)).strip()

class CustomerLLM(RunnableSerializable[BaseMessage]):
device: str = Field(torch.device('cpu'))
model: Any = None
tokenizer: Any = None
Expand Down
42 changes: 17 additions & 25 deletions agi/llms/image2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
import torch
from langchain.llms.base import LLM
from typing import Any, List, Mapping, Optional,Union
from langchain.callbacks.manager import (
CallbackManagerForLLMRun
)
from pydantic import Field
from agi.llms import CustomerLLM
from agi.llms.base import CustomerLLM,MultiModalMessage,Image
from agi.config import MODEL_PATH as model_root

from langchain_core.runnables import RunnableConfig

style = 'style="width: 100%; max-height: 100vh;"'

Expand Down Expand Up @@ -54,28 +52,17 @@ def _llm_type(self) -> str:
def model_name(self) -> str:
return "image2image"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
output = ""
if prompt == "":
def invoke(
self, input: MultiModalMessage, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> MultiModalMessage:
output = MultiModalMessage(content="")
if input.content == "" or input.image is None:
return output

image_path = kwargs.pop("image_path","")
image_obj = kwargs.pop("image_obj",None)
print(image_path)
image = None
if image_path != "":
init_image = load_image(image_path).resize((512, 512))
print(type(init_image))
image = self.model(prompt, image=init_image, num_inference_steps=2, strength=0.5, guidance_scale=0.0).images[0]
if image_obj is not None:
image_obj.resize((512, 512))
image = self.model(prompt, image=image_obj, num_inference_steps=2, strength=0.5, guidance_scale=0.0).images[0]
prompt = input.content
input_image = input.image.pil_image.resize((512, 512))
image = self.model(prompt, image=input_image, num_inference_steps=2, strength=0.5, guidance_scale=0.0).images[0]

if image is not None:
output = self.handle_output(image,prompt)
Expand All @@ -87,14 +74,18 @@ def get_inputs(self,prompt:str,batch_size=1):

return {"prompt": prompts, "generator": generator, "num_inference_steps": self.n_steps}

def handle_output(self,image,prompt):
def handle_output(self,image,prompt) -> MultiModalMessage:
img = Image()
img.pil_image =image
output = MultiModalMessage(image=image)
if self.save_image:
file = f'{date.today().strftime("%Y_%m_%d")}/{int(time.time())}' # noqa: E501
output_file = Path(f"{self.file_path}/{file}.png")
output_file.parent.mkdir(parents=True, exist_ok=True)

image.save(output_file)
image_source = f"file/{output_file}"
output.image = Image.new(image_source)
else:
# resize image to avoid huge logs
image.thumbnail((512, 512 * image.height / image.width))
Expand All @@ -109,7 +100,8 @@ def handle_output(self,image,prompt):

formatted_result = f'<img src="{image_source}" {style}>\n'
formatted_result += f'<p> {prompt} </p>'
return formatted_result
output.content = formatted_result
return output

@property
def _identifying_params(self) -> Mapping[str, Any]:
Expand Down

0 comments on commit 36ae0ff

Please sign in to comment.