forked from pixegami/rag-tutorial-v2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1c98fd5
commit c3138f3
Showing
10 changed files
with
297 additions
and
4 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
pypdf | ||
langchain | ||
chromadb # Vector storage | ||
pytest | ||
boto3 | ||
PyPDF2 | ||
faiss-gpu | ||
sentence-transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import PyPDF2 | ||
from sentence_transformers import SentenceTransformer | ||
import os | ||
import numpy as np | ||
|
||
def extract_text_from_pdf(pdf_path): | ||
with open(pdf_path, 'rb') as file: | ||
reader = PyPDF2.PdfReader(file) | ||
text = '' | ||
for page in range(len(reader.pages)): | ||
text += reader.pages[page].extract_text() | ||
return text | ||
|
||
def chunk_text(text, max_token_length=512): | ||
paragraphs = text.split("\n\n") | ||
chunks = [] | ||
current_chunk = [] | ||
current_length = 0 | ||
|
||
for para in paragraphs: | ||
para_length = len(para.split()) | ||
if current_length + para_length <= max_token_length: | ||
current_chunk.append(para) | ||
current_length += para_length | ||
else: | ||
chunks.append(" ".join(current_chunk)) | ||
current_chunk = [para] | ||
current_length = para_length | ||
|
||
if current_chunk: | ||
chunks.append(" ".join(current_chunk)) | ||
|
||
return chunks | ||
|
||
def vectorize_text_chunks(chunks, model: SentenceTransformer = SentenceTransformer('all-MiniLM-L6-v2')): | ||
vectors = model.encode(chunks) | ||
return vectors | ||
|
||
def process_pdf(pdf_file, index, metadata_list, processed_files): | ||
print(f"Processing: {pdf_file}") | ||
pdf_text = extract_text_from_pdf(pdf_file) | ||
chunks = chunk_text(pdf_text) | ||
vectors = vectorize_text_chunks(chunks) | ||
|
||
# Add vectors to FAISS index | ||
vectors_np = np.array(vectors) | ||
index.add(vectors_np) | ||
|
||
# Create metadata for each chunk | ||
metadata_list.extend([{ | ||
"text": chunk, | ||
"file_name": os.path.basename(pdf_file), | ||
"page": i+1 | ||
} for i, chunk in enumerate(chunks)]) | ||
|
||
# Mark the file as processed | ||
processed_files[os.path.basename(pdf_file)] = True | ||
return processed_files, metadata_list |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import argparse | ||
import os | ||
from extract_from_pdf import process_pdf | ||
from utils import ( | ||
load_processed_files, | ||
load_faiss_index, | ||
save_faiss_index, | ||
save_processed_files, | ||
load_metadata, | ||
save_metadata, | ||
) | ||
from sentence_transformers import SentenceTransformer | ||
|
||
def main(args): | ||
print(args) | ||
processed_files = load_processed_files(args.processed_file_path) | ||
|
||
# Check if metadata file exists and load it | ||
metadata_list = load_metadata(args.metadata_file) | ||
|
||
# Initialize the sentence transformer model | ||
model = SentenceTransformer(args.embedding_model) | ||
|
||
# Initialize FAISS index (or load it) | ||
sample_vector = model.encode(["sample text"])[0] # Get vector dimensions | ||
vector_dim = len(sample_vector) | ||
index = load_faiss_index(vector_dim) | ||
|
||
# Iterate through all PDF files in the directory | ||
for pdf_file in os.listdir(args.knowledge_dir): | ||
if pdf_file.endswith(".pdf") and pdf_file not in processed_files: | ||
full_path = os.path.join(args.knowledge_dir, pdf_file) | ||
processed_files, metadata_list = process_pdf(full_path, index, metadata_list, processed_files) | ||
|
||
# Save the updated FAISS index | ||
save_faiss_index(index) | ||
|
||
# Save the updated metadata | ||
save_metadata(metadata_list, args.metadata_file) | ||
|
||
# Save the list of processed files | ||
save_processed_files(processed_files, args.processed_file_path) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--knowledge_dir', type=str, default='/scratch/mrahma45/nrl_rag_pipeline/rag-infoex/data') | ||
parser.add_argument('--processed_file_path', type=str, default='./processed_files.json') | ||
parser.add_argument('--metadata_file', type=str, default='./metadata.json') | ||
parser.add_argument('--embedding_model', type=str, default='all-MiniLM-L6-v2') | ||
|
||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import os | ||
import json | ||
import PyPDF2 | ||
import faiss | ||
import numpy as np | ||
from sentence_transformers import SentenceTransformer | ||
|
||
# Directory containing PDF files | ||
PDF_DIR = '/scratch/mrahma45/nrl_rag_pipeline/rag-infoex/data' | ||
|
||
# File to store metadata of processed PDFs | ||
PROCESSED_FILES = 'processed_files.json' | ||
|
||
# FAISS index file | ||
FAISS_INDEX_FILE = 'faiss_index.idx' | ||
|
||
# JSON file to store metadata | ||
METADATA_FILE = 'metadata.json' | ||
|
||
# Initialize the sentence transformer model | ||
model = SentenceTransformer('all-MiniLM-L6-v2') | ||
|
||
# def load_processed_files(): | ||
# if os.path.exists(PROCESSED_FILES): | ||
# with open(PROCESSED_FILES, 'r') as f: | ||
# return json.load(f) | ||
# return {} | ||
|
||
# def save_processed_files(processed_files): | ||
# with open(PROCESSED_FILES, 'w') as f: | ||
# json.dump(processed_files, f) | ||
|
||
# def extract_text_from_pdf(pdf_path): | ||
# with open(pdf_path, 'rb') as file: | ||
# reader = PyPDF2.PdfReader(file) | ||
# text = '' | ||
# for page in range(len(reader.pages)): | ||
# text += reader.pages[page].extract_text() | ||
# return text | ||
|
||
# def chunk_text(text, max_token_length=512): | ||
# paragraphs = text.split("\n\n") | ||
# chunks = [] | ||
# current_chunk = [] | ||
# current_length = 0 | ||
|
||
# for para in paragraphs: | ||
# para_length = len(para.split()) | ||
# if current_length + para_length <= max_token_length: | ||
# current_chunk.append(para) | ||
# current_length += para_length | ||
# else: | ||
# chunks.append(" ".join(current_chunk)) | ||
# current_chunk = [para] | ||
# current_length = para_length | ||
|
||
# if current_chunk: | ||
# chunks.append(" ".join(current_chunk)) | ||
|
||
# return chunks | ||
|
||
# def vectorize_text_chunks(chunks): | ||
# vectors = model.encode(chunks) | ||
# return vectors | ||
|
||
# def create_faiss_index(vector_dim): | ||
# index = faiss.IndexFlatL2(vector_dim) | ||
# return index | ||
|
||
# def load_faiss_index(vector_dim): | ||
# if os.path.exists(FAISS_INDEX_FILE): | ||
# index = faiss.read_index(FAISS_INDEX_FILE) | ||
# else: | ||
# index = create_faiss_index(vector_dim) | ||
# return index | ||
|
||
# def save_faiss_index(index): | ||
# faiss.write_index(index, FAISS_INDEX_FILE) | ||
|
||
# def process_pdf(pdf_file, index, metadata_list, processed_files): | ||
# print(f"Processing: {pdf_file}") | ||
# pdf_text = extract_text_from_pdf(pdf_file) | ||
# chunks = chunk_text(pdf_text) | ||
# vectors = vectorize_text_chunks(chunks) | ||
|
||
# # Add vectors to FAISS index | ||
# vectors_np = np.array(vectors) | ||
# index.add(vectors_np) | ||
|
||
# # Create metadata for each chunk | ||
# metadata_list.extend([{ | ||
# "text": chunk, | ||
# "file_name": os.path.basename(pdf_file), | ||
# "page": i+1 | ||
# } for i, chunk in enumerate(chunks)]) | ||
|
||
# # Mark the file as processed | ||
# processed_files[os.path.basename(pdf_file)] = True | ||
|
||
def process_pdfs_in_directory(directory): | ||
processed_files = load_processed_files() | ||
|
||
# Check if metadata file exists and load it | ||
if os.path.exists(METADATA_FILE): | ||
with open(METADATA_FILE, 'r') as f: | ||
metadata_list = json.load(f) | ||
|
||
# Initialize FAISS index (or load it) | ||
sample_vector = model.encode(["sample text"])[0] # Get vector dimensions | ||
vector_dim = len(sample_vector) | ||
index = load_faiss_index(vector_dim) | ||
|
||
# Iterate through all PDF files in the directory | ||
for pdf_file in os.listdir(directory): | ||
if pdf_file.endswith(".pdf") and pdf_file not in processed_files: | ||
full_path = os.path.join(directory, pdf_file) | ||
process_pdf(full_path, index, metadata_list, processed_files) | ||
|
||
# Save the updated FAISS index | ||
save_faiss_index(index) | ||
|
||
# Save the updated metadata | ||
with open(METADATA_FILE, 'w') as f: | ||
json.dump(metadata_list, f) | ||
|
||
# Save the list of processed files | ||
save_processed_files(processed_files) | ||
|
||
# Run the script | ||
process_pdfs_in_directory(PDF_DIR) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"monopoly.pdf": true, "ticket_to_ride.pdf": true, "mushfiq_CV.pdf": true} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import faiss | ||
import numpy as np | ||
|
||
def create_faiss_index(vectors): | ||
vectors_np = np.array(vectors) # Convert to NumPy array | ||
vector_dim = vectors_np.shape[1] # Dimensionality of vectors | ||
index = faiss.IndexFlatL2(vector_dim) # L2 distance for similarity search | ||
index.add(vectors_np) # Add vectors to the index | ||
return index | ||
|
||
# Create the FAISS index | ||
faiss_index = create_faiss_index(chunk_vectors) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
import json | ||
import faiss | ||
from typing import List | ||
|
||
def load_processed_files(file_path: str): | ||
if os.path.exists(file_path): | ||
with open(file_path, 'r') as f: | ||
return json.load(f) | ||
return {} | ||
|
||
def load_metadata(metadata_file: str): | ||
metadata_list = load_processed_files(file_path=metadata_file) | ||
if isinstance(metadata_list, List): | ||
return metadata_list | ||
else: | ||
return [] | ||
|
||
def save_processed_files(processed_dict: dict, file_path: str) -> None: | ||
with open(file_path, 'w') as f: | ||
json.dump(processed_dict, f) | ||
|
||
def save_metadata(metadata_list: dict, metadata_file: str) -> None: | ||
with open(metadata_file, 'w') as f: | ||
json.dump(metadata_list, f) | ||
|
||
def create_faiss_index(vector_dim): | ||
index = faiss.IndexFlatL2(vector_dim) | ||
return index | ||
|
||
def load_faiss_index(vector_dim, FAISS_INDEX_FILE: str = "faiss_index.idx"): | ||
if os.path.exists(FAISS_INDEX_FILE): | ||
index = faiss.read_index(FAISS_INDEX_FILE) | ||
else: | ||
index = create_faiss_index(vector_dim) | ||
return index | ||
|
||
def save_faiss_index(index, FAISS_INDEX_FILE: str = "faiss_index.idx"): | ||
faiss.write_index(index, FAISS_INDEX_FILE) | ||
|