Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/stanford-oval/suql into main
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed May 31, 2024
2 parents 62c17a8 + d118e6c commit bae7e09
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package metadata
name = "suql"
version = "1.1.7a5"
version = "1.1.7a6"
description = "Structured and Unstructured Query Language (SUQL) Python API"
author = "Shicheng Liu"
author_email = "[email protected]"
Expand Down
9 changes: 5 additions & 4 deletions src/suql/faiss_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,12 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
for sublist in map(lambda x: self.id2document[x], individual_id_list)
for item in sublist
]
embedding_indices = [
# remove potential duplicates here
embedding_indices = list(dict.fromkeys([
item
for sublist in map(lambda x: self.document2embedding[x], document_indices)
for item in sublist
]
]))

query_embedding = embed_query(query)

Expand All @@ -301,8 +302,8 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):
params=faiss.SearchParametersIVF(sel=sel),
)
else:
if top > self.embeddings.ntotal:
top = self.embeddings.ntotal
if top > min(self.embeddings.ntotal, len(embedding_indices)):
top = min(self.embeddings.ntotal, len(embedding_indices))
D, I = self.embeddings.search(
query_embedding, top, params=faiss.SearchParametersIVF(sel=sel)
)
Expand Down
7 changes: 5 additions & 2 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,10 +723,13 @@ def _retrieve_and_verify(
enforce_ordering=True if node.sortClause is not None else False,
)
else:
id_res = []
id_res = set()
for each_res in parsed_result:
if _verify_single_res(each_res, field_query_list, llm_model_name):
id_res.append(each_res[0])
if isinstance(each_res[0], list):
id_res.update(each_res[0])
else:
id_res.add(each_res[0])

end_time = time.time()
logging.info("retrieve + verification time {}s".format(end_time - start_time))
Expand Down

0 comments on commit bae7e09

Please sign in to comment.