forked from facebookresearch/faiss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_clone.py
88 lines (68 loc) · 2.68 KB
/
test_clone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import faiss
import numpy as np
from faiss.contrib import datasets
faiss.omp_set_num_threads(4)
class TestClone(unittest.TestCase):
"""
Test clone_index for various index combinations.
"""
def do_test_clone(self, factory, with_ids=False):
"""
Verify that cloning works for a given index type
"""
d = 32
ds = datasets.SyntheticDataset(d, 1000, 2000, 10)
index1 = faiss.index_factory(d, factory)
index1.train(ds.get_train())
if with_ids:
index1.add_with_ids(ds.get_database(),
np.arange(ds.nb).astype("int64"))
else:
index1.add(ds.get_database())
k = 5
Dref1, Iref1 = index1.search(ds.get_queries(), k)
index2 = faiss.clone_index(index1)
self.assertEqual(type(index1), type(index2))
index1 = None
Dref2, Iref2 = index2.search(ds.get_queries(), k)
np.testing.assert_array_equal(Dref1, Dref2)
np.testing.assert_array_equal(Iref1, Iref2)
def test_RFlat(self):
self.do_test_clone("SQ4,RFlat")
def test_Refine(self):
self.do_test_clone("SQ4,Refine(SQ8)")
def test_IVF(self):
self.do_test_clone("IVF16,Flat")
def test_PCA(self):
self.do_test_clone("PCA8,Flat")
def test_IDMap(self):
self.do_test_clone("IVF16,Flat,IDMap", with_ids=True)
def test_IDMap2(self):
self.do_test_clone("IVF16,Flat,IDMap2", with_ids=True)
def test_NSGPQ(self):
self.do_test_clone("NSG32,Flat")
def test_IVFAdditiveQuantizer(self):
self.do_test_clone("IVF16,LSQ5x6_Nqint8")
self.do_test_clone("IVF16,RQ5x6_Nqint8")
self.do_test_clone("IVF16,PLSQ4x3x5_Nqint8")
self.do_test_clone("IVF16,PRQ4x3x5_Nqint8")
def test_IVFAdditiveQuantizerFastScan(self):
self.do_test_clone("IVF16,LSQ3x4fs_32_Nlsq2x4")
self.do_test_clone("IVF16,RQ3x4fs_32_Nlsq2x4")
self.do_test_clone("IVF16,PLSQ2x3x4fs_Nlsq2x4")
self.do_test_clone("IVF16,PRQ2x3x4fs_Nrq2x4")
def test_AdditiveQuantizer(self):
self.do_test_clone("LSQ5x6_Nqint8")
self.do_test_clone("RQ5x6_Nqint8")
self.do_test_clone("PLSQ4x3x5_Nqint8")
self.do_test_clone("PRQ4x3x5_Nqint8")
def test_AdditiveQuantizerFastScan(self):
self.do_test_clone("LSQ3x4fs_32_Nlsq2x4")
self.do_test_clone("RQ3x4fs_32_Nlsq2x4")
self.do_test_clone("PLSQ2x3x4fs_Nlsq2x4")
self.do_test_clone("PRQ2x3x4fs_Nrq2x4")