-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f2ffad9
commit 75599bd
Showing
12 changed files
with
595 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,47 @@ | ||
# KoAP_RF | ||
Бот для ответов на вопросы по кодексу Российской Федерации об административных правонарушениях. | ||
# KoAP_RF | ||
|
||
Бот отвечает на вопросы по документу "Кодекс Российской Федерации об административных правонарушениях" от 30.12.2001 N 195-ФЗ (ред. от 08.08.2024) (с изм. и доп., вступ. в силу с 08.09.2024). | ||
|
||
## Установка | ||
|
||
Для работы скриптов необходимо установить зависимости из файла `requirements.txt`: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Использование | ||
|
||
### 1. Создание базы | ||
|
||
Для создания базы данных и загрузки документа в коллекцию ChromaDB выполните следующую команду: | ||
|
||
```bash | ||
python ingest.py | ||
``` | ||
|
||
### Опции для `run.py` | ||
|
||
Скрипт `run.py` поддерживает несколько опций для настройки его поведения: | ||
|
||
- `--show_sources`, `-s`: Показывать источники вместе с ответами (по умолчанию False). | ||
- Пример использования: | ||
```bash | ||
python run.py --show_sources | ||
``` | ||
- Или: | ||
```bash | ||
python run.py -s | ||
``` | ||
|
||
- `--save_qa`: Сохранять пары вопросов и ответов в CSV файл (по умолчанию False). | ||
- Пример использования: | ||
```bash | ||
python run.py --save_qa | ||
``` | ||
|
||
### Пример использования с обоими флагами | ||
|
||
```bash | ||
python run.py --show_sources --save_qa | ||
``` |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
client_id: | ||
client_secret: | ||
authorization_data: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
from chromadb.config import Settings | ||
|
||
ROOT_DIRECTORY = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
# Путь к исходному документу | ||
SOURCE_DOCUMENT = f"{ROOT_DIRECTORY}/SOURCE_DOCUMENTS/KoAP_RF.docx" | ||
|
||
# Директория для сохранения базы данных | ||
PERSIST_DIRECTORY = f"{ROOT_DIRECTORY}/DB" | ||
|
||
# Настройки для ChromaDB | ||
CHROMA_SETTINGS = Settings( | ||
anonymized_telemetry=False, is_persistent=True, persist_directory=PERSIST_DIRECTORY | ||
) | ||
|
||
# Имя коллекции в базе данных | ||
COLLECTION_NAME = "KoAP_RF_test_v4.1" | ||
|
||
# Максимальное количество токенов для одного чанка | ||
MAX_N_TOKENS = 512 | ||
|
||
# Шаг для токенизации | ||
STRIDE = 128 | ||
|
||
# Имя модели для генерации эмбеддингов | ||
EMBEDDING_MODEL_NAME = "ai-forever/ru-en-RoSBERTa" | ||
# EMBEDDING_MODEL_NAME = "deepvk/roberta-base" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from constants import EMBEDDING_MODEL_NAME | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from transformers import AutoTokenizer, AutoModel | ||
|
||
from typing import List, Tuple | ||
|
||
# Функция для пулинга скрытых состояний | ||
def pool(hidden_state, mask, pooling_method="cls"): | ||
if pooling_method == "mean": | ||
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1) | ||
d = mask.sum(axis=1, keepdim=True).float() | ||
return s / d | ||
elif pooling_method == "cls": | ||
return hidden_state[:, 0] | ||
|
||
class EmbeddingGenerator: | ||
def __init__(self): | ||
self.tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME) | ||
self.model = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME) | ||
|
||
def token_chunker(self, context, max_chunk_size=512, stride=128, min_chunk_len=50): | ||
# Токенизируем контекст один раз. | ||
# Если контекст > max_chunk_size, разбиваем его на несколько чанков с перекрытием stride | ||
|
||
context_tokens = self.tokenizer.encode(context, add_special_tokens=False) | ||
|
||
chunk_holder = [] | ||
chunk_size = max_chunk_size | ||
current_pos = 0 | ||
while current_pos < len(context_tokens): | ||
end_point = ( | ||
current_pos + chunk_size | ||
if (current_pos + chunk_size) < len(context_tokens) | ||
else len(context_tokens) | ||
) | ||
token_chunk = context_tokens[current_pos:end_point] | ||
|
||
# Пропускаем чанки, которые короче min_chunk_len | ||
if len(token_chunk) < min_chunk_len: | ||
current_pos = end_point | ||
continue | ||
|
||
# Создаем маску внимания для каждого токена | ||
attention_mask = torch.ones((1, len(token_chunk)), dtype=torch.int32) | ||
|
||
# Преобразуем чанк токенов в тензор | ||
token_chunk = torch.tensor(token_chunk, dtype=torch.int32).unsqueeze(0) | ||
|
||
chunk_holder.append( | ||
{ | ||
"token_ids": token_chunk, | ||
"context": self.tokenizer.decode( | ||
context_tokens[current_pos:end_point], skip_special_tokens=True | ||
), | ||
"attention_mask": attention_mask, | ||
} | ||
) | ||
current_pos = current_pos + chunk_size - stride | ||
|
||
return chunk_holder | ||
|
||
# Метод для получения эмбеддингов | ||
def get_embeddings( | ||
self, context: str, max_length: int = 512, overlap: int = 128, min_chunk_len: int = 10 | ||
) -> Tuple[List[str], List[List[float]]]: | ||
chunks = self.token_chunker(context, max_chunk_size=max_length, stride=overlap, min_chunk_len=min_chunk_len) | ||
|
||
embeddings = [] | ||
documents = [] | ||
for chunk in chunks: | ||
input_ids = chunk["token_ids"] | ||
attention_mask = chunk["attention_mask"] | ||
|
||
with torch.no_grad(): | ||
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
chunk_embeddings = pool( | ||
outputs.last_hidden_state, | ||
attention_mask, | ||
pooling_method="cls", # или попробуйте "mean" | ||
) | ||
embeddings.append(chunk_embeddings) | ||
|
||
documents.append(chunk["context"]) | ||
|
||
# Проверка, что список embeddings не пуст | ||
if not embeddings: | ||
return [], [] # Возвращаем пустые списки, если нет валидных чанков | ||
|
||
# Объединяем эмбеддинги из всех чанков | ||
embeddings = torch.cat(embeddings, dim=0) | ||
|
||
embeddings_list = embeddings.tolist() | ||
|
||
return documents, embeddings_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import json | ||
import requests | ||
|
||
|
||
def get_gpt_response(content, access_token): | ||
|
||
url = "https://gigachat.devices.sberbank.ru/api/v1/chat/completions" | ||
|
||
system_promt = "Ты профессиональный юрист. Ответь на вопрос пользователя используя статью нормативного акта. Ничего не придумывай. Не пиши ничего лишнего. Если не знаешь ответа, то напиши «не знаю»." | ||
|
||
payload = json.dumps( | ||
{ | ||
"model": "GigaChat", | ||
"messages": [ | ||
{"role": "system", "content": system_promt}, | ||
{"role": "user", "content": content}, | ||
], | ||
} | ||
) | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Accept": "application/json", | ||
"Authorization": f"Bearer {access_token}", | ||
} | ||
|
||
response = requests.request( | ||
"POST", url, headers=headers, data=payload, verify="./russiantrustedca.pem" | ||
) | ||
return json.loads(response.text)["choices"][0]["message"]["content"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import re | ||
import chromadb | ||
from docx import Document | ||
|
||
from constants import ( | ||
CHROMA_SETTINGS, | ||
SOURCE_DOCUMENT, | ||
COLLECTION_NAME, | ||
MAX_N_TOKENS, | ||
STRIDE, | ||
) | ||
|
||
from embedder import EmbeddingGenerator | ||
|
||
from tqdm import tqdm | ||
from typing import List, Tuple | ||
import logging | ||
|
||
def get_collection(chroma_client, collection_name: str): | ||
""" | ||
Проверяет существование коллекции и создает её, если она не существует. | ||
Параметры: | ||
chroma_client: Клиент для работы с коллекциями. | ||
collection_name (str): Имя коллекции. | ||
Возвращает: | ||
Any: Объект коллекции. | ||
""" | ||
# Проверяем, существует ли уже коллекция | ||
collections = chroma_client.list_collections() | ||
if any(collection.name == collection_name for collection in collections): | ||
logging.info(f"Коллекция '{collection_name}' уже существует.") | ||
return chroma_client.get_collection(name=collection_name) | ||
|
||
collection = chroma_client.create_collection(name=collection_name) | ||
logging.info(f"Коллекция '{collection_name}' успешно создана.") | ||
return collection | ||
|
||
|
||
def load_documents(doc_path: str) -> Tuple[List[str], List[str]]: | ||
""" | ||
Загружает документ и разбивает его на сегменты текста, основываясь на шаблоне "Статья..." | ||
Параметры: | ||
doc_path (str): Путь к документу. | ||
Возвращает: | ||
tuple: Кортеж, содержащий список текстовых сегментов и список идентификаторов сегментов. | ||
""" | ||
|
||
doc = Document(doc_path) | ||
|
||
text_segments = [] | ||
text_ids = [] | ||
|
||
segment_name = "" | ||
segment_txt = "" | ||
|
||
pattern = re.compile(r"^Статья \d+(\.\d+){0,2}(\-\d+)?\.?") | ||
|
||
# Итерируемся по каждому параграфу в документе | ||
for paragraph in doc.paragraphs: | ||
para = paragraph.text.strip() | ||
if para: | ||
match = pattern.match(para) | ||
if match: | ||
if segment_name: | ||
text_segments.append(re.sub(pattern, "", segment_txt).strip()) | ||
text_ids.append(segment_name) | ||
segment_name = match.group(0) | ||
segment_txt = para | ||
else: | ||
segment_txt += "\n" + para | ||
|
||
return text_segments, text_ids | ||
|
||
|
||
def main(): | ||
""" | ||
Основная функция для загрузки документа, разбиения его на части, генерации эмбеддингов и добавления их в коллекцию. | ||
Шаги: | ||
1. Создание клиента и получение или создание коллекции. | ||
2. Загрузка документа и разбиение его на части. | ||
3. Генерация эмбеддингов для каждой части текста. | ||
4. Добавление строк и эмбеддингов в коллекцию. | ||
""" | ||
|
||
client = chromadb.Client(CHROMA_SETTINGS) | ||
collection = get_collection(client, COLLECTION_NAME) | ||
|
||
logging.info(f"Загрузка документа {SOURCE_DOCUMENT}") | ||
texts, ids = load_documents(SOURCE_DOCUMENT) | ||
texts_len = len(texts) | ||
logging.info(f"Разделено на {texts_len} частей текста") | ||
|
||
logging.info("Добавление строк и эмбеддингов в коллекцию") | ||
embed_generator = EmbeddingGenerator() | ||
|
||
for texts_chunk, ids_chunk in tqdm(zip(texts, ids), desc="Обработка частей"): | ||
documents_chunk, text_embeds_chunk = embed_generator.get_embeddings( | ||
texts_chunk, MAX_N_TOKENS, STRIDE | ||
) | ||
|
||
collection.add( | ||
documents=documents_chunk, | ||
ids=[ids_chunk + f"_part{pi}" for pi in range(len(text_embeds_chunk))], | ||
embeddings=text_embeds_chunk, | ||
) | ||
logging.info("Готово") | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig( | ||
format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)s - %(message)s", | ||
level=logging.INFO, | ||
) | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
python-docx==1.1.2 | ||
chromadb==0.5.5 | ||
torch==2.2.1 | ||
transformers==4.40.0 |
Oops, something went wrong.