-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathtest_splade.py
51 lines (42 loc) · 1.95 KB
/
test_splade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from pytest import raises
from pinecone_text.sparse import SpladeEncoder
class TestSplade:
def test_splade_single_doc_inference(self):
splade = SpladeEncoder()
text = "This is a test"
output = splade.encode_documents(text)
assert len(output["indices"]) == len(output["values"])
def test_splade_batch_inference(self):
splade = SpladeEncoder()
texts = ["This is a test", "This is a test also"]
output = splade.encode_documents(texts)
assert len(output) == 2
assert len(output[0]["indices"]) == len(output[0]["values"])
assert len(output[1]["indices"]) == len(output[1]["values"])
# calculate jacard similarity of output indices
jacard = len(
set(output[0]["indices"]).intersection(set(output[1]["indices"]))
) / len(set(output[0]["indices"]).union(set(output[1]["indices"])))
assert 0.5 < jacard < 1.0
def test_splade_single_query_inference(self):
splade = SpladeEncoder()
text = "This is a test"
output = splade.encode_queries(text)
assert len(output["indices"]) == len(output["values"])
def test_splade_batch_query_inference(self):
splade = SpladeEncoder()
texts = ["This is a test", "This is a test also"]
output = splade.encode_queries(texts)
assert len(output) == 2
assert len(output[0]["indices"]) == len(output[0]["values"])
assert len(output[1]["indices"]) == len(output[1]["values"])
# calculate jacard similarity of output indices
jacard = len(
set(output[0]["indices"]).intersection(set(output[1]["indices"]))
) / len(set(output[0]["indices"]).union(set(output[1]["indices"])))
assert 0.5 < jacard < 1.0
def test_splade_init_invalid_max_seq_length(self):
with raises(ValueError):
SpladeEncoder(max_seq_length=0)
with raises(ValueError):
SpladeEncoder(max_seq_length=513)