Skip to content

Commit

Permalink
update sim type.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Mar 11, 2022
1 parent 20e9de9 commit 85a043a
Show file tree
Hide file tree
Showing 8 changed files with 297 additions and 184 deletions.
9 changes: 4 additions & 5 deletions examples/benchmarking/benchmark_bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ def get_dbpedia():
#### Loading test queries and corpus in DBPedia
corpus, queries, qrels = SearchDataLoader(data_path).load(split="test")
corpus_ids, query_ids = list(corpus), list(queries)
print(len(corpus))
print(len(queries))
print(len(qrels))
logger.info(f"corpus: {len(corpus)}, queries: {len(queries)}")

#### Randomly sample 1M pairs from Original Corpus (4.63M pairs)
#### First include all relevant documents (i.e. present in qrels)
corpus_set = set()
Expand All @@ -66,7 +65,7 @@ def get_dbpedia():
#### Remove already seen k relevant documents and sample (1M - k) docs randomly
remaining_corpus = list(set(corpus_ids) - corpus_set)
sample = min(1000000 - len(corpus_set), len(remaining_corpus))
sample = 10
# sample = 10

for corpus_id in random.sample(remaining_corpus, sample):
corpus_new[corpus_id] = corpus[corpus_id]
Expand Down Expand Up @@ -110,4 +109,4 @@ def get_dbpedia():

#### Evaluate your retrieval using NDCG@k, MAP@K ...
ndcg, _map, recall, precision = evaluate(qrels, results)
print(ndcg, _map, recall, precision)
logger.info(f"MAP: {_map}")
18 changes: 9 additions & 9 deletions examples/benchmarking/benchmark_sbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ def get_dbpedia():
#### Loading test queries and corpus in DBPedia
corpus, queries, qrels = SearchDataLoader(data_path).load(split="test")
corpus_ids, query_ids = list(corpus), list(queries)
print(len(corpus))
print(len(queries))
query_keys = list(queries.keys())[:10]
queries = {key: queries[key] for key in query_keys}
print(len(queries))
print(len(qrels))
logger.info(f"corpus: {len(corpus)}, queries: {len(queries)}")

# query_keys = list(queries.keys())[:10]
# queries = {key: queries[key] for key in query_keys}
# print(len(queries))
# print(len(qrels))

#### Randomly sample 1M pairs from Original Corpus (4.63M pairs)
#### First include all relevant documents (i.e. present in qrels)
corpus_set = set()
Expand All @@ -70,15 +71,14 @@ def get_dbpedia():
#### Remove already seen k relevant documents and sample (1M - k) docs randomly
remaining_corpus = list(set(corpus_ids) - corpus_set)
sample = min(1000000 - len(corpus_set), len(remaining_corpus))
sample = 10

for corpus_id in random.sample(remaining_corpus, sample):
corpus_new[corpus_id] = corpus[corpus_id]

corpus_docs = {corpus_id: corpus_new[corpus_id]['title'] + corpus_new[corpus_id]['text'] for corpus_id, corpus in
corpus_new.items()}
#### Index 1M passages into the index (seperately)
model = Similarity(corpus=corpus_docs)
model = Similarity(corpus=corpus_docs, model_name_or_path="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
logger.debug(model)
#### Saving benchmark times with batch
# queries = [queries[query_id] for query_id in query_ids]
Expand All @@ -94,7 +94,7 @@ def get_dbpedia():

#### Evaluate your retrieval using NDCG@k, MAP@K ...
ndcg, _map, recall, precision = evaluate(qrels, results)
print(ndcg, _map, recall, precision)
logger.info(f"MAP: {_map}")

#### Measuring Index size consumed by document embeddings
corpus_embs = model.corpus_embeddings
Expand Down
35 changes: 29 additions & 6 deletions examples/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import sys
import glob
from PIL import Image

sys.path.append('..')
from similarities.imagesim import ImageHashSimilarity, SiftSimilarity, ClipSimilarity
Expand All @@ -13,29 +14,51 @@
def sim_and_search(m):
print(m)
# similarity
sim_scores = m.similarity(image_fps1, image_fps2)
sim_scores = m.similarity(imgs1, imgs2)
print('sim scores: ', sim_scores)
for (idx, i), j in zip(enumerate(image_fps1), image_fps2):
s = sim_scores[idx] if isinstance(sim_scores, list) else sim_scores[idx][idx]
print(f"{i} vs {j}, score: {s:.4f}")
# search
m.add_corpus(corpus)
queries = image_fps1
m.add_corpus(corpus_imgs)
queries = imgs1
res = m.most_similar(queries, topn=3)
print('sim search: ', res)
for q_id, c in res.items():
print('query:', queries[q_id])
print('query:', image_fps1[q_id])
print("search top 3:")
for corpus_id, s in c.items():
print(f'\t{m.corpus[corpus_id]}: {s:.4f}')
print(f'\t{m.corpus[corpus_id].filename}: {s:.4f}')
print('-' * 50 + '\n')


def clip_demo():
m = ClipSimilarity()
print(m)
# similarity score between text and image
image_fps = ['data/image3.png', # yellow flower image
'data/image1.png'] # tiger image
texts = ['a yellow flower', 'a tiger']
imgs = [Image.open(i) for i in image_fps]
sim_scores = m.similarity(imgs, texts)
print('sim scores: ', sim_scores)
for (idx, i), j in zip(enumerate(image_fps), texts):
s = sim_scores[idx][idx]
print(f"{i} vs {j}, score: {s:.4f}")
print('-' * 50 + '\n')


if __name__ == "__main__":
image_fps1 = ['data/image1.png', 'data/image3.png']
image_fps2 = ['data/image12-like-image1.png', 'data/image10.png']
corpus = glob.glob('data/*.jpg') + glob.glob('data/*.png')
imgs1 = [Image.open(i) for i in image_fps1]
imgs2 = [Image.open(i) for i in image_fps2]
corpus_fps = glob.glob('data/*.jpg') + glob.glob('data/*.png')
corpus_imgs = [Image.open(i) for i in corpus_fps]
# 1. image and text similarity
clip_demo()

# 2. image and image similarity score
sim_and_search(ClipSimilarity()) # the best result
sim_and_search(ImageHashSimilarity(hash_function='phash'))
sim_and_search(SiftSimilarity())
14 changes: 7 additions & 7 deletions similarities/clip_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,22 @@ def save(self, output_path: str):
@staticmethod
def load(input_path: str):
return CLIPModel(model_name=input_path)
def _text_length(self, text: Union[List[int], List[List[int]]]):

def _text_length(self, text):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
"""

if isinstance(text, dict): #{key: value} case
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, '__len__'): #Object has no len() method
elif not hasattr(text, '__len__'): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): #Empty string or list of ints
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) #Sum of length of individual strings
return sum([len(t) for t in text]) # Sum of length of individual strings

@staticmethod
def batch_to_device(batch):
Expand All @@ -117,7 +118,6 @@ def batch_to_device(batch):
batch[key] = batch[key].to(device)
return batch


def encode(
self,
sentences: Union[str, List[str]],
Expand All @@ -127,7 +127,7 @@ def encode(
normalize_embeddings: bool = False
):
"""
Computes sentence embeddings
Computes sentence and images embeddings
:param sentences: the sentences to embed
:param batch_size: the batch size used for the computation
Expand Down
Loading

0 comments on commit 85a043a

Please sign in to comment.