-
Notifications
You must be signed in to change notification settings - Fork 0
/
semantic_search_hnswlib.py
163 lines (129 loc) · 6.02 KB
/
semantic_search_hnswlib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
This example uses Approximate Nearest Neighbor Search (ANN) with Hnswlib (https://github.com/nmslib/hnswlib/).
Searching a large corpus with Millions of embeddings can be time-consuming. To speed this up,
ANN can index the existent vectors. For a new query vector, this index can be used to find the nearest neighbors.
This nearest neighbor search is not perfect, i.e., it might not perfectly find all top-k nearest neighbors.
In this example, we use Hnswlib: It is a fast and easy to use library, with excellent results on common benchmarks.
Usually you can install Hnswlib by running:
pip install hnswlib
For more details, see https://github.com/nmslib/hnswlib/
As dataset, we use the Quora Duplicate Questions dataset, which contains about 500k questions (we only use 100k in this example):
https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs
As embeddings model, we use the SBERT model 'quora-distilbert-multilingual',
that it aligned for 100 languages. I.e., you can type in a question in various languages and it will
return the closest questions in the corpus (questions in the corpus are mainly in English).
"""
# STL
import os
import csv
import time
import pickle
# PDM
import hnswlib
from sentence_transformers import SentenceTransformer, util
model_name = "quora-distilbert-multilingual"
model = SentenceTransformer(model_name)
# url = ""
dataset_path = os.path.join("./", "datasets", "g4boyz100k.csv")
max_corpus_size = 100000
embedding_cache_path = "g4boyz-embeddings-{}-size-{}.pkl".format(
model_name.replace("/", "_"), max_corpus_size
)
embedding_size = 768 # Size of embeddings
top_k_hits = 10 # Output k hits
# Check if embedding cache path exists
if not os.path.exists(embedding_cache_path):
# Check if the dataset exists. If not, download and extract
# Download dataset if needed
# if not os.path.exists(dataset_path):
# print("Download dataset")
# util.http_get(url, dataset_path)
# Get all unique sentences from the file
corpus_sentences = set()
with open(dataset_path, encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter=",", quoting=csv.QUOTE_MINIMAL)
for row in reader:
corpus_sentences.add(row["content"])
if len(corpus_sentences) >= max_corpus_size:
break
corpus_sentences = list(corpus_sentences)
print("Encode the corpus. This might take a while")
corpus_embeddings = model.encode(
corpus_sentences, show_progress_bar=True, convert_to_numpy=True
)
print("Store file on disc")
with open(embedding_cache_path, "wb") as fOut:
pickle.dump(
{"sentences": corpus_sentences, "embeddings": corpus_embeddings}, fOut
)
else:
print("Load pre-computed embeddings from disc")
with open(embedding_cache_path, "rb") as fIn:
cache_data = pickle.load(fIn)
corpus_sentences = cache_data["sentences"]
corpus_embeddings = cache_data["embeddings"]
# # Defining our hnswlib index
# index_path = "./hnswlib.index"
# # We use Inner Product (dot-product) as Index. We will normalize our vectors to unit length, then is Inner Product equal to cosine similarity
# index = hnswlib.Index(space="cosine", dim=embedding_size)
# if os.path.exists(index_path):
# print("Loading index...")
# index.load_index(index_path)
# else:
# # Create the HNSWLIB index
# print("Start creating HNSWLIB index")
# index.init_index(max_elements=len(corpus_embeddings), ef_construction=400, M=64)
# # Then we train the index to find a suitable clustering
# index.add_items(corpus_embeddings, list(range(len(corpus_embeddings))))
# print("Saving index to:", index_path)
# index.save_index(index_path)
# # Controlling the recall by setting ef:
# index.set_ef(50) # ef should always be > top_k_hits
# Search in the index
print("Corpus loaded with {} sentences / embeddings".format(len(corpus_sentences)))
while True:
inp_question = input("Please enter a question: ")
start_time = time.time()
question_embedding = model.encode(inp_question)
# # We use hnswlib knn_query method to find the top_k_hits
# corpus_ids, distances = index.knn_query(question_embedding, k=top_k_hits)
# # We extract corpus ids and scores for the first query
# hits = [
# {"corpus_id": id, "score": 1 - score}
# for id, score in zip(corpus_ids[0], distances[0])
# ]
# hits = sorted(hits, key=lambda x: x["score"], reverse=True)
# end_time = time.time()
print("Input question:", inp_question)
# print("Results (after {:.3f} seconds):".format(end_time - start_time))
# for hit in hits[0:top_k_hits]:
# print("\t{:.3f}\t{}".format(hit["score"], corpus_sentences[hit["corpus_id"]]))
# Approximate Nearest Neighbor (ANN) is not exact, it might miss entries with high cosine similarity
# Here, we compute the recall of ANN compared to the exact results
correct_hits = util.semantic_search(
question_embedding, corpus_embeddings, top_k=top_k_hits
)[0]
for hit in correct_hits[0:top_k_hits]:
print("\t{:.3f}\t{}".format(hit["score"], corpus_sentences[hit["corpus_id"]]))
# correct_hits_ids = set([hit["corpus_id"] for hit in correct_hits])
# ann_corpus_ids = set([hit["corpus_id"] for hit in hits])
# if len(ann_corpus_ids) != len(correct_hits_ids):
# print(
# "Approximate Nearest Neighbor returned a different number of results than expected"
# )
# recall = len(ann_corpus_ids.intersection(correct_hits_ids)) / len(correct_hits_ids)
# print(
# "\nApproximate Nearest Neighbor Recall@{}: {:.2f}".format(
# top_k_hits, recall * 100
# )
# )
# if recall < 1:
# print("Missing results:")
# for hit in correct_hits[0:top_k_hits]:
# if hit["corpus_id"] not in ann_corpus_ids:
# print(
# "\t{:.3f}\t{}".format(
# hit["score"], corpus_sentences[hit["corpus_id"]]
# )
# )
# print("\n\n========\n")