Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
neoguojing committed Oct 28, 2024
1 parent 4c68e39 commit 6bf3275
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 19 deletions.
16 changes: 7 additions & 9 deletions agi/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def pretty_repr(self) -> List[str]:

class Audio(BaseModel):
url: Optional[HttpUrl] = None # 音频的 URL
file_path: Optional[str] = None
samples: Optional[List[int]] = None # 音频的样本数据
filename: Optional[str] = None # 文件名
filetype: Optional[str] = None # 文件类型 (如 'audio/mpeg', 'audio/wav')
Expand All @@ -77,7 +78,7 @@ def from_local(cls, audio_path: str):
filetype = filename.split('.')[-1] # 简单提取文件扩展名
size = len(binary_data)

return cls(samples=samples, filename=filename, filetype=filetype, size=size)
return cls(samples=samples,file_path=audio_path, filename=filename, filetype=filetype, size=size)

@classmethod
def from_url(cls, url: HttpUrl):
Expand Down Expand Up @@ -174,17 +175,14 @@ def pretty_repr(self, html: bool = False) -> str:
return (base.strip() + "\n" + "\n".join(lines)).strip()


class CustomerLLM(RunnableSerializable[BaseMessage,BaseMessage]):
device: str = Field(torch.device('cpu'))
class CustomerLLM(RunnableSerializable[BaseMessage, BaseMessage]):
device: str = Field(default_factory=lambda: str(torch.device('cpu')))
model: Any = None
tokenizer: Any = None

def __init__(self,llm,**kwargs):
super(CustomerLLM, self).__init__()
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device('cpu')
def __init__(self, llm, **kwargs):
super(CustomerLLM, self).__init__(**kwargs)
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
self.model = llm

def destroy(self):
Expand Down
29 changes: 19 additions & 10 deletions agi/llms/tts.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
import os

from TTS.api import TTS
from agi.config import MODEL_PATH as model_root
from agi.config import MODEL_PATH as model_root,CACHE_DIR
from agi.llms.base import CustomerLLM,MultiModalMessage,Audio
from langchain_core.runnables import RunnableConfig
from typing import Any, List, Mapping, Optional,Union

class TextToSpeech(CustomerLLM):
def __init__(self, model_path: str = os.path.join(model_root,"tts_models--zh-CN--baker--tacotron2-DDC-GST"),
def __init__(self, model_path: str = os.path.join(model_root,"tts_models--multilingual--multi-dataset--xtts_v2"),
speaker_wav: str = os.path.join(model_root,"XTTS-v2","samples/zh-cn-sample.wav"),
language: str = "zh-cn"):
language: str = "zh-cn",save_file: bool = True):
config_path = os.path.join(model_path,"config.json")
self.tts = TTS(model_path=model_path,config_path=config_path).to(self.device)
tts = TTS(model_path=model_path,config_path=config_path)
# self.tts = TTS(model_name="tts_models--zh-CN--baker--tacotron2-DDC-GST").to(self.device)

super(TextToSpeech, self).__init__(llm=tts.synthesizer)
self.tts = tts.to(self.device)
self.speaker_wav = speaker_wav
self.language = language
self.model = self.tts.synthesizer
super(TextToSpeech, self).__init__(llm=self.model)
self.save_file = save_file

def list_available_models(self):
return self.tts.list_models()

def invoke(
self, input: MultiModalMessage, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> MultiModalMessage:
samples = self.tts.tts(text=input.content, speaker_wav=self.speaker_wav, language=self.language)
return MultiModalMessage(content=input.content,audio=Audio(samples=samples))
output = None
if self.save_file:
file_path = self.save_audio_to_file(text=input.content)
output = MultiModalMessage(content=input.content,audio=Audio(file_path=file_path))
else:
samples = self.tts.tts(text=input.content, speaker_wav=self.speaker_wav, language=self.language)
output = MultiModalMessage(content=input.content,audio=Audio(samples=samples))
return output

def save_audio_to_file(self, text: str, file_path: str):
def save_audio_to_file(self, text: str, file_path: str=CACHE_DIR):
# self.tts.tts_to_file(text=text, speaker_wav=self.speaker_wav, language=self.language, file_path=file_path)
self.tts.tts_to_file(text=text, speaker_wav=self.speaker_wav, file_path=file_path)
return self.tts.tts_to_file(text=text, speaker_wav=self.speaker_wav, file_path=file_path)

# 使用示例
if __name__ == "__main__":
Expand Down
23 changes: 23 additions & 0 deletions tests/llms/test_speech2text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest


class TestText2Image(unittest.TestCase):

def setUp(self):
from agi.llms.speech2text import Speech2Text
from agi.llms.base import MultiModalMessage,Audio
import torch
self.instance = Speech2Text()
self.input = MultiModalMessage(content="a midlife crisis man")

def test_image2image(self):
output = self.instance.invoke(self.input)
self.assertIsNotNone(output)
self.assertIsNotNone(output.image)
self.assertIsNotNone(output.image.pil_image)
self.assertIsNotNone(output.content)
print(output.content)


if __name__ == "__main__":
unittest.main()
44 changes: 44 additions & 0 deletions tests/llms/test_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import unittest


class TestTextToSpeech(unittest.TestCase):

def setUp(self):
from agi.llms.tts import TextToSpeech
from agi.llms.base import MultiModalMessage,Audio
import torch
self.instance = TextToSpeech()
print(self.instance.list_available_models())
content = '''
以下是每个缩写的简要解释:
hag: Hanga — 指的是一种语言,主要在巴布亚新几内亚的Hanga地区使用。
hnn: Hanunoo — 指的是菲律宾的一种语言,主要由Hanunoo人使用,属于马来-波利尼西亚语系。
bgc: Haryanvi — 指的是印度哈里亚纳邦的一种方言,属于印地语的一种变体。
had: Hatam — 指的是巴布亚新几内亚的一种语言,主要在Hatam地区使用。
hau: Hausa — 指的是西非的一种语言,广泛用于尼日利亚和尼日尔,是主要的交易语言之一。
hwc: Hawaii Pidgin — 指的是夏威夷的一种克里奥尔语,受英语和夏威夷土著语言影响,常用于当地的日常交流。
hvn: Hawu — 指的是印度尼西亚的一种语言,主要在西努沙登加拉省的Hawu地区使用。
hay: Haya — 指的是坦桑尼亚的一种语言,由Haya人使用,属于尼日尔-刚果语系。
'''
self.input = MultiModalMessage(content=content)

def test_text2speech(self):
output = self.instance.invoke(self.input)
self.assertIsNotNone(output)
self.assertIsNotNone(output.audio)
self.assertIsNotNone(output.audio.file_path)
self.assertIsNotNone(output.content)
print(output.content)
print(output.audio.file_path)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6bf3275

Please sign in to comment.