Skip to content

Commit

Permalink
works
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Jun 30, 2023
1 parent 0bf5f3a commit 84873af
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 82 deletions.
Binary file modified __pycache__/main.cpython-310.pyc
Binary file not shown.
Binary file modified __pycache__/model.cpython-310.pyc
Binary file not shown.
Binary file modified __pycache__/utils.cpython-310.pyc
Binary file not shown.
68 changes: 26 additions & 42 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@
import pickle
import torch

from sentence_transformers import SentenceTransformer, util

class VLite:
'''
vlite is a simple vector database that stores vectors in a numpy array.
'''
def __init__(self, collection='vlite.pkl'):
self.collection = collection
self.model = EmbeddingModel()
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
# try:
# with open(self.collection, 'rb') as f:
# self.data = pickle.load(f)
# except FileNotFoundError:
self.text = {}

self.texts = []
self.metadata = {}
self.vectors = np.empty((0, 384))

Expand All @@ -36,13 +34,13 @@ def memorize(self, text, id=None, metadata=None):
chunks = chop_and_chunk(text)
for chunk in chunks:
encoded_data = self.model.embed(chunk)
encoded_data_bench = self.embedder.encode(chunk)

print("[+] Encoded:", encoded_data.shape)
print("[+] Encoded bench:", encoded_data_bench.shape)

self.text[chunk] = id
self.metadata[id] = metadata
self.texts.append(chunk)
self.metadata[len(self.texts) - 1] = metadata or {}
self.metadata[len(self.texts) - 1]['index'] = len(self.texts) - 1

self.vectors = np.vstack((self.vectors, encoded_data))

print("[+] Memorizing with ID:", id)
Expand All @@ -63,47 +61,33 @@ def remember(self, text=None, id=None, top_k=2):
sims = cos_sim(query, corpus)

print("[+] Similarities:", sims)

sims = sims.flatten()

top_3_idx = np.argsort(sims)[::-1][:3]
print("[+] Top 3 indices:", top_3_idx)

# iterate over the top 3 most similar sentences
for idx in top_3_idx:
print("[+] Index:", idx)
print("[+] Sentence:", self.texts[idx])
print("[+] Metadata:", self.metadata[idx])



def remember_bench(self, text=None):
# def remember_bench(self, query, corpus):

# query_embedding = self.embedder.encode(query)
# corpus_embeddings = self.embedder.encode(corpus)

query_embedding = self.model.embed(text)
# cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]

print("[+] Query shape:", query_embedding.shape)
# print("[+] Cos scores:", cos_scores)

print("[+] Corpus shape:", self.vectors.shape)
# hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=5)

hits = util.semantic_search(query_embedding, self.vectors, top_k=5)
hits = hits[0] #Get the hits for the first query
for hit in hits:
print("(Score: {:.4f})".format(hit['score']))
# print("[+] Hits:", hits)

def save(self):
with open(self.collection, 'wb') as f:
pickle.dump(self.data, f)



# def remove(self, text):
# '''
# Removes the sentence from the database.
# '''
# if text in self.id_to_index:
# index = self.id_to_index[text]
# self.vectors = np.delete(self.vectors, index, 0)
# del self.id_to_index[text]
# for key, value in self.id_to_index.items():
# if value > index:
# self.id_to_index[key] = value - 1
# else:
# print(f"Text: {text} not found in database.")

# def relevancy(self, query, top_k=5):
# '''
# Returns the top_k most relevant sentences in the corpus to the query.
# '''
# query_vector = self.model.embed(query)
# scores = np.dot(self.vectors, query_vector)
# top_k_indices = np.argsort(scores)[-top_k:]
# top_k_texts = {k: v for k, v in self.id_to_index.items() if v in top_k_indices}
# return top_k_texts
14 changes: 6 additions & 8 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@ def setUp(self):
self.metadata = {"name": "test"}

def test_memorize(self):
self.vlite.memorize(text=self.multiple_data)
print("[mem] test 1")
self.vlite.memorize(text=self.long_data)
print("[mem] test 2")
self.vlite.memorize(text=self.long_data_2)
print("[mem] test 3")
# self.vlite.remember_bench(text="civil law")
# print("[remember] test 4")
print("[test_memorize] memorized")
self.vlite.remember(text="civil law")
print("[remember] test 5")
print("[test_memorize] remembered")

# def test_sentence_transformers(self):
# self.vlite.remember_bench(query="civil law", corpus=self.long_data_2)
# print("[test_sentence_transformers] test 1")

if __name__ == '__main__':
unittest.main()
Expand Down
33 changes: 1 addition & 32 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import pickle

def chop_and_chunk(text, max_seq_length=128):
'''
Chop text into chunks of max_seq_length.
'''

# text can be a string or a list of strings
if isinstance(text, str):
# if text is a string, create a list with the string as the only element
Expand All @@ -16,33 +12,6 @@ def chop_and_chunk(text, max_seq_length=128):


def cos_sim(vec,mat):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
# if not isinstance(a, torch.Tensor):
# a = torch.tensor(a)

# if not isinstance(b, torch.Tensor):
# b = torch.tensor(b)

# if len(a.shape) == 1:
# a = a.unsqueeze(0)

# if len(b.shape) == 1:
# b = b.unsqueeze(0)

# a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
# b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
# return torch.mm(a_norm, b_norm.transpose(0, 1))

# product = np.dot(a, b)
# norm = np.linalg.norm(a) * np.linalg.norm(b)
# return product / norm

sim = vec @ mat.T # Vector matrix multiplication

# Normalize vectors
sim = vec @ mat.T
sim /= np.linalg.norm(vec) * np.linalg.norm(mat, axis=1)

return sim

0 comments on commit 84873af

Please sign in to comment.